cast/json_encoder.go

353 lines
8.8 KiB
Go

package cast
import (
"bytes"
"fmt"
"reflect"
"strconv"
"strings"
"sync"
"time"
)
var (
bufferPool = sync.Pool{
New: func() any {
return new(bytes.Buffer)
},
}
encoderStructCache sync.Map
timeType = reflect.TypeOf(time.Time{})
)
type encoderFieldDescriptor struct {
index int
name string
isAnonymous bool
isTime bool
timeFormat string
keepKey bool
}
type encoderStructDescriptor struct {
fields []encoderFieldDescriptor
}
type fastEncoder struct {
buffer *bytes.Buffer
desensitizeKeys map[string]bool
}
func (encoder *fastEncoder) encode(value any) error {
if value == nil {
encoder.buffer.WriteString("null")
return nil
}
return encoder.encodeValue(reflect.ValueOf(value), "")
}
func (encoder *fastEncoder) encodeValue(reflectValue reflect.Value, path string) error {
reflectValue = RealValue(reflectValue)
if !reflectValue.IsValid() {
encoder.buffer.WriteString("null")
return nil
}
// 检查是否需要脱敏
if encoder.desensitizeKeys != nil && encoder.desensitizeKeys[path] {
encoder.buffer.WriteString(`"***"`)
return nil
}
// 处理 time.Time
if reflectValue.Type() == timeType {
encoder.writeTime(reflectValue.Interface().(time.Time), "2006-01-02 15:04:05.000")
return nil
}
switch reflectValue.Kind() {
case reflect.Bool:
if reflectValue.Bool() {
encoder.buffer.WriteString("true")
} else {
encoder.buffer.WriteString("false")
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
encoder.buffer.WriteString(strconv.FormatInt(reflectValue.Int(), 10))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
encoder.buffer.WriteString(strconv.FormatUint(reflectValue.Uint(), 10))
case reflect.Float32, reflect.Float64:
encoder.buffer.WriteString(strconv.FormatFloat(reflectValue.Float(), 'f', -1, 64))
case reflect.String:
encoder.writeString(reflectValue.String())
case reflect.Slice, reflect.Array:
if reflectValue.Type().Elem().Kind() == reflect.Uint8 {
encoder.writeString(string(reflectValue.Bytes()))
return nil
}
return encoder.encodeSlice(reflectValue, path)
case reflect.Map:
return encoder.encodeMap(reflectValue, path)
case reflect.Struct:
return encoder.encodeStruct(reflectValue, path)
case reflect.Interface, reflect.Pointer:
return encoder.encodeValue(reflectValue.Elem(), path)
default:
encoder.writeString(fmt.Sprint(reflectValue.Interface()))
}
return nil
}
func (encoder *fastEncoder) writeTime(t time.Time, format string) {
encoder.buffer.WriteByte('"')
encoder.buffer.WriteString(t.Format(format))
encoder.buffer.WriteByte('"')
}
func (encoder *fastEncoder) encodeSlice(reflectValue reflect.Value, path string) error {
if reflectValue.IsNil() && reflectValue.Kind() == reflect.Slice {
encoder.buffer.WriteString("[]")
return nil
}
encoder.buffer.WriteByte('[')
for index := 0; index < reflectValue.Len(); index++ {
if index > 0 {
encoder.buffer.WriteByte(',')
}
if err := encoder.encodeValue(reflectValue.Index(index), path); err != nil {
return err
}
}
encoder.buffer.WriteByte(']')
return nil
}
func (encoder *fastEncoder) encodeMap(reflectValue reflect.Value, path string) error {
if reflectValue.IsNil() {
encoder.buffer.WriteString("{}")
return nil
}
// 处理 map[any]any 伪装的数组 (Goja)
if reflectValue.Type().Key().Kind() == reflect.Interface {
isArr := true
length := reflectValue.Len()
for index := 0; index < length; index++ {
if !reflectValue.MapIndex(reflect.ValueOf(index)).IsValid() &&
!reflectValue.MapIndex(reflect.ValueOf(float64(index))).IsValid() &&
!reflectValue.MapIndex(reflect.ValueOf(strconv.Itoa(index))).IsValid() {
isArr = false
break
}
}
if isArr {
encoder.buffer.WriteByte('[')
for index := 0; index < length; index++ {
if index > 0 {
encoder.buffer.WriteByte(',')
}
mapValue := reflectValue.MapIndex(reflect.ValueOf(index))
if !mapValue.IsValid() {
mapValue = reflectValue.MapIndex(reflect.ValueOf(float64(index)))
}
if !mapValue.IsValid() {
mapValue = reflectValue.MapIndex(reflect.ValueOf(strconv.Itoa(index)))
}
if err := encoder.encodeValue(mapValue, path); err != nil {
return err
}
}
encoder.buffer.WriteByte(']')
return nil
}
}
encoder.buffer.WriteByte('{')
iter := reflectValue.MapRange()
isFirst := true
for iter.Next() {
if !isFirst {
encoder.buffer.WriteByte(',')
}
isFirst = false
key := iter.Key()
keyName := String(key.Interface())
encoder.writeString(keyName)
encoder.buffer.WriteByte(':')
newPath := ""
if encoder.desensitizeKeys != nil {
newPath = keyName
if path != "" {
newPath = path + "." + keyName
}
}
if err := encoder.encodeValue(iter.Value(), newPath); err != nil {
return err
}
}
encoder.buffer.WriteByte('}')
return nil
}
func getEncoderStructDescriptor(reflectType reflect.Type) *encoderStructDescriptor {
if val, ok := encoderStructCache.Load(reflectType); ok {
return val.(*encoderStructDescriptor)
}
descriptor := &encoderStructDescriptor{}
for index := 0; index < reflectType.NumField(); index++ {
field := reflectType.Field(index)
if !field.IsExported() {
continue
}
fieldDesc := encoderFieldDescriptor{
index: index,
name: field.Name,
isAnonymous: field.Anonymous,
}
if field.Type == timeType || (field.Type.Kind() == reflect.Pointer && field.Type.Elem() == timeType) {
fieldDesc.isTime = true
fieldDesc.timeFormat = "2006-01-02 15:04:05.000"
}
tag := field.Tag.Get("json")
fieldDesc.keepKey = strings.Contains(string(field.Tag), "keepKey")
if tag != "" && tag != "-" {
parts := strings.Split(tag, ",")
for _, part := range parts {
if strings.HasPrefix(part, "format=") {
fieldDesc.timeFormat = strings.TrimPrefix(part, "format=")
} else if fieldDesc.name == field.Name && part != "" { // 防止空 tag 抹掉字段名
fieldDesc.name = part
}
}
}
if tag == "" && !fieldDesc.keepKey && !field.Anonymous {
fieldDesc.name = GetLowerName(field.Name)
}
descriptor.fields = append(descriptor.fields, fieldDesc)
}
encoderStructCache.Store(reflectType, descriptor)
return descriptor
}
func (encoder *fastEncoder) encodeStruct(reflectValue reflect.Value, path string) error {
encoder.buffer.WriteByte('{')
first := true
err := encoder.encodeStructFields(reflectValue, path, &first)
encoder.buffer.WriteByte('}')
return err
}
func (encoder *fastEncoder) encodeStructFields(reflectValue reflect.Value, path string, first *bool) error {
descriptor := getEncoderStructDescriptor(reflectValue.Type())
for _, fieldDesc := range descriptor.fields {
fieldValue := reflectValue.Field(fieldDesc.index)
if fieldDesc.isAnonymous {
fieldValue = RealValue(fieldValue)
if fieldValue.Kind() == reflect.Struct {
if err := encoder.encodeStructFields(fieldValue, path, first); err != nil {
return err
}
continue
}
}
if !*first {
encoder.buffer.WriteByte(',')
}
*first = false
encoder.writeString(fieldDesc.name)
encoder.buffer.WriteByte(':')
newPath := ""
if encoder.desensitizeKeys != nil {
newPath = fieldDesc.name
if path != "" {
newPath = path + "." + fieldDesc.name
}
}
if fieldDesc.isTime {
v := RealValue(fieldValue)
if v.IsValid() && v.Type() == timeType {
encoder.writeTime(v.Interface().(time.Time), fieldDesc.timeFormat)
continue
}
}
if err := encoder.encodeValue(fieldValue, newPath); err != nil {
return err
}
}
return nil
}
func (encoder *fastEncoder) writeString(str string) {
encoder.buffer.WriteByte('"')
start := 0
for i := 0; i < len(str); i++ {
char := str[i]
if char < 0x20 || char == '"' || char == '\\' {
if start < i {
encoder.buffer.WriteString(str[start:i])
}
switch char {
case '"':
encoder.buffer.WriteString(`\"`)
case '\\':
encoder.buffer.WriteString(`\\`)
case '\n':
encoder.buffer.WriteString(`\n`)
case '\r':
encoder.buffer.WriteString(`\r`)
case '\t':
encoder.buffer.WriteString(`\t`)
default:
// 其他不可见字符
encoder.buffer.WriteString(`\u00`)
encoder.buffer.WriteByte("0123456789abcdef"[char>>4])
encoder.buffer.WriteByte("0123456789abcdef"[char&0xf])
}
start = i + 1
}
}
if start < len(str) {
encoder.buffer.WriteString(str[start:])
}
encoder.buffer.WriteByte('"')
}
// 导出函数
func fastToJSONBytes(value any, desensitizeKeys ...string) ([]byte, error) {
buffer := bufferPool.Get().(*bytes.Buffer)
buffer.Reset()
defer bufferPool.Put(buffer)
encoder := &fastEncoder{buffer: buffer}
if len(desensitizeKeys) > 0 {
encoder.desensitizeKeys = make(map[string]bool)
for _, key := range desensitizeKeys {
encoder.desensitizeKeys[key] = true
}
}
if err := encoder.encode(value); err != nil {
return nil, err
}
result := make([]byte, buffer.Len())
copy(result, buffer.Bytes())
return result, nil
}