diff --git a/CHANGELOG.md b/CHANGELOG.md index 68a62aa..18eee13 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # CHANGELOG +## [v1.2.6] - 2026-05-04 +### Fixed +- **Map 深度合并修复**: 修复了在 `Convert` 或 `ToMap` 过程中,如果目标 Map 已存在该 Key,其原有结构体/Map 值会被直接覆盖而非深度合并的问题。通过引入 `dst.MapIndex` 预读取与临时寻址变量,现已完美支持 Map 下非指针结构体的局部字段覆盖。 + +### Added +- **深度合并测试集**: 新增 `merge_test.go` 与 `complex_conversion_test.go`,覆盖了 Struct-to-Map, Map-to-Map, Slice-to-Map 等多种深度嵌套合并场景,确保配置覆盖逻辑的健壮性。 + ## [v1.2.3] - 2026-05-04 ### Added - **强大的时间解析引擎**: 移植 `time` 模块的核心算法至 `cast.ParseTime`,支持时间戳(秒至纳秒)、RFC3339、JS 格式、紧凑格式(20060102150405)及中文日期解析。 diff --git a/TEST.md b/TEST.md index 53bab10..444e0c9 100644 --- a/TEST.md +++ b/TEST.md @@ -14,6 +14,7 @@ - **零分配摩擦匹配**: `normalizeEqual` 算法实现 0 内存分配的归一化 Key 匹配,支持 UTF-8。 - **智能 Slice 扩容**: 尊重预设 Capacity,减少反序列化时的内存重分配。 - **FastDecoder**: 实现单路径流式解析,支持嵌套指针、Slice 和 Map 的智能初始化。 +- **深度合并支持**: 支持在 `Convert` 过程中对 Map 和 Struct 进行深度递归合并。即使目标 Map 存储的是非指针结构体,也能通过反射寻址实现局部字段覆盖,而不丢失未提及的默认值。 - **指针与接口**: `RealValue` 处理多级指针与接口解包。 - **实用工具**: `UniqueAppend` ($O(n)$ 去重),`If` (泛型三元),`SplitArgs` (支持引用格式)。 diff --git a/cast.go b/cast.go index c0e88ce..cf031a1 100644 --- a/cast.go +++ b/cast.go @@ -342,13 +342,30 @@ func recursiveMapToMap(src, dst reflect.Value) { kt := dst.Type().Key() vt := dst.Type().Elem() newKey := reflect.New(kt).Elem() - newVal := reflect.New(vt).Elem() + iter := src.MapRange() for iter.Next() { newKey.Set(reflect.Zero(kt)) performRecursiveTo(iter.Key(), newKey) - newVal.Set(reflect.Zero(vt)) + + // 1. 创建一个【可寻址】的临时变量 + newVal := reflect.New(vt).Elem() + + // 2. 尝试从目标 Map 中获取已经存在的老值 + existingVal := dst.MapIndex(newKey) + if existingVal.IsValid() { + // 如果老值存在,把它拷贝到我们刚才创建的临时变量里作“底本” + // 这样就能保留原本没被覆盖的字段 + newVal.Set(existingVal) + } else { + // 如果老值不存在,就用零值兜底 + newVal.Set(reflect.Zero(vt)) + } + + // 3. 对这个可寻址的临时变量进行深度合并 performRecursiveTo(iter.Value(), newVal) + + // 4. 将合并后的全新值,重新塞回目标 Map dst.SetMapIndex(newKey, newVal) } } @@ -357,7 +374,7 @@ func recursiveStructToMap(src, dst reflect.Value) { kt := dst.Type().Key() vt := dst.Type().Elem() newKey := reflect.New(kt).Elem() - newVal := reflect.New(vt).Elem() + srcType := src.Type() for i := 0; i < src.NumField(); i++ { field := srcType.Field(i) @@ -370,8 +387,22 @@ func recursiveStructToMap(src, dst reflect.Value) { } newKey.Set(reflect.Zero(kt)) performRecursiveTo(reflect.ValueOf(GetLowerName(field.Name)), newKey) - newVal.Set(reflect.Zero(vt)) + + // 1. 创建一个【可寻址】的临时变量 + newVal := reflect.New(vt).Elem() + + // 2. 尝试获取老值 + existingVal := dst.MapIndex(newKey) + if existingVal.IsValid() { + newVal.Set(existingVal) + } else { + newVal.Set(reflect.Zero(vt)) + } + + // 3. 合并 performRecursiveTo(src.Field(i), newVal) + + // 4. 塞回 dst.SetMapIndex(newKey, newVal) } } @@ -380,11 +411,18 @@ func recursiveSliceToMap(src, dst reflect.Value) { kt := dst.Type().Key() vt := dst.Type().Elem() newKey := reflect.New(kt).Elem() - newVal := reflect.New(vt).Elem() for i := 0; i < src.Len(); i += 2 { newKey.Set(reflect.Zero(kt)) performRecursiveTo(src.Index(i), newKey) - newVal.Set(reflect.Zero(vt)) + + newVal := reflect.New(vt).Elem() + existingVal := dst.MapIndex(newKey) + if existingVal.IsValid() { + newVal.Set(existingVal) + } else { + newVal.Set(reflect.Zero(vt)) + } + if i+1 < src.Len() { performRecursiveTo(src.Index(i+1), newVal) } diff --git a/complex_conversion_test.go b/complex_conversion_test.go new file mode 100644 index 0000000..c79680d --- /dev/null +++ b/complex_conversion_test.go @@ -0,0 +1,102 @@ +package cast_test + +import ( + "testing" + "apigo.cc/go/cast" +) + +type SubConfig struct { + Level int + Tag string +} + +type MainConfig struct { + Name string + Sub SubConfig + Items []string + Options map[string]int +} + +func TestDeepMergeComplex(t *testing.T) { + dst := MainConfig{ + Name: "Base", + Sub: SubConfig{ + Level: 1, + Tag: "original", + }, + Items: []string{"a", "b"}, + Options: map[string]int{ + "debug": 1, + "trace": 0, + }, + } + + src := map[string]any{ + "Sub": map[string]any{ + "Level": 2, + }, + "Options": map[string]any{ + "trace": 1, + "new": 100, + }, + } + + cast.Convert(&dst, src) + + if dst.Name != "Base" { + t.Errorf("Expected Name Base, got %s", dst.Name) + } + if dst.Sub.Level != 2 { + t.Errorf("Expected Sub.Level 2, got %d", dst.Sub.Level) + } + if dst.Sub.Tag != "original" { + t.Errorf("Expected Sub.Tag original, got %s", dst.Sub.Tag) + } + if len(dst.Items) != 2 { + t.Errorf("Expected Items length 2, got %d", len(dst.Items)) + } + if dst.Options["debug"] != 1 { + t.Errorf("Expected Options.debug 1, got %d", dst.Options["debug"]) + } + if dst.Options["trace"] != 1 { + t.Errorf("Expected Options.trace 1, got %d", dst.Options["trace"]) + } + if dst.Options["new"] != 100 { + t.Errorf("Expected Options.new 100, got %d", dst.Options["new"]) + } +} + +func TestMapToMapMergeComplex(t *testing.T) { + dst := map[string]MainConfig{ + "c1": { + Name: "Config1", + Sub: SubConfig{Level: 10}, + }, + } + + src := map[string]any{ + "c1": map[string]any{ + "Sub": map[string]any{ + "Tag": "updated", + }, + }, + "c2": map[string]any{ + "Name": "Config2", + }, + } + + cast.Convert(&dst, src) + + if dst["c1"].Name != "Config1" { + t.Errorf("Expected c1.Name Config1, got %s", dst["c1"].Name) + } + if dst["c1"].Sub.Level != 10 { + t.Errorf("Expected c1.Sub.Level 10, got %d", dst["c1"].Sub.Level) + } + if dst["c1"].Sub.Tag != "updated" { + t.Errorf("Expected c1.Sub.Tag updated, got %s", dst["c1"].Sub.Tag) + } + if dst["c2"].Name != "Config2" { + t.Errorf("Expected c2.Name Config2, got %s", dst["c2"].Name) + } +} diff --git a/merge_test.go b/merge_test.go new file mode 100644 index 0000000..f726f27 --- /dev/null +++ b/merge_test.go @@ -0,0 +1,119 @@ +package cast_test + +import ( + "testing" + "apigo.cc/go/cast" +) + +type DBConfig struct { + Host string + Port int +} + +func TestMapStructMerge(t *testing.T) { + // Initial configuration with default values + dst := map[string]DBConfig{ + "mysql": {Host: "localhost", Port: 3306}, + } + + // New data (e.g., from environment variables) that only overrides Host + src := map[string]any{ + "mysql": map[string]any{ + "Host": "127.0.0.1", + }, + } + + // Perform conversion/merge + cast.Convert(&dst, src) + + // Verify results + mysql, ok := dst["mysql"] + if !ok { + t.Fatal("mysql config not found") + } + + if mysql.Host != "127.0.0.1" { + t.Errorf("Expected Host 127.0.0.1, got %s", mysql.Host) + } + + if mysql.Port != 3306 { + t.Errorf("Expected Port 3306, got %d", mysql.Port) + } +} + +func TestMapStructPointerMerge(t *testing.T) { + // Initial configuration with default values (using pointers) + dst := map[string]*DBConfig{ + "mysql": {Host: "localhost", Port: 3306}, + } + + // New data that only overrides Host + src := map[string]any{ + "mysql": map[string]any{ + "Host": "127.0.0.1", + }, + } + + // Perform conversion/merge + cast.Convert(&dst, src) + + // Verify results + mysql, ok := dst["mysql"] + if !ok { + t.Fatal("mysql config not found") + } + + if mysql.Host != "127.0.0.1" { + t.Errorf("Expected Host 127.0.0.1, got %s", mysql.Host) + } + + if mysql.Port != 3306 { + t.Errorf("Expected Port 3306, got %d", mysql.Port) + } +} + +func TestSliceToMapMerge(t *testing.T) { + dst := map[string]int{ + "a": 1, + "b": 2, + } + src := []any{"b", 20, "c", 30} + + cast.Convert(&dst, src) + + if dst["a"] != 1 { + t.Errorf("Expected a=1, got %d", dst["a"]) + } + if dst["b"] != 20 { + t.Errorf("Expected b=20, got %d", dst["b"]) + } + if dst["c"] != 30 { + t.Errorf("Expected c=30, got %d", dst["c"]) + } +} + +func TestStructToMapMerge(t *testing.T) { + type Config struct { + Host string + Port int + } + dst := map[string]any{ + "host": "localhost", + "port": 3306, + "user": "root", + } + // Note: Struct fields are always "present", so Port:0 will overwrite port:3306 + src := Config{Host: "127.0.0.1", Port: 8080} + + cast.Convert(&dst, src) + + if dst["host"] != "127.0.0.1" { + t.Errorf("Expected host 127.0.0.1, got %v", dst["host"]) + } + if cast.Int(dst["port"]) != 8080 { + t.Errorf("Expected port 8080, got %v", dst["port"]) + } + if dst["user"] != "root" { + t.Errorf("Expected user root, got %v", dst["user"]) + } +}