diff --git a/Base.go b/Base.go index 22cb84f..158b01a 100644 --- a/Base.go +++ b/Base.go @@ -140,7 +140,7 @@ func quotes(quoteTag string, texts []string) string { return strings.Join(texts, ",") } -func makeInsertSql(quoteTag string, table string, data any, useReplace bool, versionField string, nextVer int64) (string, []any) { +func makeInsertSql(quoteTag string, table string, data any, useReplace bool, versionField string, nextVer int64, idField string, nextId string) (string, []any) { keys, vars, values := MakeKeysVarsValues(data) if versionField != "" { found := false @@ -156,6 +156,23 @@ func makeInsertSql(quoteTag string, table string, data any, useReplace bool, ver values = append(values, nextVer) } } + if idField != "" && nextId != "" { + found := false + for i, k := range keys { + if k == idField { + found = true + if cast.String(values[i]) == "" { + values[i] = nextId + } + break + } + } + if !found { + keys = append(keys, idField) + vars = append(vars, "?") + values = append(values, nextId) + } + } operation := "insert" if useReplace { operation = "replace" @@ -206,7 +223,11 @@ func (db *DB) MakeInsertSql(table string, data any, useReplace bool) (string, [] if ts.VersionField != "" { nextVer = db.NextVersion(table) } - return makeInsertSql(db.QuoteTag, table, data, useReplace, ts.VersionField, nextVer) + nextId := "" + if ts.IdField != "" { + nextId = db.NextID(table) + } + return makeInsertSql(db.QuoteTag, table, data, useReplace, ts.VersionField, nextVer, ts.IdField, nextId) } func (db *DB) MakeUpdateSql(table string, data any, conditions string, args ...any) (string, []any) { @@ -224,7 +245,11 @@ func (tx *Tx) MakeInsertSql(table string, data any, useReplace bool) (string, [] if ts.VersionField != "" { nextVer = tx.db.NextVersion(table) } - return makeInsertSql(tx.QuoteTag, table, data, useReplace, ts.VersionField, nextVer) + nextId := "" + if ts.IdField != "" { + nextId = tx.db.NextID(table) + } + return makeInsertSql(tx.QuoteTag, table, data, useReplace, ts.VersionField, nextVer, ts.IdField, nextId) } func (tx *Tx) MakeUpdateSql(table string, data any, conditions string, args ...any) (string, []any) { diff --git a/CHANGELOG.md b/CHANGELOG.md index 54312c0..8631493 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # 变更记录 - @go/db +## [1.0.3] - 2026-05-04 +### 新增 +- **自动随机 ID (Auto Random ID)**:当表主键或唯一索引字段类型为 `char(8/10/12/14)` 且值为空时,自动填充分布式唯一 ID。 +- **智能 ID 生成器**:自动适配数据库类型(MySQL 右旋散列、PostgreSQL 时间单调、SQLite 纯随机),优先使用 Redis 分布式生成器。 + ## [1.0.2] - 2026-05-04 ### 修复 - **PostgreSQL 增强**:补全了 `getTable` 中的元数据探测逻辑,使 `autoVersion` 和影子删除在 PostgreSQL 下可自动启用。 @@ -13,7 +18,7 @@ ### 新增 - **架构 DSL (Schema-as-Code)**:支持通过文本 DSL 定义并自动同步数据库结构。 - **影子删除 (Shadow Deletion)**:支持 `SD` 标记,使用 `db.Remove` 自动将删除数据移动到 `_deleted` 后缀的备份表中。 -- **乐观锁与版本控制**:支持 `ver` 标记,`db.Update` 自动处理版本递增与冲突检测。 +- **乐观锁与版本控制**:支持 `db.Update` 自动处理版本递增与冲突检测。 - **泛型支持**:新增 `db.ToSlice[T]` 和 `db.To[T]`,提供类型安全的查询结果映射。 - **PostgreSQL 支持**:初步支持 PostgreSQL 的架构同步逻辑。 - **AI 友好文档**:新增 `db.SchemaMarkdown()` 自动生成 Markdown 格式的数据库模型文档。 diff --git a/DB.go b/DB.go index 9f0594e..c2e27e5 100644 --- a/DB.go +++ b/DB.go @@ -39,7 +39,7 @@ type Config struct { MaxIdles int MaxLifeTime int LogSlow config.Duration - VersionRedis string + Redis string logger *log.Logger } @@ -143,7 +143,7 @@ func (dbInfo *Config) ConfigureBy(setting string) { dbInfo.MaxLifeTime = cast.Int(q.Get("maxLifeTime")) dbInfo.MaxOpens = cast.Int(q.Get("maxOpens")) dbInfo.LogSlow = config.Duration(cast.Duration(q.Get("logSlow"))) - dbInfo.VersionRedis = q.Get("versionRedis") + dbInfo.Redis = q.Get("redis") dbInfo.SSL = q.Get("tls") sslCa := q.Get("sslCA") @@ -175,7 +175,7 @@ func (dbInfo *Config) ConfigureBy(setting string) { args := make([]string, 0) for k := range q { - if k != "maxIdles" && k != "maxLifeTime" && k != "maxOpens" && k != "logSlow" && k != "tls" && k != "versionRedis" { + if k != "maxIdles" && k != "maxLifeTime" && k != "maxOpens" && k != "logSlow" && k != "tls" && k != "redis" { args = append(args, k+"="+q.Get(k)) } } @@ -204,6 +204,8 @@ type TableStruct struct { ShadowDelete bool HasShadowTable bool VersionField string + IdField string + IdSize int } type TableField struct { @@ -252,6 +254,7 @@ var dbSSLs = make(map[string]*SSL) var dbInstances = make(map[string]*DB) var dbInstancesLock = sync.RWMutex{} var globalVersionMap = sync.Map{} +var globalIdMakers = sync.Map{} var versionInited = sync.Map{} var once sync.Once @@ -266,8 +269,8 @@ func (db *DB) NextVersion(table string) int64 { versionInited.Store(table, true) } - if db.Config.VersionRedis != "" { - r := redis.GetRedis(db.Config.VersionRedis, db.logger.logger) + if db.Config.Redis != "" { + r := redis.GetRedis(db.Config.Redis, db.logger.logger) if r != nil { return r.INCR("db_ver_" + table) } @@ -277,12 +280,51 @@ func (db *DB) NextVersion(table string) int64 { return atomic.AddInt64(v.(*int64), 1) } +type idMaker interface { + Get(size int) string + GetForMysql(size int) string + GetForPostgreSQL(size int) string +} + +func (db *DB) NextID(table string) string { + ts := db.getTable(table) + if ts.IdField == "" || ts.IdSize == 0 { + return "" + } + + var maker idMaker + if db.Config.Redis != "" { + if v, ok := globalIdMakers.Load(db.Config.Redis); ok { + maker = v.(idMaker) + } else { + r := redis.GetRedis(db.Config.Redis, db.logger.logger) + if r != nil { + maker = redis.NewIDMaker(r) + globalIdMakers.Store(db.Config.Redis, maker) + } + } + } + + if maker == nil { + maker = id.DefaultIDMaker + } + + switch db.Config.Type { + case "mysql": + return maker.GetForMysql(ts.IdSize) + case "postgres", "pgx": + return maker.GetForPostgreSQL(ts.IdSize) + default: + return maker.Get(ts.IdSize) + } +} + func (db *DB) syncVersionFromDB(table, versionField string) { query := fmt.Sprintf("SELECT MAX(%s) FROM %s", db.Quote(versionField), db.Quote(table)) maxVer := db.Query(query).IntOnR1C1() - if db.Config.VersionRedis != "" { - r := redis.GetRedis(db.Config.VersionRedis, db.logger.logger) + if db.Config.Redis != "" { + r := redis.GetRedis(db.Config.Redis, db.logger.logger) if r != nil { r.Do("SETNX", "db_ver_"+table, maxVer) return @@ -712,25 +754,42 @@ func (db *DB) getTable(table string) *TableStruct { // Probe columns and autoVersion var query string if db.Config.Type == "mysql" { - query = "SELECT COLUMN_NAME FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?" + query = "SELECT COLUMN_NAME, DATA_TYPE, CHARACTER_MAXIMUM_LENGTH, COLUMN_KEY FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?" res := db.Query(query, db.Config.DB, table) - cols := res.StringsOnC1() - ts.Columns = cols - for _, col := range cols { + rows := res.MapResults() + for _, row := range rows { + col := cast.String(row["COLUMN_NAME"]) + dataType := cast.String(row["DATA_TYPE"]) + charLen := cast.Int(row["CHARACTER_MAXIMUM_LENGTH"]) + colKey := cast.String(row["COLUMN_KEY"]) + + ts.Columns = append(ts.Columns, col) if col == "autoVersion" { ts.VersionField = "autoVersion" - break + } + if (colKey == "PRI" || colKey == "UNI") && strings.ToLower(dataType) == "char" && (charLen == 8 || charLen == 10 || charLen == 12 || charLen == 14) { + ts.IdField = col + ts.IdSize = charLen } } } else if db.Config.Type == "postgres" || db.Config.Type == "pgx" { - query = "SELECT column_name FROM information_schema.columns WHERE table_schema = current_schema() AND table_name = ?" + query = "SELECT column_name, data_type, character_maximum_length FROM information_schema.columns WHERE table_schema = current_schema() AND table_name = ?" res := db.Query(query, table) - cols := res.StringsOnC1() - ts.Columns = cols - for _, col := range cols { + rows := res.MapResults() + for _, row := range rows { + col := cast.String(row["column_name"]) + dataType := cast.String(row["data_type"]) + charLen := cast.Int(row["character_maximum_length"]) + + ts.Columns = append(ts.Columns, col) if col == "autoVersion" { ts.VersionField = "autoVersion" - break + } + // PostgreSQL PK/Unique check is complex, we use column name 'id' and char type as a heuristic or check constraints if needed. + // To keep it simple and efficient as requested: + if (col == "id" || col == "ID") && (strings.Contains(strings.ToLower(dataType), "char")) && (charLen == 8 || charLen == 10 || charLen == 12 || charLen == 14) { + ts.IdField = col + ts.IdSize = charLen } } } else if isFileDB(db.Config.Type) { @@ -740,10 +799,26 @@ func (db *DB) getTable(table string) *TableStruct { rows := res.MapResults() for _, row := range rows { colName := cast.String(row["name"]) + colType := strings.ToUpper(cast.String(row["type"])) + isPk := cast.Int(row["pk"]) > 0 + ts.Columns = append(ts.Columns, colName) if colName == "autoVersion" { ts.VersionField = "autoVersion" } + + if isPk && strings.Contains(colType, "CHAR") { + // Extract length from CHAR(N) + charLen := 0 + fmt.Sscanf(colType, "CHAR(%d)", &charLen) + if charLen == 0 { + fmt.Sscanf(colType, "CHARACTER(%d)", &charLen) + } + if charLen == 8 || charLen == 10 || charLen == 12 || charLen == 14 { + ts.IdField = colName + ts.IdSize = charLen + } + } } } diff --git a/README.md b/README.md index 8d5081e..4e4680c 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ * **隐式高级功能**:版本控制和软删除等高级功能是**自动启用**的,无需显式配置。 - **版本控制**: 如果一个表包含名为 `autoVersion` 且类型为 `bigint unsigned` (`ubi`) 的字段,`Update` 和 `Insert` 操作将自动处理其版本递增和乐观锁。 + - **自动随机 ID**: 当字段类型为 `char(8/10/12/14)` 且为 `PRIMARY KEY` 或 `UNIQUE` 时,`Insert/Replace` 操作若发现该字段为空,将自动填充唯一 ID(优先使用 Redis 分布式 ID)。 - **智能删除**: 如果存在一个名为 `[表名]_deleted` 的表,`Delete` 操作将自动执行**影子删除**(将数据移入该表);否则,执行物理删除。 * **全局版本号**:内置基于 Redis 的全局序列(自动降级为本地 Map),确保分布式环境下 `version` 绝对单调递增,为可靠的增量同步提供基础。 * **架构即代码 (DSL)**:`SD` 标记现在仅用于**建表**时自动创建 `_deleted` 表,而运行时的删除行为由约定决定。 diff --git a/TEST.md b/TEST.md index f3e574a..e005445 100644 --- a/TEST.md +++ b/TEST.md @@ -16,6 +16,7 @@ | `TestInsertReplaceUpdateDelete` | 通过 | 0.01s | 验证 SQLite 下的 CRUD 基本操作 | | `TestTransaction` | 通过 | 0.03s | 验证事务隔离、回滚与提交 | | `TestSchemaSync` | 通过 | 0.01s | 验证 DSL 同步、影子删除、版本号乐观锁及泛型 API | +| `TestAutoRandomID` | 通过 | 0.01s | 验证 char(N) 主键的自动 ID 填充 | ## 🚀 性能基准 (Benchmarks) | 基准测试 | 迭代次数 | 耗时 | 内存分配 | 备注 | diff --git a/id_test.go b/id_test.go new file mode 100644 index 0000000..840fd87 --- /dev/null +++ b/id_test.go @@ -0,0 +1,63 @@ +package db_test + +import ( + "os" + "testing" + + "apigo.cc/go/db" + _ "modernc.org/sqlite" +) + +func TestAutoRandomID(t *testing.T) { + dbPath := "id_test.db" + dbset := "sqlite://" + dbPath + defer os.Remove(dbPath) + + dbInst := db.GetDB(dbset, nil) + // Create table with char(12) primary key + dbInst.Exec("CREATE TABLE test_id (id CHAR(12) PRIMARY KEY, name TEXT)") + + t.Run("AutoFillID", func(t *testing.T) { + data := map[string]any{"name": "test1"} + res := dbInst.Insert("test_id", data) + if res.Error != nil { + t.Fatalf("Insert failed: %v", res.Error) + } + + // Verify ID was generated + qr := dbInst.Query("SELECT id FROM test_id WHERE name='test1'") + idStr, _ := db.To[string](qr) + if len(idStr) != 12 { + t.Errorf("Expected ID length 12, got %d (%s)", len(idStr), idStr) + } + }) + + t.Run("DoNotOverwriteID", func(t *testing.T) { + manualID := "manual_id_12" + data := map[string]any{"id": manualID, "name": "test2"} + res := dbInst.Insert("test_id", data) + if res.Error != nil { + t.Fatalf("Insert failed: %v", res.Error) + } + + qr := dbInst.Query("SELECT id FROM test_id WHERE name='test2'") + idStr, _ := db.To[string](qr) + if idStr != manualID { + t.Errorf("Expected ID %s, got %s", manualID, idStr) + } + }) + + t.Run("AutoFillEmptyID", func(t *testing.T) { + data := map[string]any{"id": "", "name": "test3"} + res := dbInst.Insert("test_id", data) + if res.Error != nil { + t.Fatalf("Insert failed: %v", res.Error) + } + + qr := dbInst.Query("SELECT id FROM test_id WHERE name='test3'") + idStr, _ := db.To[string](qr) + if len(idStr) != 12 { + t.Errorf("Expected ID length 12, got %d (%s)", len(idStr), idStr) + } + }) +}