feat: implement global version control and optimistic locking

This commit is contained in:
AI Engineer 2026-05-03 23:01:31 +08:00
parent e7592b669e
commit 8d75cf7be5
4 changed files with 145 additions and 14 deletions

76
Base.go
View File

@ -100,8 +100,22 @@ func quotes(quoteTag string, texts []string) string {
return strings.Join(texts, ",") return strings.Join(texts, ",")
} }
func makeInsertSql(quoteTag string, table string, data any, useReplace bool) (string, []any) { func makeInsertSql(quoteTag string, table string, data any, useReplace bool, versionField string, nextVer int64) (string, []any) {
keys, vars, values := MakeKeysVarsValues(data) keys, vars, values := MakeKeysVarsValues(data)
if versionField != "" {
found := false
for _, k := range keys {
if k == versionField {
found = true
break
}
}
if !found {
keys = append(keys, versionField)
vars = append(vars, "?")
values = append(values, nextVer)
}
}
operation := "insert" operation := "insert"
if useReplace { if useReplace {
operation = "replace" operation = "replace"
@ -110,34 +124,76 @@ func makeInsertSql(quoteTag string, table string, data any, useReplace bool) (st
return query, values return query, values
} }
func makeUpdateSql(quoteTag string, table string, data any, conditions string, args ...any) (string, []any) { func makeUpdateSql(quoteTag string, table string, data any, conditions string, versionField string, nextVer int64, args ...any) (string, []any) {
args = flatArgs(args) args = flatArgs(args)
keys, vars, values := MakeKeysVarsValues(data) keys, vars, values := MakeKeysVarsValues(data)
newKeys := make([]string, 0, len(keys))
newValues := make([]any, 0, len(values))
var oldVersion any
for i, k := range keys { for i, k := range keys {
keys[i] = fmt.Sprintf("%s=%s", quote(quoteTag, k), vars[i]) if k == versionField {
oldVersion = values[i]
continue
}
newKeys = append(newKeys, fmt.Sprintf("%s=%s", quote(quoteTag, k), vars[i]))
newValues = append(newValues, values[i])
} }
values = append(values, args...) if versionField != "" {
newKeys = append(newKeys, fmt.Sprintf("%s=?", quote(quoteTag, versionField)))
newValues = append(newValues, nextVer)
}
if oldVersion != nil {
if conditions != "" {
conditions = fmt.Sprintf("(%s) and %s=?", conditions, quote(quoteTag, versionField))
} else {
conditions = fmt.Sprintf("%s=?", quote(quoteTag, versionField))
}
args = append(args, oldVersion)
}
newValues = append(newValues, args...)
if conditions != "" { if conditions != "" {
conditions = " where " + conditions conditions = " where " + conditions
} }
query := fmt.Sprintf("update %s set %s%s", quote(quoteTag, table), strings.Join(keys, ","), conditions) query := fmt.Sprintf("update %s set %s%s", quote(quoteTag, table), strings.Join(newKeys, ","), conditions)
return query, values return query, newValues
} }
func (db *DB) MakeInsertSql(table string, data any, useReplace bool) (string, []any) { func (db *DB) MakeInsertSql(table string, data any, useReplace bool) (string, []any) {
return makeInsertSql(db.QuoteTag, table, data, useReplace) ts := db.getTable(table)
nextVer := int64(0)
if ts.VersionField != "" {
nextVer = db.NextVersion(table)
}
return makeInsertSql(db.QuoteTag, table, data, useReplace, ts.VersionField, nextVer)
} }
func (db *DB) MakeUpdateSql(table string, data any, conditions string, args ...any) (string, []any) { func (db *DB) MakeUpdateSql(table string, data any, conditions string, args ...any) (string, []any) {
return makeUpdateSql(db.QuoteTag, table, data, conditions, args...) ts := db.getTable(table)
nextVer := int64(0)
if ts.VersionField != "" {
nextVer = db.NextVersion(table)
}
return makeUpdateSql(db.QuoteTag, table, data, conditions, ts.VersionField, nextVer, args...)
} }
func (tx *Tx) MakeInsertSql(table string, data any, useReplace bool) (string, []any) { func (tx *Tx) MakeInsertSql(table string, data any, useReplace bool) (string, []any) {
return makeInsertSql(tx.QuoteTag, table, data, useReplace) ts := tx.db.getTable(table)
nextVer := int64(0)
if ts.VersionField != "" {
nextVer = tx.db.NextVersion(table)
}
return makeInsertSql(tx.QuoteTag, table, data, useReplace, ts.VersionField, nextVer)
} }
func (tx *Tx) MakeUpdateSql(table string, data any, conditions string, args ...any) (string, []any) { func (tx *Tx) MakeUpdateSql(table string, data any, conditions string, args ...any) (string, []any) {
return makeUpdateSql(tx.QuoteTag, table, data, conditions, args...) ts := tx.db.getTable(table)
nextVer := int64(0)
if ts.VersionField != "" {
nextVer = tx.db.NextVersion(table)
}
return makeUpdateSql(tx.QuoteTag, table, data, conditions, ts.VersionField, nextVer, args...)
} }
func getFlatFields(fields map[string]reflect.Value, fieldKeys *[]string, value reflect.Value) { func getFlatFields(fields map[string]reflect.Value, fieldKeys *[]string, value reflect.Value) {

25
DB.go
View File

@ -11,6 +11,7 @@ import (
"regexp" "regexp"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"apigo.cc/go/cast" "apigo.cc/go/cast"
@ -19,6 +20,7 @@ import (
"apigo.cc/go/id" "apigo.cc/go/id"
"apigo.cc/go/log" "apigo.cc/go/log"
"apigo.cc/go/rand" "apigo.cc/go/rand"
"apigo.cc/go/redis"
"apigo.cc/go/safe" "apigo.cc/go/safe"
) )
@ -37,6 +39,7 @@ type Config struct {
MaxIdles int MaxIdles int
MaxLifeTime int MaxLifeTime int
LogSlow config.Duration LogSlow config.Duration
VersionRedis string
logger *log.Logger logger *log.Logger
} }
@ -140,6 +143,7 @@ func (dbInfo *Config) ConfigureBy(setting string) {
dbInfo.MaxLifeTime = cast.Int(q.Get("maxLifeTime")) dbInfo.MaxLifeTime = cast.Int(q.Get("maxLifeTime"))
dbInfo.MaxOpens = cast.Int(q.Get("maxOpens")) dbInfo.MaxOpens = cast.Int(q.Get("maxOpens"))
dbInfo.LogSlow = config.Duration(cast.Duration(q.Get("logSlow"))) dbInfo.LogSlow = config.Duration(cast.Duration(q.Get("logSlow")))
dbInfo.VersionRedis = q.Get("versionRedis")
dbInfo.SSL = q.Get("tls") dbInfo.SSL = q.Get("tls")
sslCa := q.Get("sslCA") sslCa := q.Get("sslCA")
@ -171,7 +175,7 @@ func (dbInfo *Config) ConfigureBy(setting string) {
args := make([]string, 0) args := make([]string, 0)
for k := range q { for k := range q {
if k != "maxIdles" && k != "maxLifeTime" && k != "maxOpens" && k != "logSlow" && k != "tls" { if k != "maxIdles" && k != "maxLifeTime" && k != "maxOpens" && k != "logSlow" && k != "tls" && k != "versionRedis" {
args = append(args, k+"="+q.Get(k)) args = append(args, k+"="+q.Get(k))
} }
} }
@ -229,8 +233,21 @@ var dbConfigsLock = sync.RWMutex{}
var dbSSLs = make(map[string]*SSL) var dbSSLs = make(map[string]*SSL)
var dbInstances = make(map[string]*DB) var dbInstances = make(map[string]*DB)
var dbInstancesLock = sync.RWMutex{} var dbInstancesLock = sync.RWMutex{}
var globalVersionMap = sync.Map{}
var once sync.Once var once sync.Once
func (db *DB) NextVersion(key string) int64 {
if db.Config.VersionRedis != "" {
r := redis.GetRedis(db.Config.VersionRedis, db.logger.logger)
if r != nil {
return r.INCR("db_ver_" + key)
}
}
v, _ := globalVersionMap.LoadOrStore(key, new(int64))
return atomic.AddInt64(v.(*int64), 1)
}
func GetDBWithoutCache(name string, logger *log.Logger) *DB { func GetDBWithoutCache(name string, logger *log.Logger) *DB {
return getDB(name, logger, false) return getDB(name, logger, false)
} }
@ -503,14 +520,14 @@ func (db *DB) Quotes(texts []string) string {
func (db *DB) Begin() *Tx { func (db *DB) Begin() *Tx {
if db.conn == nil { if db.conn == nil {
return &Tx{QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: errors.New("operate on a bad connection"), logger: db.logger} return &Tx{db: db, QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: errors.New("operate on a bad connection"), logger: db.logger}
} }
sqlTx, err := db.conn.Begin() sqlTx, err := db.conn.Begin()
if err != nil { if err != nil {
db.logger.LogError(err.Error()) db.logger.LogError(err.Error())
return &Tx{QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: err, logger: db.logger} return &Tx{db: db, QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: err, logger: db.logger}
} }
return &Tx{QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), conn: sqlTx, logger: db.logger} return &Tx{db: db, QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), conn: sqlTx, logger: db.logger}
} }
func (db *DB) Exec(query string, args ...any) *ExecResult { func (db *DB) Exec(query string, args ...any) *ExecResult {

1
Tx.go
View File

@ -9,6 +9,7 @@ import (
type Tx struct { type Tx struct {
conn *sql.Tx conn *sql.Tx
db *DB
lastSql *string lastSql *string
lastArgs []any lastArgs []any
Error error Error error

57
version_test.go Normal file
View File

@ -0,0 +1,57 @@
package db_test
import (
"testing"
"apigo.cc/go/db"
_ "modernc.org/sqlite"
)
func TestVersionControl(t *testing.T) {
dbInst := db.GetDB("sqlite://:memory:", nil)
// Create table with autoVersion
dbInst.Exec("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, autoVersion BIGINT UNSIGNED)")
t.Run("InsertAutoVersion", func(t *testing.T) {
data := map[string]any{"id": 1, "name": "Alice"}
res := dbInst.Insert("users", data)
if res.Error != nil {
t.Fatalf("Insert failed: %v", res.Error)
}
// Verify version was injected
var ver int64
qr := dbInst.Query("SELECT autoVersion FROM users WHERE id = 1")
ver, _ = db.ToValue[int64](qr)
if ver != 1 {
t.Errorf("Expected version 1, got %d", ver)
}
})
t.Run("UpdateOptimisticLock", func(t *testing.T) {
// First update
data := map[string]any{"name": "Alice Updated", "autoVersion": int64(1)}
res := dbInst.Update("users", data, "id = 1")
if res.Error != nil {
t.Fatalf("Update failed: %v", res.Error)
}
if res.Changes() != 1 {
t.Errorf("Expected 1 change, got %d", res.Changes())
}
// Verify version incremented
var ver int64
qr := dbInst.Query("SELECT autoVersion FROM users WHERE id = 1")
ver, _ = db.ToValue[int64](qr)
if ver != 2 {
t.Errorf("Expected version 2, got %d", ver)
}
// Try update with old version (should fail to update any rows)
dataConflict := map[string]any{"name": "Conflict", "autoVersion": int64(1)}
resConflict := dbInst.Update("users", dataConflict, "id = 1")
if resConflict.Changes() != 0 {
t.Errorf("Expected 0 changes due to optimistic lock, got %d", resConflict.Changes())
}
})
}