cast/json_encoder.go

267 lines
6.9 KiB
Go

package cast
import (
"bytes"
"fmt"
"reflect"
"sort"
"strconv"
"strings"
"sync"
)
var bufferPool = sync.Pool{
New: func() any {
return new(bytes.Buffer)
},
}
type fastEncoder struct {
buffer *bytes.Buffer
desensitizeKeys map[string]bool
}
func (encoder *fastEncoder) encode(value any) error {
if value == nil {
encoder.buffer.WriteString("null")
return nil
}
return encoder.encodeValue(reflect.ValueOf(value), "")
}
func (encoder *fastEncoder) encodeValue(reflectValue reflect.Value, path string) error {
reflectValue = RealValue(reflectValue)
if !reflectValue.IsValid() {
encoder.buffer.WriteString("null")
return nil
}
// 检查是否需要脱敏
if encoder.desensitizeKeys != nil && encoder.desensitizeKeys[path] {
encoder.buffer.WriteString(`"***"`)
return nil
}
switch reflectValue.Kind() {
case reflect.Bool:
if reflectValue.Bool() {
encoder.buffer.WriteString("true")
} else {
encoder.buffer.WriteString("false")
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
encoder.buffer.WriteString(strconv.FormatInt(reflectValue.Int(), 10))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
encoder.buffer.WriteString(strconv.FormatUint(reflectValue.Uint(), 10))
case reflect.Float32, reflect.Float64:
encoder.buffer.WriteString(strconv.FormatFloat(reflectValue.Float(), 'f', -1, 64))
case reflect.String:
encoder.writeString(reflectValue.String())
case reflect.Slice, reflect.Array:
if reflectValue.Type().Elem().Kind() == reflect.Uint8 {
encoder.writeString(string(reflectValue.Bytes()))
return nil
}
return encoder.encodeSlice(reflectValue, path)
case reflect.Map:
return encoder.encodeMap(reflectValue, path)
case reflect.Struct:
return encoder.encodeStruct(reflectValue, path)
case reflect.Interface, reflect.Pointer:
return encoder.encodeValue(reflectValue.Elem(), path)
default:
encoder.writeString(fmt.Sprint(reflectValue.Interface()))
}
return nil
}
func (encoder *fastEncoder) encodeSlice(reflectValue reflect.Value, path string) error {
if reflectValue.IsNil() && reflectValue.Kind() == reflect.Slice {
encoder.buffer.WriteString("[]")
return nil
}
encoder.buffer.WriteByte('[')
for index := 0; index < reflectValue.Len(); index++ {
if index > 0 {
encoder.buffer.WriteByte(',')
}
if err := encoder.encodeValue(reflectValue.Index(index), path); err != nil {
return err
}
}
encoder.buffer.WriteByte(']')
return nil
}
func (encoder *fastEncoder) encodeMap(reflectValue reflect.Value, path string) error {
if reflectValue.IsNil() {
encoder.buffer.WriteString("{}")
return nil
}
// 处理 map[any]any 伪装的数组 (Goja)
if reflectValue.Type().Key().Kind() == reflect.Interface {
isArr := true
length := reflectValue.Len()
for index := 0; index < length; index++ {
if !reflectValue.MapIndex(reflect.ValueOf(index)).IsValid() &&
!reflectValue.MapIndex(reflect.ValueOf(float64(index))).IsValid() &&
!reflectValue.MapIndex(reflect.ValueOf(strconv.Itoa(index))).IsValid() {
isArr = false
break
}
}
if isArr {
encoder.buffer.WriteByte('[')
for index := 0; index < length; index++ {
if index > 0 {
encoder.buffer.WriteByte(',')
}
mapValue := reflectValue.MapIndex(reflect.ValueOf(index))
if !mapValue.IsValid() {
mapValue = reflectValue.MapIndex(reflect.ValueOf(float64(index)))
}
if !mapValue.IsValid() {
mapValue = reflectValue.MapIndex(reflect.ValueOf(strconv.Itoa(index)))
}
if err := encoder.encodeValue(mapValue, path); err != nil {
return err
}
}
encoder.buffer.WriteByte(']')
return nil
}
}
encoder.buffer.WriteByte('{')
keys := reflectValue.MapKeys()
// 为了输出稳定,对 Key 进行排序
sort.Slice(keys, func(index1, index2 int) bool {
return String(keys[index1].Interface()) < String(keys[index2].Interface())
})
for index, key := range keys {
if index > 0 {
encoder.buffer.WriteByte(',')
}
keyName := String(key.Interface())
encoder.writeString(keyName)
encoder.buffer.WriteByte(':')
newPath := keyName
if path != "" {
newPath = path + "." + keyName
}
if err := encoder.encodeValue(reflectValue.MapIndex(key), newPath); err != nil {
return err
}
}
encoder.buffer.WriteByte('}')
return nil
}
func (encoder *fastEncoder) encodeStruct(reflectValue reflect.Value, path string) error {
encoder.buffer.WriteByte('{')
reflectType := reflectValue.Type()
first := true
for index := 0; index < reflectType.NumField(); index++ {
field := reflectType.Field(index)
if !field.IsExported() {
continue
}
// 处理匿名嵌入
if field.Anonymous {
// 这里简单处理,实际上标准库会展开。我们为了保持算法一致性,直接递归。
// 但要注意 JSON Tag 可能会覆盖
continue
}
if !first {
encoder.buffer.WriteByte(',')
}
first = false
// 算法转换 Key
keyName := field.Name
tag := field.Tag.Get("json")
keepKey := strings.Contains(string(field.Tag), "keepKey")
if tag != "" && tag != "-" {
parts := strings.Split(tag, ",")
keyName = parts[0]
} else if !keepKey {
// 执行首字母小写逻辑 (与 FixUpperCase 保持一致)
if len(keyName) > 0 && keyName[0] >= 'A' && keyName[0] <= 'Z' {
// 检查是否有小写字母,如果有则转小写 (UserID -> userID, ID -> ID)
hasLower := false
for charIndex := 0; charIndex < len(keyName); charIndex++ {
if keyName[charIndex] >= 'a' && keyName[charIndex] <= 'z' {
hasLower = true
break
}
}
if hasLower {
keyName = strings.ToLower(keyName[:1]) + keyName[1:]
}
}
}
encoder.writeString(keyName)
encoder.buffer.WriteByte(':')
newPath := keyName
if path != "" {
newPath = path + "." + keyName
}
if err := encoder.encodeValue(reflectValue.Field(index), newPath); err != nil {
return err
}
}
encoder.buffer.WriteByte('}')
return nil
}
func (encoder *fastEncoder) writeString(str string) {
encoder.buffer.WriteByte('"')
for index := 0; index < len(str); index++ {
char := str[index]
if char == '"' || char == '\\' {
encoder.buffer.WriteByte('\\')
encoder.buffer.WriteByte(char)
} else if char == '\n' {
encoder.buffer.WriteString("\\n")
} else if char == '\r' {
encoder.buffer.WriteString("\\r")
} else if char == '\t' {
encoder.buffer.WriteString("\\t")
} else {
encoder.buffer.WriteByte(char)
}
}
encoder.buffer.WriteByte('"')
}
// 导出函数
func fastToJSONBytes(value any, desensitizeKeys ...string) ([]byte, error) {
buffer := bufferPool.Get().(*bytes.Buffer)
buffer.Reset()
defer bufferPool.Put(buffer)
encoder := &fastEncoder{buffer: buffer}
if len(desensitizeKeys) > 0 {
encoder.desensitizeKeys = make(map[string]bool)
for _, key := range desensitizeKeys {
encoder.desensitizeKeys[key] = true
}
}
if err := encoder.encode(value); err != nil {
return nil, err
}
result := make([]byte, buffer.Len())
copy(result, buffer.Bytes())
return result, nil
}