feat: implement global version control and optimistic locking
This commit is contained in:
parent
e7592b669e
commit
8d75cf7be5
76
Base.go
76
Base.go
@ -100,8 +100,22 @@ func quotes(quoteTag string, texts []string) string {
|
||||
return strings.Join(texts, ",")
|
||||
}
|
||||
|
||||
func makeInsertSql(quoteTag string, table string, data any, useReplace bool) (string, []any) {
|
||||
func makeInsertSql(quoteTag string, table string, data any, useReplace bool, versionField string, nextVer int64) (string, []any) {
|
||||
keys, vars, values := MakeKeysVarsValues(data)
|
||||
if versionField != "" {
|
||||
found := false
|
||||
for _, k := range keys {
|
||||
if k == versionField {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
keys = append(keys, versionField)
|
||||
vars = append(vars, "?")
|
||||
values = append(values, nextVer)
|
||||
}
|
||||
}
|
||||
operation := "insert"
|
||||
if useReplace {
|
||||
operation = "replace"
|
||||
@ -110,34 +124,76 @@ func makeInsertSql(quoteTag string, table string, data any, useReplace bool) (st
|
||||
return query, values
|
||||
}
|
||||
|
||||
func makeUpdateSql(quoteTag string, table string, data any, conditions string, args ...any) (string, []any) {
|
||||
func makeUpdateSql(quoteTag string, table string, data any, conditions string, versionField string, nextVer int64, args ...any) (string, []any) {
|
||||
args = flatArgs(args)
|
||||
keys, vars, values := MakeKeysVarsValues(data)
|
||||
newKeys := make([]string, 0, len(keys))
|
||||
newValues := make([]any, 0, len(values))
|
||||
var oldVersion any
|
||||
for i, k := range keys {
|
||||
keys[i] = fmt.Sprintf("%s=%s", quote(quoteTag, k), vars[i])
|
||||
if k == versionField {
|
||||
oldVersion = values[i]
|
||||
continue
|
||||
}
|
||||
values = append(values, args...)
|
||||
newKeys = append(newKeys, fmt.Sprintf("%s=%s", quote(quoteTag, k), vars[i]))
|
||||
newValues = append(newValues, values[i])
|
||||
}
|
||||
if versionField != "" {
|
||||
newKeys = append(newKeys, fmt.Sprintf("%s=?", quote(quoteTag, versionField)))
|
||||
newValues = append(newValues, nextVer)
|
||||
}
|
||||
|
||||
if oldVersion != nil {
|
||||
if conditions != "" {
|
||||
conditions = fmt.Sprintf("(%s) and %s=?", conditions, quote(quoteTag, versionField))
|
||||
} else {
|
||||
conditions = fmt.Sprintf("%s=?", quote(quoteTag, versionField))
|
||||
}
|
||||
args = append(args, oldVersion)
|
||||
}
|
||||
|
||||
newValues = append(newValues, args...)
|
||||
if conditions != "" {
|
||||
conditions = " where " + conditions
|
||||
}
|
||||
query := fmt.Sprintf("update %s set %s%s", quote(quoteTag, table), strings.Join(keys, ","), conditions)
|
||||
return query, values
|
||||
query := fmt.Sprintf("update %s set %s%s", quote(quoteTag, table), strings.Join(newKeys, ","), conditions)
|
||||
return query, newValues
|
||||
}
|
||||
|
||||
func (db *DB) MakeInsertSql(table string, data any, useReplace bool) (string, []any) {
|
||||
return makeInsertSql(db.QuoteTag, table, data, useReplace)
|
||||
ts := db.getTable(table)
|
||||
nextVer := int64(0)
|
||||
if ts.VersionField != "" {
|
||||
nextVer = db.NextVersion(table)
|
||||
}
|
||||
return makeInsertSql(db.QuoteTag, table, data, useReplace, ts.VersionField, nextVer)
|
||||
}
|
||||
|
||||
func (db *DB) MakeUpdateSql(table string, data any, conditions string, args ...any) (string, []any) {
|
||||
return makeUpdateSql(db.QuoteTag, table, data, conditions, args...)
|
||||
ts := db.getTable(table)
|
||||
nextVer := int64(0)
|
||||
if ts.VersionField != "" {
|
||||
nextVer = db.NextVersion(table)
|
||||
}
|
||||
return makeUpdateSql(db.QuoteTag, table, data, conditions, ts.VersionField, nextVer, args...)
|
||||
}
|
||||
|
||||
func (tx *Tx) MakeInsertSql(table string, data any, useReplace bool) (string, []any) {
|
||||
return makeInsertSql(tx.QuoteTag, table, data, useReplace)
|
||||
ts := tx.db.getTable(table)
|
||||
nextVer := int64(0)
|
||||
if ts.VersionField != "" {
|
||||
nextVer = tx.db.NextVersion(table)
|
||||
}
|
||||
return makeInsertSql(tx.QuoteTag, table, data, useReplace, ts.VersionField, nextVer)
|
||||
}
|
||||
|
||||
func (tx *Tx) MakeUpdateSql(table string, data any, conditions string, args ...any) (string, []any) {
|
||||
return makeUpdateSql(tx.QuoteTag, table, data, conditions, args...)
|
||||
ts := tx.db.getTable(table)
|
||||
nextVer := int64(0)
|
||||
if ts.VersionField != "" {
|
||||
nextVer = tx.db.NextVersion(table)
|
||||
}
|
||||
return makeUpdateSql(tx.QuoteTag, table, data, conditions, ts.VersionField, nextVer, args...)
|
||||
}
|
||||
|
||||
func getFlatFields(fields map[string]reflect.Value, fieldKeys *[]string, value reflect.Value) {
|
||||
|
||||
25
DB.go
25
DB.go
@ -11,6 +11,7 @@ import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"apigo.cc/go/cast"
|
||||
@ -19,6 +20,7 @@ import (
|
||||
"apigo.cc/go/id"
|
||||
"apigo.cc/go/log"
|
||||
"apigo.cc/go/rand"
|
||||
"apigo.cc/go/redis"
|
||||
"apigo.cc/go/safe"
|
||||
)
|
||||
|
||||
@ -37,6 +39,7 @@ type Config struct {
|
||||
MaxIdles int
|
||||
MaxLifeTime int
|
||||
LogSlow config.Duration
|
||||
VersionRedis string
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
@ -140,6 +143,7 @@ func (dbInfo *Config) ConfigureBy(setting string) {
|
||||
dbInfo.MaxLifeTime = cast.Int(q.Get("maxLifeTime"))
|
||||
dbInfo.MaxOpens = cast.Int(q.Get("maxOpens"))
|
||||
dbInfo.LogSlow = config.Duration(cast.Duration(q.Get("logSlow")))
|
||||
dbInfo.VersionRedis = q.Get("versionRedis")
|
||||
dbInfo.SSL = q.Get("tls")
|
||||
|
||||
sslCa := q.Get("sslCA")
|
||||
@ -171,7 +175,7 @@ func (dbInfo *Config) ConfigureBy(setting string) {
|
||||
|
||||
args := make([]string, 0)
|
||||
for k := range q {
|
||||
if k != "maxIdles" && k != "maxLifeTime" && k != "maxOpens" && k != "logSlow" && k != "tls" {
|
||||
if k != "maxIdles" && k != "maxLifeTime" && k != "maxOpens" && k != "logSlow" && k != "tls" && k != "versionRedis" {
|
||||
args = append(args, k+"="+q.Get(k))
|
||||
}
|
||||
}
|
||||
@ -229,8 +233,21 @@ var dbConfigsLock = sync.RWMutex{}
|
||||
var dbSSLs = make(map[string]*SSL)
|
||||
var dbInstances = make(map[string]*DB)
|
||||
var dbInstancesLock = sync.RWMutex{}
|
||||
var globalVersionMap = sync.Map{}
|
||||
var once sync.Once
|
||||
|
||||
func (db *DB) NextVersion(key string) int64 {
|
||||
if db.Config.VersionRedis != "" {
|
||||
r := redis.GetRedis(db.Config.VersionRedis, db.logger.logger)
|
||||
if r != nil {
|
||||
return r.INCR("db_ver_" + key)
|
||||
}
|
||||
}
|
||||
|
||||
v, _ := globalVersionMap.LoadOrStore(key, new(int64))
|
||||
return atomic.AddInt64(v.(*int64), 1)
|
||||
}
|
||||
|
||||
func GetDBWithoutCache(name string, logger *log.Logger) *DB {
|
||||
return getDB(name, logger, false)
|
||||
}
|
||||
@ -503,14 +520,14 @@ func (db *DB) Quotes(texts []string) string {
|
||||
|
||||
func (db *DB) Begin() *Tx {
|
||||
if db.conn == nil {
|
||||
return &Tx{QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: errors.New("operate on a bad connection"), logger: db.logger}
|
||||
return &Tx{db: db, QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: errors.New("operate on a bad connection"), logger: db.logger}
|
||||
}
|
||||
sqlTx, err := db.conn.Begin()
|
||||
if err != nil {
|
||||
db.logger.LogError(err.Error())
|
||||
return &Tx{QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: err, logger: db.logger}
|
||||
return &Tx{db: db, QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: err, logger: db.logger}
|
||||
}
|
||||
return &Tx{QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), conn: sqlTx, logger: db.logger}
|
||||
return &Tx{db: db, QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), conn: sqlTx, logger: db.logger}
|
||||
}
|
||||
|
||||
func (db *DB) Exec(query string, args ...any) *ExecResult {
|
||||
|
||||
1
Tx.go
1
Tx.go
@ -9,6 +9,7 @@ import (
|
||||
|
||||
type Tx struct {
|
||||
conn *sql.Tx
|
||||
db *DB
|
||||
lastSql *string
|
||||
lastArgs []any
|
||||
Error error
|
||||
|
||||
57
version_test.go
Normal file
57
version_test.go
Normal file
@ -0,0 +1,57 @@
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"apigo.cc/go/db"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func TestVersionControl(t *testing.T) {
|
||||
dbInst := db.GetDB("sqlite://:memory:", nil)
|
||||
|
||||
// Create table with autoVersion
|
||||
dbInst.Exec("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, autoVersion BIGINT UNSIGNED)")
|
||||
|
||||
t.Run("InsertAutoVersion", func(t *testing.T) {
|
||||
data := map[string]any{"id": 1, "name": "Alice"}
|
||||
res := dbInst.Insert("users", data)
|
||||
if res.Error != nil {
|
||||
t.Fatalf("Insert failed: %v", res.Error)
|
||||
}
|
||||
|
||||
// Verify version was injected
|
||||
var ver int64
|
||||
qr := dbInst.Query("SELECT autoVersion FROM users WHERE id = 1")
|
||||
ver, _ = db.ToValue[int64](qr)
|
||||
if ver != 1 {
|
||||
t.Errorf("Expected version 1, got %d", ver)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UpdateOptimisticLock", func(t *testing.T) {
|
||||
// First update
|
||||
data := map[string]any{"name": "Alice Updated", "autoVersion": int64(1)}
|
||||
res := dbInst.Update("users", data, "id = 1")
|
||||
if res.Error != nil {
|
||||
t.Fatalf("Update failed: %v", res.Error)
|
||||
}
|
||||
if res.Changes() != 1 {
|
||||
t.Errorf("Expected 1 change, got %d", res.Changes())
|
||||
}
|
||||
|
||||
// Verify version incremented
|
||||
var ver int64
|
||||
qr := dbInst.Query("SELECT autoVersion FROM users WHERE id = 1")
|
||||
ver, _ = db.ToValue[int64](qr)
|
||||
if ver != 2 {
|
||||
t.Errorf("Expected version 2, got %d", ver)
|
||||
}
|
||||
|
||||
// Try update with old version (should fail to update any rows)
|
||||
dataConflict := map[string]any{"name": "Conflict", "autoVersion": int64(1)}
|
||||
resConflict := dbInst.Update("users", dataConflict, "id = 1")
|
||||
if resConflict.Changes() != 0 {
|
||||
t.Errorf("Expected 0 changes due to optimistic lock, got %d", resConflict.Changes())
|
||||
}
|
||||
})
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user