cast/json_decoder.go

496 lines
11 KiB
Go
Raw Permalink Normal View History

package cast
import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"sync"
"time"
"unicode"
"unicode/utf8"
)
var (
structFieldMapCache sync.Map
)
type decoderFieldDescriptor struct {
index int
isTime bool
timeFormat string
normalized string
}
type decoderStructDescriptor struct {
exactMatches map[string]int
fields []decoderFieldDescriptor
}
type decoder struct {
data []byte
pos int
}
func (d *decoder) skipWhitespace() {
for d.pos < len(d.data) {
char := d.data[d.pos]
if char == ' ' || char == '\t' || char == '\r' || char == '\n' {
d.pos++
} else {
break
}
}
}
func (d *decoder) decode(value any) error {
reflectValue := reflect.ValueOf(value)
if reflectValue.Kind() != reflect.Ptr || reflectValue.IsNil() {
return errors.New("destination must be a non-nil pointer")
}
d.skipWhitespace()
return d.decodeValue(reflectValue.Elem(), "")
}
func (d *decoder) decodeValue(reflectValue reflect.Value, timeFormat string) error {
d.skipWhitespace()
if d.pos >= len(d.data) {
return nil
}
char := d.data[d.pos]
if char == 'n' { // null
if string(d.data[d.pos:min(d.pos+4, len(d.data))]) == "null" {
d.pos += 4
reflectValue.Set(reflect.Zero(reflectValue.Type()))
return nil
}
}
// 自动初始化指针
for reflectValue.Kind() == reflect.Ptr {
if reflectValue.IsNil() {
reflectValue.Set(reflect.New(reflectValue.Type().Elem()))
}
reflectValue = reflectValue.Elem()
}
// 处理 time.Time
if reflectValue.Type() == timeType {
if char == '"' {
str, err := d.parseString()
if err != nil {
return err
}
if timeFormat == "" {
timeFormat = "2006-01-02 15:04:05.000"
}
t, err := time.ParseInLocation(timeFormat, str, time.Local)
if err != nil {
// 尝试其他常见格式
if t, err = time.Parse(time.RFC3339, str); err != nil {
t = time.Time{}
}
}
reflectValue.Set(reflect.ValueOf(t))
return nil
}
}
switch char {
case '{':
return d.decodeObject(reflectValue)
case '[':
return d.decodeArray(reflectValue)
case '"':
str, err := d.parseString()
if err != nil {
return err
}
// 使用 cast 的基础转换能力
switch reflectValue.Kind() {
case reflect.String:
reflectValue.SetString(str)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
reflectValue.SetInt(Int64(str))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
reflectValue.SetUint(Uint64(str))
case reflect.Float32, reflect.Float64:
reflectValue.SetFloat(Float64(str))
case reflect.Bool:
reflectValue.SetBool(Bool(str))
default:
// 尝试将字符串解析为具体对象(比如内部又是 JSON
if strings.HasPrefix(str, "{") || strings.HasPrefix(str, "[") {
subDec := &decoder{data: []byte(str)}
return subDec.decodeValue(reflectValue, "")
}
}
default:
// 数字或布尔值
literal, err := d.parseLiteral()
if err != nil {
return err
}
switch reflectValue.Kind() {
case reflect.Bool:
reflectValue.SetBool(Bool(literal))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
reflectValue.SetInt(Int64(literal))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
reflectValue.SetUint(Uint64(literal))
case reflect.Float32, reflect.Float64:
reflectValue.SetFloat(Float64(literal))
case reflect.String:
reflectValue.SetString(literal)
case reflect.Interface:
reflectValue.Set(reflect.ValueOf(literal))
}
}
return nil
}
func (d *decoder) decodeObject(reflectValue reflect.Value) error {
if d.data[d.pos] != '{' {
return fmt.Errorf("expected '{' at pos %d", d.pos)
}
d.pos++
d.skipWhitespace()
if d.pos < len(d.data) && d.data[d.pos] == '}' {
d.pos++
return nil
}
isMap := reflectValue.Kind() == reflect.Map
isStruct := reflectValue.Kind() == reflect.Struct
if isMap && reflectValue.IsNil() {
reflectValue.Set(reflect.MakeMap(reflectValue.Type()))
}
var descriptor *decoderStructDescriptor
if isStruct {
descriptor = getDecoderFieldMap(reflectValue.Type())
}
for {
d.skipWhitespace()
key, err := d.parseString()
if err != nil {
return err
}
d.skipWhitespace()
if d.pos >= len(d.data) || d.data[d.pos] != ':' {
return fmt.Errorf("expected ':' after key at pos %d", d.pos)
}
d.pos++
if isStruct {
// Frictionless 匹配
fieldIndex, isTime, format, ok := matchField(key, descriptor)
if ok {
if isTime {
if err := d.decodeValue(reflectValue.Field(fieldIndex), format); err != nil {
return err
}
} else {
if err := d.decodeValue(reflectValue.Field(fieldIndex), ""); err != nil {
return err
}
}
} else {
if err := d.skipValue(); err != nil {
return err
}
}
} else if isMap {
keyType := reflectValue.Type().Key()
valueType := reflectValue.Type().Elem()
keyValue := reflect.New(keyType).Elem()
// Key 总是尝试转化
keyValue.Set(reflect.ValueOf(key).Convert(keyType))
valValue := reflect.New(valueType).Elem()
if err := d.decodeValue(valValue, ""); err != nil {
return err
}
reflectValue.SetMapIndex(keyValue, valValue)
} else {
if err := d.skipValue(); err != nil {
return err
}
}
d.skipWhitespace()
if d.pos >= len(d.data) {
return errors.New("unexpected end of object")
}
if d.data[d.pos] == '}' {
d.pos++
break
}
if d.data[d.pos] != ',' {
return fmt.Errorf("expected ',' or '}' at pos %d", d.pos)
}
d.pos++
}
return nil
}
func (d *decoder) decodeArray(reflectValue reflect.Value) error {
if d.data[d.pos] != '[' {
return fmt.Errorf("expected '[' at pos %d", d.pos)
}
d.pos++
d.skipWhitespace()
if d.pos < len(d.data) && d.data[d.pos] == ']' {
d.pos++
return nil
}
isSlice := reflectValue.Kind() == reflect.Slice
for index := 0; ; index++ {
if isSlice {
if index >= reflectValue.Cap() {
newCap := reflectValue.Cap() * 2
if newCap < 4 {
newCap = 4
}
newSlice := reflect.MakeSlice(reflectValue.Type(), index+1, newCap)
reflect.Copy(newSlice, reflectValue)
reflectValue.Set(newSlice)
} else {
if index >= reflectValue.Len() {
reflectValue.SetLen(index + 1)
}
}
if err := d.decodeValue(reflectValue.Index(index), ""); err != nil {
return err
}
} else {
if err := d.skipValue(); err != nil {
return err
}
}
d.skipWhitespace()
if d.pos >= len(d.data) {
return errors.New("unexpected end of array")
}
if d.data[d.pos] == ']' {
d.pos++
break
}
if d.data[d.pos] != ',' {
return fmt.Errorf("expected ',' or ']' at pos %d", d.pos)
}
d.pos++
}
return nil
}
func (d *decoder) parseString() (string, error) {
d.skipWhitespace()
if d.pos >= len(d.data) || d.data[d.pos] != '"' {
return "", fmt.Errorf("expected '\"' at pos %d", d.pos)
}
d.pos++
start := d.pos
for d.pos < len(d.data) {
if d.data[d.pos] == '\\' {
d.pos += 2
continue
}
if d.data[d.pos] == '"' {
s := string(d.data[start:d.pos])
d.pos++
// 处理转义
if strings.Contains(s, "\\") {
s, _ = strconv.Unquote("\"" + s + "\"")
}
return s, nil
}
d.pos++
}
return "", errors.New("unterminated string")
}
func (d *decoder) parseLiteral() (string, error) {
start := d.pos
for d.pos < len(d.data) {
char := d.data[d.pos]
if char == ' ' || char == '\t' || char == '\r' || char == '\n' || char == ',' || char == '}' || char == ']' || char == ':' {
break
}
d.pos++
}
return string(d.data[start:d.pos]), nil
}
func (d *decoder) skipValue() error {
d.skipWhitespace()
if d.pos >= len(d.data) {
return nil
}
char := d.data[d.pos]
switch char {
case '{':
d.pos++
for d.pos < len(d.data) {
d.skipWhitespace()
if d.data[d.pos] == '}' {
d.pos++
return nil
}
if err := d.skipValue(); err != nil {
return err
}
d.skipWhitespace()
if d.data[d.pos] == ':' {
d.pos++
if err := d.skipValue(); err != nil {
return err
}
}
d.skipWhitespace()
if d.data[d.pos] == ',' {
d.pos++
}
}
case '[':
d.pos++
for d.pos < len(d.data) {
d.skipWhitespace()
if d.data[d.pos] == ']' {
d.pos++
return nil
}
if err := d.skipValue(); err != nil {
return err
}
d.skipWhitespace()
if d.data[d.pos] == ',' {
d.pos++
}
}
case '"':
_, err := d.parseString()
return err
default:
_, err := d.parseLiteral()
return err
}
return nil
}
// Frictionless Logic
func getDecoderFieldMap(reflectType reflect.Type) *decoderStructDescriptor {
if val, ok := structFieldMapCache.Load(reflectType); ok {
return val.(*decoderStructDescriptor)
}
descriptor := &decoderStructDescriptor{
exactMatches: make(map[string]int),
}
for index := 0; index < reflectType.NumField(); index++ {
field := reflectType.Field(index)
if !field.IsExported() {
continue
}
fieldDesc := decoderFieldDescriptor{
index: index,
}
if field.Type == timeType || (field.Type.Kind() == reflect.Pointer && field.Type.Elem() == timeType) {
fieldDesc.isTime = true
fieldDesc.timeFormat = "2006-01-02 15:04:05.000"
}
// 1. Tag
tag := field.Tag.Get("json")
if tag != "" && tag != "-" {
parts := strings.Split(tag, ",")
tagName := parts[0]
if tagName != "" {
descriptor.exactMatches[tagName] = index
}
for _, part := range parts {
if strings.HasPrefix(part, "format=") {
fieldDesc.timeFormat = strings.TrimPrefix(part, "format=")
}
}
}
// 2. 原名
descriptor.exactMatches[field.Name] = index
// 3. 归一化名
fieldDesc.normalized = normalizeKey(field.Name)
descriptor.fields = append(descriptor.fields, fieldDesc)
}
structFieldMapCache.Store(reflectType, descriptor)
return descriptor
}
func matchField(key string, descriptor *decoderStructDescriptor) (int, bool, string, bool) {
// 1. 精确匹配
if index, ok := descriptor.exactMatches[key]; ok {
// 找到字段后还需要获取其 timeFormat 信息
for _, f := range descriptor.fields {
if f.index == index {
return index, f.isTime, f.timeFormat, true
}
}
return index, false, "", true
}
// 2. 归一化匹配 (忽略大小写、下划线等) - 零分配比对
for _, f := range descriptor.fields {
if normalizeEqual(key, f.normalized) {
return f.index, f.isTime, f.timeFormat, true
}
}
return 0, false, "", false
}
func normalizeEqual(raw string, normalized string) bool {
// normalized 已经是小写且只包含字母数字的字符串
// raw 是原始输入的 Key
j := 0
for _, r := range raw {
if !unicode.IsLetter(r) && !unicode.IsDigit(r) {
continue
}
if j >= len(normalized) {
return false
}
// 比较 normalized 的下一个字符
// 由于 normalized 是由 normalizeKey 生成的,我们可以安全地假设它只包含字母数字且已转小写
// 为了完全正确且无分配地获取 normalized 的下一个 rune
nr, size := utf8.DecodeRuneInString(normalized[j:])
if unicode.ToLower(r) != nr {
return false
}
j += size
}
return j == len(normalized)
}
func normalizeKey(str string) string {
var builder strings.Builder
builder.Grow(len(str))
for _, char := range str {
if unicode.IsLetter(char) || unicode.IsDigit(char) {
builder.WriteRune(unicode.ToLower(char))
}
}
return builder.String()
}
func fastUnmarshalJSONBytes(data []byte, value any) error {
d := &decoder{data: data}
return d.decode(value)
}