cast/cast.go

802 lines
18 KiB
Go
Raw Normal View History

package cast
import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"slices"
"sort"
"strconv"
"strings"
"time"
)
// If 泛型三元表达式
func If[T any](condition bool, trueVal, falseVal T) T {
if condition {
return trueVal
}
return falseVal
}
// In 泛型包含判断
func In[T comparable](arr []T, val T) bool {
return slices.Contains(arr, val)
}
// As 忽略错误,返回零值 (消除摩擦)
func As[T any](v T, err error) T {
if err != nil {
var zero T
return zero
}
return v
}
func RealValue(v reflect.Value) reflect.Value {
for v.Kind() == reflect.Pointer || v.Kind() == reflect.Interface {
if v.IsNil() {
return v
}
v = v.Elem()
}
return v
}
// --- Core Cast Logic ---
// To 泛型转换 (支持基础类型、Slice、Map 及 JSON 自动转换,零摩擦模式)
func To[T any](v any) T {
var zero T
targetType := reflect.TypeOf((*T)(nil)).Elem()
// 1. 处理 JSON 自动转换 (Input: string/[]byte, Target: struct/map/slice)
if isJSONText(v) && isComplexType(targetType) {
return FromJSON[T](v)
}
// 2. 处理 JSON 自动转换 (Input: struct/map/slice, Target: string/[]byte)
if isComplexValue(v) && (targetType.Kind() == reflect.String || (targetType.Kind() == reflect.Slice && targetType.Elem().Kind() == reflect.Uint8)) {
s := ToJSON(v)
return any(reflectCast(s, targetType).Interface()).(T)
}
// 3. 处理 Slice/Map
if targetType.Kind() == reflect.Slice && targetType.Elem().Kind() != reflect.Uint8 {
sv := reflect.MakeSlice(targetType, 0, 0)
ptr := reflect.New(targetType)
ptr.Elem().Set(sv)
fillToSlice(ptr.Elem(), v)
return ptr.Elem().Interface().(T)
}
if targetType.Kind() == reflect.Map {
mv := reflect.MakeMap(targetType)
fillToMap(mv, v)
return mv.Interface().(T)
}
// 4. 处理基础类型
res := reflectCast(v, targetType)
if !res.IsValid() {
return zero
}
return res.Interface().(T)
}
func reflectCastE(value any, t reflect.Type) (reflect.Value, error) {
switch t.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
i, err := ToInt64E(value)
if err != nil {
return reflect.Value{}, err
}
return reflect.ValueOf(i).Convert(t), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
i, err := ToUint64E(value)
if err != nil {
return reflect.Value{}, err
}
return reflect.ValueOf(i).Convert(t), nil
case reflect.Float32, reflect.Float64:
f, err := ToFloat64E(value)
if err != nil {
return reflect.Value{}, err
}
return reflect.ValueOf(f).Convert(t), nil
case reflect.String:
return reflect.ValueOf(String(value)).Convert(t), nil
case reflect.Bool:
return reflect.ValueOf(Bool(value)).Convert(t), nil
case reflect.Interface:
if value == nil {
return reflect.Zero(t), nil
}
return reflect.ValueOf(value), nil
}
if t == reflect.TypeOf(time.Duration(0)) {
return reflect.ValueOf(Duration(value)), nil
}
return reflect.Zero(t), fmt.Errorf("unsupported target type: %v", t)
}
func ToInt64E(v any) (int64, error) {
if i, ok := toInt64(v); ok {
return i, nil
}
s := String(v)
i, err := strconv.ParseInt(s, 10, 64)
if err == nil {
return i, nil
}
if f, err := strconv.ParseFloat(s, 64); err == nil {
return int64(f), nil
}
return 0, fmt.Errorf("cannot cast %v to int64", v)
}
func ToUint64E(v any) (uint64, error) {
switch realValue := v.(type) {
case int, int8, int16, int32, int64:
return uint64(Int64(realValue)), nil
case uint, uint8, uint16, uint32, uint64:
return Uint64(realValue), nil
case float32, float64:
return uint64(Float64(realValue)), nil
case bool:
return If(realValue, uint64(1), uint64(0)), nil
}
s := String(v)
i, err := strconv.ParseUint(s, 10, 64)
if err == nil {
return i, nil
}
if f, err := strconv.ParseFloat(s, 64); err == nil {
return uint64(f), nil
}
return 0, fmt.Errorf("cannot cast %v to uint64", v)
}
func ToFloat64E(v any) (float64, error) {
switch realValue := v.(type) {
case int, int8, int16, int32, int64:
return float64(Int64(realValue)), nil
case uint, uint8, uint16, uint32, uint64:
return float64(Uint64(realValue)), nil
case float32:
return float64(realValue), nil
case float64:
return realValue, nil
case bool:
return If(realValue, 1.0, 0.0), nil
}
s := String(v)
f, err := strconv.ParseFloat(s, 64)
if err != nil {
return 0, fmt.Errorf("cannot cast %v to float64", v)
}
return f, nil
}
func isJSONText(v any) bool {
switch val := v.(type) {
case string:
s := strings.TrimSpace(val)
return strings.HasPrefix(s, "{") || strings.HasPrefix(s, "[")
case []byte:
s := bytes.TrimSpace(val)
return bytes.HasPrefix(s, []byte("{")) || bytes.HasPrefix(s, []byte("["))
}
return false
}
func isComplexType(t reflect.Type) bool {
kind := t.Kind()
for kind == reflect.Ptr {
t = t.Elem()
kind = t.Kind()
}
return kind == reflect.Struct || kind == reflect.Map || (kind == reflect.Slice && t.Elem().Kind() != reflect.Uint8)
}
func isComplexValue(v any) bool {
if v == nil {
return true
}
rv := RealValue(reflect.ValueOf(v))
if !rv.IsValid() || rv.Kind() == reflect.Pointer || rv.Kind() == reflect.Interface {
return true
}
kind := rv.Kind()
return kind == reflect.Struct || kind == reflect.Map || (kind == reflect.Slice && rv.Type().Elem().Kind() != reflect.Uint8)
}
func parseInt(s string) int64 {
i, err := strconv.ParseInt(s, 10, 64)
if err == nil {
return i
}
if f, err := strconv.ParseFloat(s, 64); err == nil {
return int64(f)
}
return 0
}
func parseUint(s string) uint64 {
i, err := strconv.ParseUint(s, 10, 64)
if err == nil {
return i
}
if f, err := strconv.ParseFloat(s, 64); err == nil {
return uint64(f)
}
return 0
}
func Int(value any) int { return int(Int64(value)) }
// Helper for integer coercion to avoid repetition
func toInt64(value any) (int64, bool) {
switch v := value.(type) {
case int: return int64(v), true
case int8: return int64(v), true
case int16: return int64(v), true
case int32: return int64(v), true
case int64: return v, true
case uint: return int64(v), true
case uint8: return int64(v), true
case uint16: return int64(v), true
case uint32: return int64(v), true
case uint64: return int64(v), true
case float32: return int64(v), true
case float64: return int64(v), true
case bool: return If(v, int64(1), int64(0)), true
}
return 0, false
}
func Int64(value any) int64 {
if value == nil { return 0 }
if i, ok := toInt64(value); ok { return i }
switch realValue := value.(type) {
case []byte: return parseInt(string(realValue))
case string: return parseInt(realValue)
}
rv := reflect.ValueOf(value)
if rv.Kind() == reflect.Pointer || rv.Kind() == reflect.Interface {
if rv = RealValue(rv); rv.IsValid() && rv.CanInterface() && rv.Kind() != reflect.Pointer {
return Int64(rv.Interface())
}
}
return 0
}
func Uint(value any) uint { return uint(Uint64(value)) }
func Uint64(value any) uint64 {
if value == nil {
return 0
}
switch realValue := value.(type) {
case int:
return uint64(realValue)
case int8:
return uint64(realValue)
case int16:
return uint64(realValue)
case int32:
return uint64(realValue)
case int64:
return uint64(realValue)
case uint:
return uint64(realValue)
case uint8:
return uint64(realValue)
case uint16:
return uint64(realValue)
case uint32:
return uint64(realValue)
case uint64:
return realValue
case float32:
return uint64(realValue)
case float64:
return uint64(realValue)
case bool:
return If(realValue, uint64(1), uint64(0))
case []byte:
return parseUint(string(realValue))
case string:
return parseUint(realValue)
}
rv := reflect.ValueOf(value)
if rv.Kind() == reflect.Pointer || rv.Kind() == reflect.Interface {
if rv = RealValue(rv); rv.IsValid() && rv.CanInterface() && rv.Kind() != reflect.Pointer {
return Uint64(rv.Interface())
}
}
return 0
}
func Float(value any) float32 { return float32(Float64(value)) }
func Float64(value any) float64 {
if value == nil {
return 0
}
switch realValue := value.(type) {
case int, int8, int16, int32, int64:
return float64(Int64(realValue))
case uint, uint8, uint16, uint32, uint64:
return float64(Uint64(realValue))
case float32:
return float64(realValue)
case float64:
return realValue
case bool:
return If(realValue, 1.0, 0.0)
case []byte:
i, err := strconv.ParseFloat(string(realValue), 64)
if err == nil {
return i
}
case string:
i, err := strconv.ParseFloat(realValue, 64)
if err == nil {
return i
}
}
rv := reflect.ValueOf(value)
if rv.Kind() == reflect.Pointer || rv.Kind() == reflect.Interface {
if rv = RealValue(rv); rv.IsValid() && rv.CanInterface() && rv.Kind() != reflect.Pointer {
return Float64(rv.Interface())
}
}
return 0
}
func String(value any) string {
if value == nil {
return ""
}
switch realValue := value.(type) {
case int, int8, int16, int32, int64:
return strconv.FormatInt(Int64(realValue), 10)
case uint, uint8, uint16, uint32, uint64:
return strconv.FormatUint(Uint64(realValue), 10)
case float32:
return strconv.FormatFloat(float64(realValue), 'f', -1, 32)
case float64:
return strconv.FormatFloat(realValue, 'f', -1, 64)
case bool:
return If(realValue, "true", "false")
case string:
return realValue
case []byte:
return string(realValue)
}
rv := reflect.ValueOf(value)
if rv.Kind() == reflect.Pointer || rv.Kind() == reflect.Interface {
if rv = RealValue(rv); rv.IsValid() && rv.CanInterface() && rv.Kind() != reflect.Pointer {
return String(rv.Interface())
}
}
return fmt.Sprint(value)
}
func Bool(value any) bool {
switch realValue := value.(type) {
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
return Uint64(realValue) != 0
case bool:
return realValue
case []byte:
s := strings.ToLower(string(realValue))
return s == "1" || s == "t" || s == "true"
case string:
s := strings.ToLower(realValue)
return s == "1" || s == "t" || s == "true"
}
rv := reflect.ValueOf(value)
if rv.Kind() == reflect.Pointer || rv.Kind() == reflect.Interface {
if rv = RealValue(rv); rv.IsValid() && rv.CanInterface() && rv.Kind() != reflect.Pointer {
return Bool(rv.Interface())
}
}
return false
}
func Duration(value any) time.Duration {
if value == nil {
return 0
}
switch realValue := value.(type) {
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return time.Duration(Int64(realValue))
case float32, float64:
return time.Duration(Float64(realValue) * float64(time.Second))
case []byte, string:
if result, err := time.ParseDuration(String(realValue)); err == nil {
return result
}
}
rv := reflect.ValueOf(value)
if rv.Kind() == reflect.Pointer || rv.Kind() == reflect.Interface {
if rv = RealValue(rv); rv.IsValid() && rv.CanInterface() && rv.Kind() != reflect.Pointer {
return Duration(rv.Interface())
}
}
return 0
}
func ToJSONBytes(value any) []byte {
return As(fastToJSONBytes(value))
}
func ToJSONDesensitizeBytes(value any, keys []string) []byte {
return As(fastToJSONBytes(value, keys...))
}
func ToJSON(value any) string {
return string(ToJSONBytes(value))
}
func PrettyToJSON(value any) string {
j := ToJSONBytes(value)
r := &bytes.Buffer{}
if err := json.Indent(r, j, "", " "); err == nil {
return r.String()
}
return string(j)
}
func toBytes(data any) []byte {
if data == nil {
return nil
}
2026-05-04 09:22:34 +08:00
switch v := data.(type) {
case []byte:
return v
case string:
return []byte(v)
}
return []byte(String(data))
}
func UnmarshalJSON(data any, value any) error {
b := toBytes(data)
if b == nil {
return fmt.Errorf("nil data")
}
2026-05-04 09:22:34 +08:00
return fastUnmarshalJSONBytes(b, value)
}
func FromJSON[T any](data any) T {
2026-05-04 09:22:34 +08:00
var v T
_ = UnmarshalJSON(data, &v)
return v
}
// --- Others (Keep logic but clean style) ---
func Split(s, sep string) []string {
ss := strings.Split(s, sep)
out := make([]string, 0, len(ss))
for _, s1 := range ss {
if s2 := strings.TrimSpace(s1); s2 != "" {
out = append(out, s2)
}
}
return out
}
func SplitArgs(s string) []string {
var res []string
var builder strings.Builder
inQuote := false
escaped := false
chars := []rune(s)
for i := 0; i < len(chars); i++ {
c := chars[i]
if escaped {
builder.WriteRune(c)
escaped = false
continue
}
if c == '\\' {
escaped = true
continue
}
if c == '"' {
inQuote = !inQuote
continue
}
if c == ' ' && !inQuote {
if builder.Len() > 0 {
res = append(res, builder.String())
builder.Reset()
}
continue
}
builder.WriteRune(c)
}
if builder.Len() > 0 {
res = append(res, builder.String())
}
return res
}
func JoinArgs(arr []string, sep string) string {
var builder strings.Builder
for i, s := range arr {
if i > 0 {
builder.WriteString(sep)
}
// 如果包含空格或引号,则需要包裹引号并转义内部引号
if strings.ContainsRune(s, ' ') || strings.ContainsRune(s, '"') {
builder.WriteByte('"')
builder.WriteString(strings.ReplaceAll(s, "\"", "\\\""))
builder.WriteByte('"')
} else {
builder.WriteString(s)
}
}
return builder.String()
}
func UniqueAppend(to []string, from ...any) []string {
exists := make(map[string]struct{}, len(to))
for _, s := range to {
exists[s] = struct{}{}
}
for _, a := range from {
s := String(a)
if _, ok := exists[s]; !ok {
to = append(to, s)
exists[s] = struct{}{}
}
}
return to
}
func ArrayToBoolMap[T comparable](arr []T) map[T]bool {
r := map[T]bool{}
for _, s := range arr {
r[s] = true
}
return r
}
// ToMap 泛型构建新 Map
func ToMap[K comparable, V any](source any) map[K]V {
m := make(map[K]V)
fillToMap(reflect.ValueOf(m), source)
return m
}
// ToSlice 泛型构建新 Slice
func ToSlice[T any](source any) []T {
var s []T
fillToSlice(reflect.ValueOf(&s).Elem(), source)
return s
}
// FillMap 将 source 填充到目标 map 中 (兼容旧 API 逻辑)
func FillMap(target any, source any) {
2026-05-04 09:22:34 +08:00
rv := reflect.ValueOf(target)
for rv.Kind() == reflect.Pointer {
if rv.IsNil() {
if !rv.CanSet() {
return
}
elemType := rv.Type().Elem()
if elemType.Kind() == reflect.Map || elemType.Kind() == reflect.Pointer || elemType.Kind() == reflect.Struct {
rv.Set(reflect.New(elemType))
} else {
return
}
}
rv = rv.Elem()
}
if rv.Kind() != reflect.Map {
return
}
if rv.IsNil() {
if rv.CanSet() {
rv.Set(reflect.MakeMap(rv.Type()))
} else {
return
}
}
fillToMap(rv, source)
}
2026-05-04 09:22:34 +08:00
func fillToMap(rv reflect.Value, source any) {
2026-05-04 09:22:34 +08:00
kt := rv.Type().Key()
vt := rv.Type().Elem()
sv := RealValue(reflect.ValueOf(source))
switch sv.Kind() {
case reflect.Struct:
fillMapFromStruct(rv, sv, kt, vt)
case reflect.Slice, reflect.Array:
for i := 0; i < sv.Len(); i += 2 {
k := sv.Index(i).Interface()
var v any
if i+1 < sv.Len() {
v = sv.Index(i+1).Interface()
}
rv.SetMapIndex(reflectCast(k, kt), reflectCast(v, vt))
}
case reflect.Map:
iter := sv.MapRange()
for iter.Next() {
rv.SetMapIndex(reflectCast(iter.Key().Interface(), kt), reflectCast(iter.Value().Interface(), vt))
}
}
}
func fillMapFromStruct(targetMap, sv reflect.Value, kt, vt reflect.Type) {
st := sv.Type()
for i := 0; i < sv.NumField(); i++ {
field := st.Field(i)
if !field.IsExported() {
continue
}
if field.Anonymous {
fillMapFromStruct(targetMap, sv.Field(i), kt, vt)
continue
}
targetMap.SetMapIndex(reflectCast(GetLowerName(field.Name), kt), reflectCast(sv.Field(i).Interface(), vt))
}
}
// FillSlice 将 source 填充到目标 slice 中 (兼容旧 API 逻辑)
func FillSlice(target any, source any) {
2026-05-04 09:22:34 +08:00
rv := reflect.ValueOf(target)
if rv.Kind() != reflect.Pointer || rv.Elem().Kind() != reflect.Slice {
return
}
fillToSlice(rv.Elem(), source)
}
2026-05-04 09:22:34 +08:00
func fillToSlice(sliceRv reflect.Value, source any) {
et := sliceRv.Type().Elem()
2026-05-04 09:22:34 +08:00
sv := RealValue(reflect.ValueOf(source))
switch sv.Kind() {
case reflect.Map:
keys := sv.MapKeys()
sort.Slice(keys, func(i, j int) bool {
return String(keys[i].Interface()) < String(keys[j].Interface())
})
for _, key := range keys {
sliceRv.Set(reflect.Append(sliceRv, reflectCast(key.Interface(), et)))
sliceRv.Set(reflect.Append(sliceRv, reflectCast(sv.MapIndex(key).Interface(), et)))
2026-05-04 09:22:34 +08:00
}
case reflect.Slice, reflect.Array:
for i := 0; i < sv.Len(); i++ {
sliceRv.Set(reflect.Append(sliceRv, reflectCast(sv.Index(i).Interface(), et)))
}
case reflect.Invalid:
// Nil source, do nothing
default:
sliceRv.Set(reflect.Append(sliceRv, reflectCast(source, et)))
}
}
func reflectCast(value any, t reflect.Type) reflect.Value {
return As(reflectCastE(value, t))
2026-05-04 09:22:34 +08:00
}
// 补充缺失的 Key 转换工具
func GetLowerName(s string) string {
2026-05-04 09:22:34 +08:00
if len(s) > 0 && s[0] >= 'A' && s[0] <= 'Z' {
hasLower := false
for i := 0; i < len(s); i++ {
if s[i] >= 'a' && s[i] <= 'z' {
hasLower = true
break
}
}
if hasLower {
return strings.ToLower(s[:1]) + s[1:]
}
}
2026-05-04 09:22:34 +08:00
return s
}
func GetUpperName(s string) string {
if s == "" {
return ""
}
return strings.ToUpper(s[:1]) + s[1:]
}
// 指针工具
func Ptr[T any](v T) *T { return &v }
// FixUpperCase (保留以支持历史复杂的 Key 转换需求)
func FixUpperCase(data []byte, excludesKeys []string) {
// 原有逻辑保持
n := len(data)
types, keys, tpos := make([]bool, 0), make([]string, 0), -1
for i := 0; i < n-1; i++ {
if tpos+1 >= len(types) {
types = append(types, false)
keys = append(keys, "")
}
switch data[i] {
case '{':
tpos++
types[tpos] = true
keys[tpos] = ""
case '}':
tpos--
case '[':
tpos++
types[tpos] = false
keys[tpos] = ""
case ']':
tpos--
case '"':
keyPos := -1
if i > 0 && (data[i-1] == '{' || (data[i-1] == ',' && tpos >= 0 && types[tpos])) {
keyPos = i + 1
}
i++
for ; i < n-1; i++ {
if data[i] == '\\' {
i++
continue
}
if data[i] == '"' {
if keyPos >= 0 && len(excludesKeys) > 0 {
keys[tpos] = string(data[keyPos:i])
}
break
}
}
if keyPos >= 0 && (data[keyPos] >= 'A' && data[keyPos] <= 'Z') {
excluded := false
if len(excludesKeys) > 0 {
checkStr := strings.Join(keys[0:tpos+1], ".")
for _, ek := range excludesKeys {
if strings.HasSuffix(ek, ".") {
excluded = strings.HasPrefix(checkStr, ek)
} else {
excluded = checkStr == ek
}
if excluded {
break
}
}
}
if !excluded {
hasLower := false
for _, b := range data[keyPos:i] {
if b >= 'a' && b <= 'z' {
hasLower = true
break
}
}
if hasLower {
data[keyPos] |= 0x20
}
}
}
}
}
}