353 lines
8.8 KiB
Go
353 lines
8.8 KiB
Go
package cast
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
var (
|
|
bufferPool = sync.Pool{
|
|
New: func() any {
|
|
return new(bytes.Buffer)
|
|
},
|
|
}
|
|
encoderStructCache sync.Map
|
|
timeType = reflect.TypeOf(time.Time{})
|
|
)
|
|
|
|
type encoderFieldDescriptor struct {
|
|
index int
|
|
name string
|
|
isAnonymous bool
|
|
isTime bool
|
|
timeFormat string
|
|
keepKey bool
|
|
}
|
|
|
|
type encoderStructDescriptor struct {
|
|
fields []encoderFieldDescriptor
|
|
}
|
|
|
|
type fastEncoder struct {
|
|
buffer *bytes.Buffer
|
|
desensitizeKeys map[string]bool
|
|
}
|
|
|
|
func (encoder *fastEncoder) encode(value any) error {
|
|
if value == nil {
|
|
encoder.buffer.WriteString("null")
|
|
return nil
|
|
}
|
|
return encoder.encodeValue(reflect.ValueOf(value), "")
|
|
}
|
|
|
|
func (encoder *fastEncoder) encodeValue(reflectValue reflect.Value, path string) error {
|
|
reflectValue = RealValue(reflectValue)
|
|
if !reflectValue.IsValid() {
|
|
encoder.buffer.WriteString("null")
|
|
return nil
|
|
}
|
|
|
|
// 检查是否需要脱敏
|
|
if encoder.desensitizeKeys != nil && encoder.desensitizeKeys[path] {
|
|
encoder.buffer.WriteString(`"***"`)
|
|
return nil
|
|
}
|
|
|
|
// 处理 time.Time
|
|
if reflectValue.Type() == timeType {
|
|
encoder.writeTime(reflectValue.Interface().(time.Time), "2006-01-02 15:04:05.000")
|
|
return nil
|
|
}
|
|
|
|
switch reflectValue.Kind() {
|
|
case reflect.Bool:
|
|
if reflectValue.Bool() {
|
|
encoder.buffer.WriteString("true")
|
|
} else {
|
|
encoder.buffer.WriteString("false")
|
|
}
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
encoder.buffer.WriteString(strconv.FormatInt(reflectValue.Int(), 10))
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
encoder.buffer.WriteString(strconv.FormatUint(reflectValue.Uint(), 10))
|
|
case reflect.Float32, reflect.Float64:
|
|
encoder.buffer.WriteString(strconv.FormatFloat(reflectValue.Float(), 'f', -1, 64))
|
|
case reflect.String:
|
|
encoder.writeString(reflectValue.String())
|
|
case reflect.Slice, reflect.Array:
|
|
if reflectValue.Type().Elem().Kind() == reflect.Uint8 {
|
|
encoder.writeString(string(reflectValue.Bytes()))
|
|
return nil
|
|
}
|
|
return encoder.encodeSlice(reflectValue, path)
|
|
case reflect.Map:
|
|
return encoder.encodeMap(reflectValue, path)
|
|
case reflect.Struct:
|
|
return encoder.encodeStruct(reflectValue, path)
|
|
case reflect.Interface, reflect.Pointer:
|
|
return encoder.encodeValue(reflectValue.Elem(), path)
|
|
default:
|
|
encoder.writeString(fmt.Sprint(reflectValue.Interface()))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (encoder *fastEncoder) writeTime(t time.Time, format string) {
|
|
encoder.buffer.WriteByte('"')
|
|
encoder.buffer.WriteString(t.Format(format))
|
|
encoder.buffer.WriteByte('"')
|
|
}
|
|
|
|
func (encoder *fastEncoder) encodeSlice(reflectValue reflect.Value, path string) error {
|
|
if reflectValue.IsNil() && reflectValue.Kind() == reflect.Slice {
|
|
encoder.buffer.WriteString("[]")
|
|
return nil
|
|
}
|
|
encoder.buffer.WriteByte('[')
|
|
for index := 0; index < reflectValue.Len(); index++ {
|
|
if index > 0 {
|
|
encoder.buffer.WriteByte(',')
|
|
}
|
|
if err := encoder.encodeValue(reflectValue.Index(index), path); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
encoder.buffer.WriteByte(']')
|
|
return nil
|
|
}
|
|
|
|
func (encoder *fastEncoder) encodeMap(reflectValue reflect.Value, path string) error {
|
|
if reflectValue.IsNil() {
|
|
encoder.buffer.WriteString("{}")
|
|
return nil
|
|
}
|
|
|
|
// 处理 map[any]any 伪装的数组 (Goja)
|
|
if reflectValue.Type().Key().Kind() == reflect.Interface {
|
|
isArr := true
|
|
length := reflectValue.Len()
|
|
for index := 0; index < length; index++ {
|
|
if !reflectValue.MapIndex(reflect.ValueOf(index)).IsValid() &&
|
|
!reflectValue.MapIndex(reflect.ValueOf(float64(index))).IsValid() &&
|
|
!reflectValue.MapIndex(reflect.ValueOf(strconv.Itoa(index))).IsValid() {
|
|
isArr = false
|
|
break
|
|
}
|
|
}
|
|
if isArr {
|
|
encoder.buffer.WriteByte('[')
|
|
for index := 0; index < length; index++ {
|
|
if index > 0 {
|
|
encoder.buffer.WriteByte(',')
|
|
}
|
|
mapValue := reflectValue.MapIndex(reflect.ValueOf(index))
|
|
if !mapValue.IsValid() {
|
|
mapValue = reflectValue.MapIndex(reflect.ValueOf(float64(index)))
|
|
}
|
|
if !mapValue.IsValid() {
|
|
mapValue = reflectValue.MapIndex(reflect.ValueOf(strconv.Itoa(index)))
|
|
}
|
|
if err := encoder.encodeValue(mapValue, path); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
encoder.buffer.WriteByte(']')
|
|
return nil
|
|
}
|
|
}
|
|
|
|
encoder.buffer.WriteByte('{')
|
|
iter := reflectValue.MapRange()
|
|
isFirst := true
|
|
for iter.Next() {
|
|
if !isFirst {
|
|
encoder.buffer.WriteByte(',')
|
|
}
|
|
isFirst = false
|
|
|
|
key := iter.Key()
|
|
keyName := String(key.Interface())
|
|
encoder.writeString(keyName)
|
|
encoder.buffer.WriteByte(':')
|
|
|
|
newPath := ""
|
|
if encoder.desensitizeKeys != nil {
|
|
newPath = keyName
|
|
if path != "" {
|
|
newPath = path + "." + keyName
|
|
}
|
|
}
|
|
if err := encoder.encodeValue(iter.Value(), newPath); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
encoder.buffer.WriteByte('}')
|
|
return nil
|
|
}
|
|
|
|
func getEncoderStructDescriptor(reflectType reflect.Type) *encoderStructDescriptor {
|
|
if val, ok := encoderStructCache.Load(reflectType); ok {
|
|
return val.(*encoderStructDescriptor)
|
|
}
|
|
|
|
descriptor := &encoderStructDescriptor{}
|
|
for index := 0; index < reflectType.NumField(); index++ {
|
|
field := reflectType.Field(index)
|
|
if !field.IsExported() {
|
|
continue
|
|
}
|
|
|
|
fieldDesc := encoderFieldDescriptor{
|
|
index: index,
|
|
name: field.Name,
|
|
isAnonymous: field.Anonymous,
|
|
}
|
|
|
|
if field.Type == timeType || (field.Type.Kind() == reflect.Pointer && field.Type.Elem() == timeType) {
|
|
fieldDesc.isTime = true
|
|
fieldDesc.timeFormat = "2006-01-02 15:04:05.000"
|
|
}
|
|
|
|
tag := field.Tag.Get("json")
|
|
fieldDesc.keepKey = strings.Contains(string(field.Tag), "keepKey")
|
|
|
|
if tag != "" && tag != "-" {
|
|
parts := strings.Split(tag, ",")
|
|
for _, part := range parts {
|
|
if strings.HasPrefix(part, "format=") {
|
|
fieldDesc.timeFormat = strings.TrimPrefix(part, "format=")
|
|
} else if fieldDesc.name == field.Name && part != "" { // 防止空 tag 抹掉字段名
|
|
fieldDesc.name = part
|
|
}
|
|
}
|
|
}
|
|
|
|
if tag == "" && !fieldDesc.keepKey && !field.Anonymous {
|
|
fieldDesc.name = GetLowerName(field.Name)
|
|
}
|
|
|
|
descriptor.fields = append(descriptor.fields, fieldDesc)
|
|
}
|
|
|
|
encoderStructCache.Store(reflectType, descriptor)
|
|
return descriptor
|
|
}
|
|
|
|
func (encoder *fastEncoder) encodeStruct(reflectValue reflect.Value, path string) error {
|
|
encoder.buffer.WriteByte('{')
|
|
first := true
|
|
err := encoder.encodeStructFields(reflectValue, path, &first)
|
|
encoder.buffer.WriteByte('}')
|
|
return err
|
|
}
|
|
|
|
func (encoder *fastEncoder) encodeStructFields(reflectValue reflect.Value, path string, first *bool) error {
|
|
descriptor := getEncoderStructDescriptor(reflectValue.Type())
|
|
for _, fieldDesc := range descriptor.fields {
|
|
fieldValue := reflectValue.Field(fieldDesc.index)
|
|
|
|
if fieldDesc.isAnonymous {
|
|
fieldValue = RealValue(fieldValue)
|
|
if fieldValue.Kind() == reflect.Struct {
|
|
if err := encoder.encodeStructFields(fieldValue, path, first); err != nil {
|
|
return err
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
|
|
if !*first {
|
|
encoder.buffer.WriteByte(',')
|
|
}
|
|
*first = false
|
|
|
|
encoder.writeString(fieldDesc.name)
|
|
encoder.buffer.WriteByte(':')
|
|
|
|
newPath := ""
|
|
if encoder.desensitizeKeys != nil {
|
|
newPath = fieldDesc.name
|
|
if path != "" {
|
|
newPath = path + "." + fieldDesc.name
|
|
}
|
|
}
|
|
|
|
if fieldDesc.isTime {
|
|
v := RealValue(fieldValue)
|
|
if v.IsValid() && v.Type() == timeType {
|
|
encoder.writeTime(v.Interface().(time.Time), fieldDesc.timeFormat)
|
|
continue
|
|
}
|
|
}
|
|
|
|
if err := encoder.encodeValue(fieldValue, newPath); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (encoder *fastEncoder) writeString(str string) {
|
|
encoder.buffer.WriteByte('"')
|
|
start := 0
|
|
for i := 0; i < len(str); i++ {
|
|
char := str[i]
|
|
if char < 0x20 || char == '"' || char == '\\' {
|
|
if start < i {
|
|
encoder.buffer.WriteString(str[start:i])
|
|
}
|
|
switch char {
|
|
case '"':
|
|
encoder.buffer.WriteString(`\"`)
|
|
case '\\':
|
|
encoder.buffer.WriteString(`\\`)
|
|
case '\n':
|
|
encoder.buffer.WriteString(`\n`)
|
|
case '\r':
|
|
encoder.buffer.WriteString(`\r`)
|
|
case '\t':
|
|
encoder.buffer.WriteString(`\t`)
|
|
default:
|
|
// 其他不可见字符
|
|
encoder.buffer.WriteString(`\u00`)
|
|
encoder.buffer.WriteByte("0123456789abcdef"[char>>4])
|
|
encoder.buffer.WriteByte("0123456789abcdef"[char&0xf])
|
|
}
|
|
start = i + 1
|
|
}
|
|
}
|
|
if start < len(str) {
|
|
encoder.buffer.WriteString(str[start:])
|
|
}
|
|
encoder.buffer.WriteByte('"')
|
|
}
|
|
|
|
// 导出函数
|
|
func fastToJSONBytes(value any, desensitizeKeys ...string) ([]byte, error) {
|
|
buffer := bufferPool.Get().(*bytes.Buffer)
|
|
buffer.Reset()
|
|
defer bufferPool.Put(buffer)
|
|
|
|
encoder := &fastEncoder{buffer: buffer}
|
|
if len(desensitizeKeys) > 0 {
|
|
encoder.desensitizeKeys = make(map[string]bool)
|
|
for _, key := range desensitizeKeys {
|
|
encoder.desensitizeKeys[key] = true
|
|
}
|
|
}
|
|
|
|
if err := encoder.encode(value); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result := make([]byte, buffer.Len())
|
|
copy(result, buffer.Bytes())
|
|
return result, nil
|
|
}
|