397 lines
8.6 KiB
Go
397 lines
8.6 KiB
Go
|
|
package cast
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"errors"
|
|||
|
|
"fmt"
|
|||
|
|
"reflect"
|
|||
|
|
"strconv"
|
|||
|
|
"strings"
|
|||
|
|
"sync"
|
|||
|
|
"unicode"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
var (
|
|||
|
|
structFieldMapCache sync.Map
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
type decoder struct {
|
|||
|
|
data []byte
|
|||
|
|
pos int
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (d *decoder) skipWhitespace() {
|
|||
|
|
for d.pos < len(d.data) {
|
|||
|
|
char := d.data[d.pos]
|
|||
|
|
if char == ' ' || char == '\t' || char == '\r' || char == '\n' {
|
|||
|
|
d.pos++
|
|||
|
|
} else {
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (d *decoder) decode(value any) error {
|
|||
|
|
reflectValue := reflect.ValueOf(value)
|
|||
|
|
if reflectValue.Kind() != reflect.Ptr || reflectValue.IsNil() {
|
|||
|
|
return errors.New("destination must be a non-nil pointer")
|
|||
|
|
}
|
|||
|
|
d.skipWhitespace()
|
|||
|
|
return d.decodeValue(reflectValue.Elem())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (d *decoder) decodeValue(reflectValue reflect.Value) error {
|
|||
|
|
d.skipWhitespace()
|
|||
|
|
if d.pos >= len(d.data) {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
char := d.data[d.pos]
|
|||
|
|
if char == 'n' { // null
|
|||
|
|
if string(d.data[d.pos:min(d.pos+4, len(d.data))]) == "null" {
|
|||
|
|
d.pos += 4
|
|||
|
|
reflectValue.Set(reflect.Zero(reflectValue.Type()))
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 自动初始化指针
|
|||
|
|
for reflectValue.Kind() == reflect.Ptr {
|
|||
|
|
if reflectValue.IsNil() {
|
|||
|
|
reflectValue.Set(reflect.New(reflectValue.Type().Elem()))
|
|||
|
|
}
|
|||
|
|
reflectValue = reflectValue.Elem()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
switch char {
|
|||
|
|
case '{':
|
|||
|
|
return d.decodeObject(reflectValue)
|
|||
|
|
case '[':
|
|||
|
|
return d.decodeArray(reflectValue)
|
|||
|
|
case '"':
|
|||
|
|
str, err := d.parseString()
|
|||
|
|
if err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
// 使用 cast 的基础转换能力
|
|||
|
|
switch reflectValue.Kind() {
|
|||
|
|
case reflect.String:
|
|||
|
|
reflectValue.SetString(str)
|
|||
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|||
|
|
reflectValue.SetInt(Int64(str))
|
|||
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|||
|
|
reflectValue.SetUint(Uint64(str))
|
|||
|
|
case reflect.Float32, reflect.Float64:
|
|||
|
|
reflectValue.SetFloat(Float64(str))
|
|||
|
|
case reflect.Bool:
|
|||
|
|
reflectValue.SetBool(Bool(str))
|
|||
|
|
default:
|
|||
|
|
// 尝试将字符串解析为具体对象(比如内部又是 JSON)
|
|||
|
|
if strings.HasPrefix(str, "{") || strings.HasPrefix(str, "[") {
|
|||
|
|
subDec := &decoder{data: []byte(str)}
|
|||
|
|
return subDec.decodeValue(reflectValue)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
default:
|
|||
|
|
// 数字或布尔值
|
|||
|
|
literal, err := d.parseLiteral()
|
|||
|
|
if err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
switch reflectValue.Kind() {
|
|||
|
|
case reflect.Bool:
|
|||
|
|
reflectValue.SetBool(Bool(literal))
|
|||
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|||
|
|
reflectValue.SetInt(Int64(literal))
|
|||
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|||
|
|
reflectValue.SetUint(Uint64(literal))
|
|||
|
|
case reflect.Float32, reflect.Float64:
|
|||
|
|
reflectValue.SetFloat(Float64(literal))
|
|||
|
|
case reflect.String:
|
|||
|
|
reflectValue.SetString(literal)
|
|||
|
|
case reflect.Interface:
|
|||
|
|
reflectValue.Set(reflect.ValueOf(literal))
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (d *decoder) decodeObject(reflectValue reflect.Value) error {
|
|||
|
|
if d.data[d.pos] != '{' {
|
|||
|
|
return fmt.Errorf("expected '{' at pos %d", d.pos)
|
|||
|
|
}
|
|||
|
|
d.pos++
|
|||
|
|
d.skipWhitespace()
|
|||
|
|
|
|||
|
|
if d.pos < len(d.data) && d.data[d.pos] == '}' {
|
|||
|
|
d.pos++
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
isMap := reflectValue.Kind() == reflect.Map
|
|||
|
|
isStruct := reflectValue.Kind() == reflect.Struct
|
|||
|
|
|
|||
|
|
if isMap && reflectValue.IsNil() {
|
|||
|
|
reflectValue.Set(reflect.MakeMap(reflectValue.Type()))
|
|||
|
|
}
|
|||
|
|
var fieldMap map[string]int
|
|||
|
|
if isStruct {
|
|||
|
|
fieldMap = getDecoderFieldMap(reflectValue.Type())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for {
|
|||
|
|
d.skipWhitespace()
|
|||
|
|
key, err := d.parseString()
|
|||
|
|
if err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
d.skipWhitespace()
|
|||
|
|
if d.pos >= len(d.data) || d.data[d.pos] != ':' {
|
|||
|
|
return fmt.Errorf("expected ':' after key at pos %d", d.pos)
|
|||
|
|
}
|
|||
|
|
d.pos++
|
|||
|
|
|
|||
|
|
if isStruct {
|
|||
|
|
// Frictionless 匹配
|
|||
|
|
fieldIndex, ok := matchField(key, fieldMap)
|
|||
|
|
if ok {
|
|||
|
|
if err := d.decodeValue(reflectValue.Field(fieldIndex)); err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
} else {
|
|||
|
|
if err := d.skipValue(); err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
} else if isMap {
|
|||
|
|
keyType := reflectValue.Type().Key()
|
|||
|
|
valueType := reflectValue.Type().Elem()
|
|||
|
|
keyValue := reflect.New(keyType).Elem()
|
|||
|
|
// Key 总是尝试转化
|
|||
|
|
keyValue.Set(reflect.ValueOf(key).Convert(keyType))
|
|||
|
|
|
|||
|
|
valValue := reflect.New(valueType).Elem()
|
|||
|
|
if err := d.decodeValue(valValue); err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
reflectValue.SetMapIndex(keyValue, valValue)
|
|||
|
|
} else {
|
|||
|
|
if err := d.skipValue(); err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
d.skipWhitespace()
|
|||
|
|
if d.pos >= len(d.data) {
|
|||
|
|
return errors.New("unexpected end of object")
|
|||
|
|
}
|
|||
|
|
if d.data[d.pos] == '}' {
|
|||
|
|
d.pos++
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
if d.data[d.pos] != ',' {
|
|||
|
|
return fmt.Errorf("expected ',' or '}' at pos %d", d.pos)
|
|||
|
|
}
|
|||
|
|
d.pos++
|
|||
|
|
}
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (d *decoder) decodeArray(reflectValue reflect.Value) error {
|
|||
|
|
if d.data[d.pos] != '[' {
|
|||
|
|
return fmt.Errorf("expected '[' at pos %d", d.pos)
|
|||
|
|
}
|
|||
|
|
d.pos++
|
|||
|
|
d.skipWhitespace()
|
|||
|
|
|
|||
|
|
if d.pos < len(d.data) && d.data[d.pos] == ']' {
|
|||
|
|
d.pos++
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
isSlice := reflectValue.Kind() == reflect.Slice
|
|||
|
|
for index := 0; ; index++ {
|
|||
|
|
if isSlice {
|
|||
|
|
if index >= reflectValue.Len() {
|
|||
|
|
newCap := reflectValue.Cap() * 2
|
|||
|
|
if newCap < 4 {
|
|||
|
|
newCap = 4
|
|||
|
|
}
|
|||
|
|
newSlice := reflect.MakeSlice(reflectValue.Type(), index+1, newCap)
|
|||
|
|
reflect.Copy(newSlice, reflectValue)
|
|||
|
|
reflectValue.Set(newSlice)
|
|||
|
|
} else {
|
|||
|
|
reflectValue.SetLen(index + 1)
|
|||
|
|
}
|
|||
|
|
if err := d.decodeValue(reflectValue.Index(index)); err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
} else {
|
|||
|
|
if err := d.skipValue(); err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
d.skipWhitespace()
|
|||
|
|
if d.pos >= len(d.data) {
|
|||
|
|
return errors.New("unexpected end of array")
|
|||
|
|
}
|
|||
|
|
if d.data[d.pos] == ']' {
|
|||
|
|
d.pos++
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
if d.data[d.pos] != ',' {
|
|||
|
|
return fmt.Errorf("expected ',' or ']' at pos %d", d.pos)
|
|||
|
|
}
|
|||
|
|
d.pos++
|
|||
|
|
}
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (d *decoder) parseString() (string, error) {
|
|||
|
|
d.skipWhitespace()
|
|||
|
|
if d.pos >= len(d.data) || d.data[d.pos] != '"' {
|
|||
|
|
return "", fmt.Errorf("expected '\"' at pos %d", d.pos)
|
|||
|
|
}
|
|||
|
|
d.pos++
|
|||
|
|
start := d.pos
|
|||
|
|
for d.pos < len(d.data) {
|
|||
|
|
if d.data[d.pos] == '\\' {
|
|||
|
|
d.pos += 2
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
if d.data[d.pos] == '"' {
|
|||
|
|
s := string(d.data[start:d.pos])
|
|||
|
|
d.pos++
|
|||
|
|
// 处理转义
|
|||
|
|
if strings.Contains(s, "\\") {
|
|||
|
|
s, _ = strconv.Unquote("\"" + s + "\"")
|
|||
|
|
}
|
|||
|
|
return s, nil
|
|||
|
|
}
|
|||
|
|
d.pos++
|
|||
|
|
}
|
|||
|
|
return "", errors.New("unterminated string")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (d *decoder) parseLiteral() (string, error) {
|
|||
|
|
start := d.pos
|
|||
|
|
for d.pos < len(d.data) {
|
|||
|
|
char := d.data[d.pos]
|
|||
|
|
if char == ' ' || char == '\t' || char == '\r' || char == '\n' || char == ',' || char == '}' || char == ']' || char == ':' {
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
d.pos++
|
|||
|
|
}
|
|||
|
|
return string(d.data[start:d.pos]), nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (d *decoder) skipValue() error {
|
|||
|
|
d.skipWhitespace()
|
|||
|
|
if d.pos >= len(d.data) {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
char := d.data[d.pos]
|
|||
|
|
switch char {
|
|||
|
|
case '{':
|
|||
|
|
d.pos++
|
|||
|
|
for d.pos < len(d.data) {
|
|||
|
|
d.skipWhitespace()
|
|||
|
|
if d.data[d.pos] == '}' {
|
|||
|
|
d.pos++
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
if err := d.skipValue(); err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
d.skipWhitespace()
|
|||
|
|
if d.data[d.pos] == ':' {
|
|||
|
|
d.pos++
|
|||
|
|
if err := d.skipValue(); err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
d.skipWhitespace()
|
|||
|
|
if d.data[d.pos] == ',' {
|
|||
|
|
d.pos++
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
case '[':
|
|||
|
|
d.pos++
|
|||
|
|
for d.pos < len(d.data) {
|
|||
|
|
d.skipWhitespace()
|
|||
|
|
if d.data[d.pos] == ']' {
|
|||
|
|
d.pos++
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
if err := d.skipValue(); err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
d.skipWhitespace()
|
|||
|
|
if d.data[d.pos] == ',' {
|
|||
|
|
d.pos++
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
case '"':
|
|||
|
|
_, err := d.parseString()
|
|||
|
|
return err
|
|||
|
|
default:
|
|||
|
|
_, err := d.parseLiteral()
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Frictionless Logic
|
|||
|
|
|
|||
|
|
func getDecoderFieldMap(reflectType reflect.Type) map[string]int {
|
|||
|
|
if val, ok := structFieldMapCache.Load(reflectType); ok {
|
|||
|
|
return val.(map[string]int)
|
|||
|
|
}
|
|||
|
|
m := make(map[string]int)
|
|||
|
|
for index := 0; index < reflectType.NumField(); index++ {
|
|||
|
|
field := reflectType.Field(index)
|
|||
|
|
if !field.IsExported() {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
// 1. Tag
|
|||
|
|
tag := field.Tag.Get("json")
|
|||
|
|
if tag != "" && tag != "-" {
|
|||
|
|
m[strings.Split(tag, ",")[0]] = index
|
|||
|
|
}
|
|||
|
|
// 2. 原名
|
|||
|
|
m[field.Name] = index
|
|||
|
|
// 3. 归一化名
|
|||
|
|
m[normalizeKey(field.Name)] = index
|
|||
|
|
}
|
|||
|
|
structFieldMapCache.Store(reflectType, m)
|
|||
|
|
return m
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func matchField(key string, fieldMap map[string]int) (int, bool) {
|
|||
|
|
// 1. 精确匹配
|
|||
|
|
if index, ok := fieldMap[key]; ok {
|
|||
|
|
return index, true
|
|||
|
|
}
|
|||
|
|
// 2. 归一化匹配 (忽略大小写、下划线等)
|
|||
|
|
if index, ok := fieldMap[normalizeKey(key)]; ok {
|
|||
|
|
return index, true
|
|||
|
|
}
|
|||
|
|
return 0, false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func normalizeKey(str string) string {
|
|||
|
|
var builder strings.Builder
|
|||
|
|
builder.Grow(len(str))
|
|||
|
|
for _, char := range str {
|
|||
|
|
if unicode.IsLetter(char) || unicode.IsDigit(char) {
|
|||
|
|
builder.WriteRune(unicode.ToLower(char))
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return builder.String()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func fastUnmarshalJSONBytes(data []byte, value any) error {
|
|||
|
|
d := &decoder{data: data}
|
|||
|
|
return d.decode(value)
|
|||
|
|
}
|