From 8d75cf7be516b2c218f1360352d47e7f80b8787a Mon Sep 17 00:00:00 2001 From: AI Engineer Date: Sun, 3 May 2026 23:01:31 +0800 Subject: [PATCH] feat: implement global version control and optimistic locking --- Base.go | 76 ++++++++++++++++++++++++++++++++++++++++++------- DB.go | 25 +++++++++++++--- Tx.go | 1 + version_test.go | 57 +++++++++++++++++++++++++++++++++++++ 4 files changed, 145 insertions(+), 14 deletions(-) create mode 100644 version_test.go diff --git a/Base.go b/Base.go index 3fce087..5209c58 100644 --- a/Base.go +++ b/Base.go @@ -100,8 +100,22 @@ func quotes(quoteTag string, texts []string) string { return strings.Join(texts, ",") } -func makeInsertSql(quoteTag string, table string, data any, useReplace bool) (string, []any) { +func makeInsertSql(quoteTag string, table string, data any, useReplace bool, versionField string, nextVer int64) (string, []any) { keys, vars, values := MakeKeysVarsValues(data) + if versionField != "" { + found := false + for _, k := range keys { + if k == versionField { + found = true + break + } + } + if !found { + keys = append(keys, versionField) + vars = append(vars, "?") + values = append(values, nextVer) + } + } operation := "insert" if useReplace { operation = "replace" @@ -110,34 +124,76 @@ func makeInsertSql(quoteTag string, table string, data any, useReplace bool) (st return query, values } -func makeUpdateSql(quoteTag string, table string, data any, conditions string, args ...any) (string, []any) { +func makeUpdateSql(quoteTag string, table string, data any, conditions string, versionField string, nextVer int64, args ...any) (string, []any) { args = flatArgs(args) keys, vars, values := MakeKeysVarsValues(data) + newKeys := make([]string, 0, len(keys)) + newValues := make([]any, 0, len(values)) + var oldVersion any for i, k := range keys { - keys[i] = fmt.Sprintf("%s=%s", quote(quoteTag, k), vars[i]) + if k == versionField { + oldVersion = values[i] + continue + } + newKeys = append(newKeys, fmt.Sprintf("%s=%s", quote(quoteTag, k), vars[i])) + newValues = append(newValues, values[i]) } - values = append(values, args...) + if versionField != "" { + newKeys = append(newKeys, fmt.Sprintf("%s=?", quote(quoteTag, versionField))) + newValues = append(newValues, nextVer) + } + + if oldVersion != nil { + if conditions != "" { + conditions = fmt.Sprintf("(%s) and %s=?", conditions, quote(quoteTag, versionField)) + } else { + conditions = fmt.Sprintf("%s=?", quote(quoteTag, versionField)) + } + args = append(args, oldVersion) + } + + newValues = append(newValues, args...) if conditions != "" { conditions = " where " + conditions } - query := fmt.Sprintf("update %s set %s%s", quote(quoteTag, table), strings.Join(keys, ","), conditions) - return query, values + query := fmt.Sprintf("update %s set %s%s", quote(quoteTag, table), strings.Join(newKeys, ","), conditions) + return query, newValues } func (db *DB) MakeInsertSql(table string, data any, useReplace bool) (string, []any) { - return makeInsertSql(db.QuoteTag, table, data, useReplace) + ts := db.getTable(table) + nextVer := int64(0) + if ts.VersionField != "" { + nextVer = db.NextVersion(table) + } + return makeInsertSql(db.QuoteTag, table, data, useReplace, ts.VersionField, nextVer) } func (db *DB) MakeUpdateSql(table string, data any, conditions string, args ...any) (string, []any) { - return makeUpdateSql(db.QuoteTag, table, data, conditions, args...) + ts := db.getTable(table) + nextVer := int64(0) + if ts.VersionField != "" { + nextVer = db.NextVersion(table) + } + return makeUpdateSql(db.QuoteTag, table, data, conditions, ts.VersionField, nextVer, args...) } func (tx *Tx) MakeInsertSql(table string, data any, useReplace bool) (string, []any) { - return makeInsertSql(tx.QuoteTag, table, data, useReplace) + ts := tx.db.getTable(table) + nextVer := int64(0) + if ts.VersionField != "" { + nextVer = tx.db.NextVersion(table) + } + return makeInsertSql(tx.QuoteTag, table, data, useReplace, ts.VersionField, nextVer) } func (tx *Tx) MakeUpdateSql(table string, data any, conditions string, args ...any) (string, []any) { - return makeUpdateSql(tx.QuoteTag, table, data, conditions, args...) + ts := tx.db.getTable(table) + nextVer := int64(0) + if ts.VersionField != "" { + nextVer = tx.db.NextVersion(table) + } + return makeUpdateSql(tx.QuoteTag, table, data, conditions, ts.VersionField, nextVer, args...) } func getFlatFields(fields map[string]reflect.Value, fieldKeys *[]string, value reflect.Value) { diff --git a/DB.go b/DB.go index 47d6f1f..3dded27 100644 --- a/DB.go +++ b/DB.go @@ -11,6 +11,7 @@ import ( "regexp" "strings" "sync" + "sync/atomic" "time" "apigo.cc/go/cast" @@ -19,6 +20,7 @@ import ( "apigo.cc/go/id" "apigo.cc/go/log" "apigo.cc/go/rand" + "apigo.cc/go/redis" "apigo.cc/go/safe" ) @@ -37,6 +39,7 @@ type Config struct { MaxIdles int MaxLifeTime int LogSlow config.Duration + VersionRedis string logger *log.Logger } @@ -140,6 +143,7 @@ func (dbInfo *Config) ConfigureBy(setting string) { dbInfo.MaxLifeTime = cast.Int(q.Get("maxLifeTime")) dbInfo.MaxOpens = cast.Int(q.Get("maxOpens")) dbInfo.LogSlow = config.Duration(cast.Duration(q.Get("logSlow"))) + dbInfo.VersionRedis = q.Get("versionRedis") dbInfo.SSL = q.Get("tls") sslCa := q.Get("sslCA") @@ -171,7 +175,7 @@ func (dbInfo *Config) ConfigureBy(setting string) { args := make([]string, 0) for k := range q { - if k != "maxIdles" && k != "maxLifeTime" && k != "maxOpens" && k != "logSlow" && k != "tls" { + if k != "maxIdles" && k != "maxLifeTime" && k != "maxOpens" && k != "logSlow" && k != "tls" && k != "versionRedis" { args = append(args, k+"="+q.Get(k)) } } @@ -229,8 +233,21 @@ var dbConfigsLock = sync.RWMutex{} var dbSSLs = make(map[string]*SSL) var dbInstances = make(map[string]*DB) var dbInstancesLock = sync.RWMutex{} +var globalVersionMap = sync.Map{} var once sync.Once +func (db *DB) NextVersion(key string) int64 { + if db.Config.VersionRedis != "" { + r := redis.GetRedis(db.Config.VersionRedis, db.logger.logger) + if r != nil { + return r.INCR("db_ver_" + key) + } + } + + v, _ := globalVersionMap.LoadOrStore(key, new(int64)) + return atomic.AddInt64(v.(*int64), 1) +} + func GetDBWithoutCache(name string, logger *log.Logger) *DB { return getDB(name, logger, false) } @@ -503,14 +520,14 @@ func (db *DB) Quotes(texts []string) string { func (db *DB) Begin() *Tx { if db.conn == nil { - return &Tx{QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: errors.New("operate on a bad connection"), logger: db.logger} + return &Tx{db: db, QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: errors.New("operate on a bad connection"), logger: db.logger} } sqlTx, err := db.conn.Begin() if err != nil { db.logger.LogError(err.Error()) - return &Tx{QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: err, logger: db.logger} + return &Tx{db: db, QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: err, logger: db.logger} } - return &Tx{QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), conn: sqlTx, logger: db.logger} + return &Tx{db: db, QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), conn: sqlTx, logger: db.logger} } func (db *DB) Exec(query string, args ...any) *ExecResult { diff --git a/Tx.go b/Tx.go index 9fe2da7..decfe00 100644 --- a/Tx.go +++ b/Tx.go @@ -9,6 +9,7 @@ import ( type Tx struct { conn *sql.Tx + db *DB lastSql *string lastArgs []any Error error diff --git a/version_test.go b/version_test.go new file mode 100644 index 0000000..9b11224 --- /dev/null +++ b/version_test.go @@ -0,0 +1,57 @@ +package db_test + +import ( + "testing" + "apigo.cc/go/db" + _ "modernc.org/sqlite" +) + +func TestVersionControl(t *testing.T) { + dbInst := db.GetDB("sqlite://:memory:", nil) + + // Create table with autoVersion + dbInst.Exec("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, autoVersion BIGINT UNSIGNED)") + + t.Run("InsertAutoVersion", func(t *testing.T) { + data := map[string]any{"id": 1, "name": "Alice"} + res := dbInst.Insert("users", data) + if res.Error != nil { + t.Fatalf("Insert failed: %v", res.Error) + } + + // Verify version was injected + var ver int64 + qr := dbInst.Query("SELECT autoVersion FROM users WHERE id = 1") + ver, _ = db.ToValue[int64](qr) + if ver != 1 { + t.Errorf("Expected version 1, got %d", ver) + } + }) + + t.Run("UpdateOptimisticLock", func(t *testing.T) { + // First update + data := map[string]any{"name": "Alice Updated", "autoVersion": int64(1)} + res := dbInst.Update("users", data, "id = 1") + if res.Error != nil { + t.Fatalf("Update failed: %v", res.Error) + } + if res.Changes() != 1 { + t.Errorf("Expected 1 change, got %d", res.Changes()) + } + + // Verify version incremented + var ver int64 + qr := dbInst.Query("SELECT autoVersion FROM users WHERE id = 1") + ver, _ = db.ToValue[int64](qr) + if ver != 2 { + t.Errorf("Expected version 2, got %d", ver) + } + + // Try update with old version (should fail to update any rows) + dataConflict := map[string]any{"name": "Conflict", "autoVersion": int64(1)} + resConflict := dbInst.Update("users", dataConflict, "id = 1") + if resConflict.Changes() != 0 { + t.Errorf("Expected 0 changes due to optimistic lock, got %d", resConflict.Changes()) + } + }) +}