cast/json_decoder.go

397 lines
8.6 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package cast
import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"sync"
"unicode"
)
var (
structFieldMapCache sync.Map
)
type decoder struct {
data []byte
pos int
}
func (d *decoder) skipWhitespace() {
for d.pos < len(d.data) {
char := d.data[d.pos]
if char == ' ' || char == '\t' || char == '\r' || char == '\n' {
d.pos++
} else {
break
}
}
}
func (d *decoder) decode(value any) error {
reflectValue := reflect.ValueOf(value)
if reflectValue.Kind() != reflect.Ptr || reflectValue.IsNil() {
return errors.New("destination must be a non-nil pointer")
}
d.skipWhitespace()
return d.decodeValue(reflectValue.Elem())
}
func (d *decoder) decodeValue(reflectValue reflect.Value) error {
d.skipWhitespace()
if d.pos >= len(d.data) {
return nil
}
char := d.data[d.pos]
if char == 'n' { // null
if string(d.data[d.pos:min(d.pos+4, len(d.data))]) == "null" {
d.pos += 4
reflectValue.Set(reflect.Zero(reflectValue.Type()))
return nil
}
}
// 自动初始化指针
for reflectValue.Kind() == reflect.Ptr {
if reflectValue.IsNil() {
reflectValue.Set(reflect.New(reflectValue.Type().Elem()))
}
reflectValue = reflectValue.Elem()
}
switch char {
case '{':
return d.decodeObject(reflectValue)
case '[':
return d.decodeArray(reflectValue)
case '"':
str, err := d.parseString()
if err != nil {
return err
}
// 使用 cast 的基础转换能力
switch reflectValue.Kind() {
case reflect.String:
reflectValue.SetString(str)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
reflectValue.SetInt(Int64(str))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
reflectValue.SetUint(Uint64(str))
case reflect.Float32, reflect.Float64:
reflectValue.SetFloat(Float64(str))
case reflect.Bool:
reflectValue.SetBool(Bool(str))
default:
// 尝试将字符串解析为具体对象(比如内部又是 JSON
if strings.HasPrefix(str, "{") || strings.HasPrefix(str, "[") {
subDec := &decoder{data: []byte(str)}
return subDec.decodeValue(reflectValue)
}
}
default:
// 数字或布尔值
literal, err := d.parseLiteral()
if err != nil {
return err
}
switch reflectValue.Kind() {
case reflect.Bool:
reflectValue.SetBool(Bool(literal))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
reflectValue.SetInt(Int64(literal))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
reflectValue.SetUint(Uint64(literal))
case reflect.Float32, reflect.Float64:
reflectValue.SetFloat(Float64(literal))
case reflect.String:
reflectValue.SetString(literal)
case reflect.Interface:
reflectValue.Set(reflect.ValueOf(literal))
}
}
return nil
}
func (d *decoder) decodeObject(reflectValue reflect.Value) error {
if d.data[d.pos] != '{' {
return fmt.Errorf("expected '{' at pos %d", d.pos)
}
d.pos++
d.skipWhitespace()
if d.pos < len(d.data) && d.data[d.pos] == '}' {
d.pos++
return nil
}
isMap := reflectValue.Kind() == reflect.Map
isStruct := reflectValue.Kind() == reflect.Struct
if isMap && reflectValue.IsNil() {
reflectValue.Set(reflect.MakeMap(reflectValue.Type()))
}
var fieldMap map[string]int
if isStruct {
fieldMap = getDecoderFieldMap(reflectValue.Type())
}
for {
d.skipWhitespace()
key, err := d.parseString()
if err != nil {
return err
}
d.skipWhitespace()
if d.pos >= len(d.data) || d.data[d.pos] != ':' {
return fmt.Errorf("expected ':' after key at pos %d", d.pos)
}
d.pos++
if isStruct {
// Frictionless 匹配
fieldIndex, ok := matchField(key, fieldMap)
if ok {
if err := d.decodeValue(reflectValue.Field(fieldIndex)); err != nil {
return err
}
} else {
if err := d.skipValue(); err != nil {
return err
}
}
} else if isMap {
keyType := reflectValue.Type().Key()
valueType := reflectValue.Type().Elem()
keyValue := reflect.New(keyType).Elem()
// Key 总是尝试转化
keyValue.Set(reflect.ValueOf(key).Convert(keyType))
valValue := reflect.New(valueType).Elem()
if err := d.decodeValue(valValue); err != nil {
return err
}
reflectValue.SetMapIndex(keyValue, valValue)
} else {
if err := d.skipValue(); err != nil {
return err
}
}
d.skipWhitespace()
if d.pos >= len(d.data) {
return errors.New("unexpected end of object")
}
if d.data[d.pos] == '}' {
d.pos++
break
}
if d.data[d.pos] != ',' {
return fmt.Errorf("expected ',' or '}' at pos %d", d.pos)
}
d.pos++
}
return nil
}
func (d *decoder) decodeArray(reflectValue reflect.Value) error {
if d.data[d.pos] != '[' {
return fmt.Errorf("expected '[' at pos %d", d.pos)
}
d.pos++
d.skipWhitespace()
if d.pos < len(d.data) && d.data[d.pos] == ']' {
d.pos++
return nil
}
isSlice := reflectValue.Kind() == reflect.Slice
for index := 0; ; index++ {
if isSlice {
if index >= reflectValue.Len() {
newCap := reflectValue.Cap() * 2
if newCap < 4 {
newCap = 4
}
newSlice := reflect.MakeSlice(reflectValue.Type(), index+1, newCap)
reflect.Copy(newSlice, reflectValue)
reflectValue.Set(newSlice)
} else {
reflectValue.SetLen(index + 1)
}
if err := d.decodeValue(reflectValue.Index(index)); err != nil {
return err
}
} else {
if err := d.skipValue(); err != nil {
return err
}
}
d.skipWhitespace()
if d.pos >= len(d.data) {
return errors.New("unexpected end of array")
}
if d.data[d.pos] == ']' {
d.pos++
break
}
if d.data[d.pos] != ',' {
return fmt.Errorf("expected ',' or ']' at pos %d", d.pos)
}
d.pos++
}
return nil
}
func (d *decoder) parseString() (string, error) {
d.skipWhitespace()
if d.pos >= len(d.data) || d.data[d.pos] != '"' {
return "", fmt.Errorf("expected '\"' at pos %d", d.pos)
}
d.pos++
start := d.pos
for d.pos < len(d.data) {
if d.data[d.pos] == '\\' {
d.pos += 2
continue
}
if d.data[d.pos] == '"' {
s := string(d.data[start:d.pos])
d.pos++
// 处理转义
if strings.Contains(s, "\\") {
s, _ = strconv.Unquote("\"" + s + "\"")
}
return s, nil
}
d.pos++
}
return "", errors.New("unterminated string")
}
func (d *decoder) parseLiteral() (string, error) {
start := d.pos
for d.pos < len(d.data) {
char := d.data[d.pos]
if char == ' ' || char == '\t' || char == '\r' || char == '\n' || char == ',' || char == '}' || char == ']' || char == ':' {
break
}
d.pos++
}
return string(d.data[start:d.pos]), nil
}
func (d *decoder) skipValue() error {
d.skipWhitespace()
if d.pos >= len(d.data) {
return nil
}
char := d.data[d.pos]
switch char {
case '{':
d.pos++
for d.pos < len(d.data) {
d.skipWhitespace()
if d.data[d.pos] == '}' {
d.pos++
return nil
}
if err := d.skipValue(); err != nil {
return err
}
d.skipWhitespace()
if d.data[d.pos] == ':' {
d.pos++
if err := d.skipValue(); err != nil {
return err
}
}
d.skipWhitespace()
if d.data[d.pos] == ',' {
d.pos++
}
}
case '[':
d.pos++
for d.pos < len(d.data) {
d.skipWhitespace()
if d.data[d.pos] == ']' {
d.pos++
return nil
}
if err := d.skipValue(); err != nil {
return err
}
d.skipWhitespace()
if d.data[d.pos] == ',' {
d.pos++
}
}
case '"':
_, err := d.parseString()
return err
default:
_, err := d.parseLiteral()
return err
}
return nil
}
// Frictionless Logic
func getDecoderFieldMap(reflectType reflect.Type) map[string]int {
if val, ok := structFieldMapCache.Load(reflectType); ok {
return val.(map[string]int)
}
m := make(map[string]int)
for index := 0; index < reflectType.NumField(); index++ {
field := reflectType.Field(index)
if !field.IsExported() {
continue
}
// 1. Tag
tag := field.Tag.Get("json")
if tag != "" && tag != "-" {
m[strings.Split(tag, ",")[0]] = index
}
// 2. 原名
m[field.Name] = index
// 3. 归一化名
m[normalizeKey(field.Name)] = index
}
structFieldMapCache.Store(reflectType, m)
return m
}
func matchField(key string, fieldMap map[string]int) (int, bool) {
// 1. 精确匹配
if index, ok := fieldMap[key]; ok {
return index, true
}
// 2. 归一化匹配 (忽略大小写、下划线等)
if index, ok := fieldMap[normalizeKey(key)]; ok {
return index, true
}
return 0, false
}
func normalizeKey(str string) string {
var builder strings.Builder
builder.Grow(len(str))
for _, char := range str {
if unicode.IsLetter(char) || unicode.IsDigit(char) {
builder.WriteRune(unicode.ToLower(char))
}
}
return builder.String()
}
func fastUnmarshalJSONBytes(data []byte, value any) error {
d := &decoder{data: data}
return d.decode(value)
}