cast/json_decoder.go
2026-05-14 15:32:10 +08:00

527 lines
12 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package cast
import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"sync"
"unicode"
"unicode/utf8"
)
var (
structFieldMapCache sync.Map
)
type decoderFieldDescriptor struct {
index int
isTime bool
timeFormat string
normalized string
}
type decoderStructDescriptor struct {
exactMatches map[string]decoderFieldDescriptor
fields []decoderFieldDescriptor
}
type decoder struct {
data []byte
pos int
}
func (d *decoder) skipWhitespace() {
for d.pos < len(d.data) {
char := d.data[d.pos]
if char == ' ' || char == '\t' || char == '\r' || char == '\n' {
d.pos++
} else {
break
}
}
}
func (d *decoder) decode(value any) error {
reflectValue := reflect.ValueOf(value)
if reflectValue.Kind() != reflect.Ptr || reflectValue.IsNil() {
return errors.New("destination must be a non-nil pointer")
}
d.skipWhitespace()
return d.decodeValue(reflectValue.Elem(), "")
}
func (d *decoder) decodeValue(reflectValue reflect.Value, timeFormat string) error {
d.skipWhitespace()
if d.pos >= len(d.data) {
return nil
}
char := d.data[d.pos]
if char == 'n' { // null
if string(d.data[d.pos:min(d.pos+4, len(d.data))]) == "null" {
d.pos += 4
reflectValue.Set(reflect.Zero(reflectValue.Type()))
return nil
}
}
// 自动初始化指针
for reflectValue.Kind() == reflect.Ptr {
if reflectValue.IsNil() {
reflectValue.Set(reflect.New(reflectValue.Type().Elem()))
}
reflectValue = reflectValue.Elem()
}
// 处理 time.Time
if reflectValue.Type() == timeType {
var v any
if char == '"' {
str, err := d.parseString()
if err != nil {
return err
}
v = str
} else if char == 'n' {
if string(d.data[d.pos:min(d.pos+4, len(d.data))]) == "null" {
d.pos += 4
reflectValue.Set(reflect.Zero(reflectValue.Type()))
return nil
}
} else {
literal, err := d.parseLiteral()
if err != nil {
return err
}
v = literal
}
reflectValue.Set(reflect.ValueOf(ToTime(v, timeFormat)))
return nil
}
switch char {
case '{':
if reflectValue.Kind() == reflect.Interface {
m := make(map[string]any)
if err := d.decodeObject(reflect.ValueOf(&m).Elem()); err != nil {
return err
}
reflectValue.Set(reflect.ValueOf(m))
return nil
}
return d.decodeObject(reflectValue)
case '[':
if reflectValue.Kind() == reflect.Interface {
var s []any
if err := d.decodeArray(reflect.ValueOf(&s).Elem()); err != nil {
return err
}
reflectValue.Set(reflect.ValueOf(s))
return nil
}
return d.decodeArray(reflectValue)
case '"':
str, err := d.parseString()
if err != nil {
return err
}
// 使用 cast 的基础转换能力
switch reflectValue.Kind() {
case reflect.String:
reflectValue.SetString(str)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
reflectValue.SetInt(Int64(str))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
reflectValue.SetUint(Uint64(str))
case reflect.Float32, reflect.Float64:
reflectValue.SetFloat(Float64(str))
case reflect.Bool:
reflectValue.SetBool(Bool(str))
case reflect.Interface:
reflectValue.Set(reflect.ValueOf(str))
default:
// 尝试将字符串解析为具体对象(比如内部又是 JSON
if strings.HasPrefix(str, "{") || strings.HasPrefix(str, "[") {
subDec := &decoder{data: []byte(str)}
return subDec.decodeValue(reflectValue, "")
}
}
default:
// 数字或布尔值
literal, err := d.parseLiteral()
if err != nil {
return err
}
switch reflectValue.Kind() {
case reflect.Bool:
reflectValue.SetBool(Bool(literal))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
reflectValue.SetInt(Int64(literal))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
reflectValue.SetUint(Uint64(literal))
case reflect.Float32, reflect.Float64:
reflectValue.SetFloat(Float64(literal))
case reflect.String:
reflectValue.SetString(literal)
case reflect.Interface:
// 优先作为数字处理以保留精度 (int64)
if literal == "true" {
reflectValue.Set(reflect.ValueOf(true))
} else if literal == "false" {
reflectValue.Set(reflect.ValueOf(false))
} else if literal == "null" {
reflectValue.Set(reflect.Zero(reflectValue.Type()))
} else if strings.Contains(literal, ".") || strings.Contains(literal, "e") || strings.Contains(literal, "E") {
reflectValue.Set(reflect.ValueOf(Float64(literal)))
} else if i, err := strconv.ParseInt(literal, 10, 64); err == nil {
reflectValue.Set(reflect.ValueOf(i))
} else {
reflectValue.Set(reflect.ValueOf(literal))
}
}
}
return nil
}
func (d *decoder) decodeObject(reflectValue reflect.Value) error {
if d.data[d.pos] != '{' {
return fmt.Errorf("expected '{' at pos %d", d.pos)
}
d.pos++
d.skipWhitespace()
if d.pos < len(d.data) && d.data[d.pos] == '}' {
d.pos++
return nil
}
isMap := reflectValue.Kind() == reflect.Map
isStruct := reflectValue.Kind() == reflect.Struct
if isMap && reflectValue.IsNil() {
reflectValue.Set(reflect.MakeMap(reflectValue.Type()))
}
var descriptor *decoderStructDescriptor
if isStruct {
descriptor = getDecoderFieldMap(reflectValue.Type())
}
for {
d.skipWhitespace()
key, err := d.parseString()
if err != nil {
return err
}
d.skipWhitespace()
if d.pos >= len(d.data) || d.data[d.pos] != ':' {
return fmt.Errorf("expected ':' after key at pos %d", d.pos)
}
d.pos++
if isStruct {
// Frictionless 匹配
fieldIndex, isTime, format, ok := matchField(key, descriptor)
if ok {
if isTime {
if err := d.decodeValue(reflectValue.Field(fieldIndex), format); err != nil {
return err
}
} else {
if err := d.decodeValue(reflectValue.Field(fieldIndex), ""); err != nil {
return err
}
}
} else {
if err := d.skipValue(); err != nil {
return err
}
}
} else if isMap {
keyType := reflectValue.Type().Key()
valueType := reflectValue.Type().Elem()
keyValue := reflect.New(keyType).Elem()
// Key 总是尝试转化
keyValue.Set(reflect.ValueOf(key).Convert(keyType))
valValue := reflect.New(valueType).Elem()
if err := d.decodeValue(valValue, ""); err != nil {
return err
}
reflectValue.SetMapIndex(keyValue, valValue)
} else {
if err := d.skipValue(); err != nil {
return err
}
}
d.skipWhitespace()
if d.pos >= len(d.data) {
return errors.New("unexpected end of object")
}
if d.data[d.pos] == '}' {
d.pos++
break
}
if d.data[d.pos] != ',' {
return fmt.Errorf("expected ',' or '}' at pos %d", d.pos)
}
d.pos++
}
return nil
}
func (d *decoder) decodeArray(reflectValue reflect.Value) error {
if d.data[d.pos] != '[' {
return fmt.Errorf("expected '[' at pos %d", d.pos)
}
d.pos++
d.skipWhitespace()
if d.pos < len(d.data) && d.data[d.pos] == ']' {
d.pos++
return nil
}
isSlice := reflectValue.Kind() == reflect.Slice
for index := 0; ; index++ {
if isSlice {
if index >= reflectValue.Cap() {
newCap := reflectValue.Cap() * 2
if newCap < 4 {
newCap = 4
}
newSlice := reflect.MakeSlice(reflectValue.Type(), index+1, newCap)
reflect.Copy(newSlice, reflectValue)
reflectValue.Set(newSlice)
} else {
if index >= reflectValue.Len() {
reflectValue.SetLen(index + 1)
}
}
if err := d.decodeValue(reflectValue.Index(index), ""); err != nil {
return err
}
} else {
if err := d.skipValue(); err != nil {
return err
}
}
d.skipWhitespace()
if d.pos >= len(d.data) {
return errors.New("unexpected end of array")
}
if d.data[d.pos] == ']' {
d.pos++
break
}
if d.data[d.pos] != ',' {
return fmt.Errorf("expected ',' or ']' at pos %d", d.pos)
}
d.pos++
}
return nil
}
func (d *decoder) parseString() (string, error) {
d.skipWhitespace()
if d.pos >= len(d.data) || d.data[d.pos] != '"' {
return "", fmt.Errorf("expected '\"' at pos %d", d.pos)
}
d.pos++
start := d.pos
for d.pos < len(d.data) {
if d.data[d.pos] == '\\' {
d.pos += 2
continue
}
if d.data[d.pos] == '"' {
s := string(d.data[start:d.pos])
d.pos++
// 处理转义
if strings.Contains(s, "\\") {
s, _ = strconv.Unquote("\"" + s + "\"")
}
return s, nil
}
d.pos++
}
return "", errors.New("unterminated string")
}
func (d *decoder) parseLiteral() (string, error) {
start := d.pos
for d.pos < len(d.data) {
char := d.data[d.pos]
if char == ' ' || char == '\t' || char == '\r' || char == '\n' || char == ',' || char == '}' || char == ']' || char == ':' {
break
}
d.pos++
}
return string(d.data[start:d.pos]), nil
}
func (d *decoder) skipValue() error {
d.skipWhitespace()
if d.pos >= len(d.data) {
return nil
}
char := d.data[d.pos]
switch char {
case '{':
d.pos++
for d.pos < len(d.data) {
d.skipWhitespace()
if d.data[d.pos] == '}' {
d.pos++
return nil
}
if err := d.skipValue(); err != nil {
return err
}
d.skipWhitespace()
if d.data[d.pos] == ':' {
d.pos++
if err := d.skipValue(); err != nil {
return err
}
}
d.skipWhitespace()
if d.data[d.pos] == ',' {
d.pos++
}
}
case '[':
d.pos++
for d.pos < len(d.data) {
d.skipWhitespace()
if d.data[d.pos] == ']' {
d.pos++
return nil
}
if err := d.skipValue(); err != nil {
return err
}
d.skipWhitespace()
if d.data[d.pos] == ',' {
d.pos++
}
}
case '"':
_, err := d.parseString()
return err
default:
_, err := d.parseLiteral()
return err
}
return nil
}
// Frictionless Logic
func getDecoderFieldMap(reflectType reflect.Type) *decoderStructDescriptor {
if val, ok := structFieldMapCache.Load(reflectType); ok {
return val.(*decoderStructDescriptor)
}
descriptor := &decoderStructDescriptor{
exactMatches: make(map[string]decoderFieldDescriptor),
}
for index := 0; index < reflectType.NumField(); index++ {
field := reflectType.Field(index)
if !field.IsExported() {
continue
}
fieldDesc := decoderFieldDescriptor{
index: index,
}
if field.Type == timeType || (field.Type.Kind() == reflect.Pointer && field.Type.Elem() == timeType) {
fieldDesc.isTime = true
fieldDesc.timeFormat = "2006-01-02 15:04:05.000"
}
// 1. Tag
tag := field.Tag.Get("json")
if tag == "-" {
continue
}
if tag != "" {
parts := strings.Split(tag, ",")
tagName := parts[0]
for _, part := range parts {
if strings.HasPrefix(part, "format=") {
fieldDesc.timeFormat = strings.TrimPrefix(part, "format=")
}
}
if tagName != "" {
descriptor.exactMatches[tagName] = fieldDesc
}
}
// 2. 原名
descriptor.exactMatches[field.Name] = fieldDesc
// 3. 归一化名
fieldDesc.normalized = normalizeKey(field.Name)
descriptor.fields = append(descriptor.fields, fieldDesc)
}
structFieldMapCache.Store(reflectType, descriptor)
return descriptor
}
func matchField(key string, descriptor *decoderStructDescriptor) (int, bool, string, bool) {
// 1. 精确匹配
if f, ok := descriptor.exactMatches[key]; ok {
return f.index, f.isTime, f.timeFormat, true
}
// 2. 归一化匹配 (忽略大小写、下划线等) - 零分配比对
for _, f := range descriptor.fields {
if normalizeEqual(key, f.normalized) {
return f.index, f.isTime, f.timeFormat, true
}
}
return 0, false, "", false
}
func normalizeEqual(raw string, normalized string) bool {
// normalized 已经是小写且只包含字母数字的字符串
// raw 是原始输入的 Key
j := 0
for _, r := range raw {
if !unicode.IsLetter(r) && !unicode.IsDigit(r) {
continue
}
if j >= len(normalized) {
return false
}
// 比较 normalized 的下一个字符
// 由于 normalized 是由 normalizeKey 生成的,我们可以安全地假设它只包含字母数字且已转小写
// 为了完全正确且无分配地获取 normalized 的下一个 rune
nr, size := utf8.DecodeRuneInString(normalized[j:])
if unicode.ToLower(r) != nr {
return false
}
j += size
}
return j == len(normalized)
}
func normalizeKey(str string) string {
var builder strings.Builder
builder.Grow(len(str))
for _, char := range str {
if unicode.IsLetter(char) || unicode.IsDigit(char) {
builder.WriteRune(unicode.ToLower(char))
}
}
return builder.String()
}
func fastUnmarshalJSONBytes(data []byte, value any) error {
d := &decoder{data: data}
return d.decode(value)
}