feat: implement global version control and optimistic locking
This commit is contained in:
parent
e7592b669e
commit
8d75cf7be5
76
Base.go
76
Base.go
@ -100,8 +100,22 @@ func quotes(quoteTag string, texts []string) string {
|
|||||||
return strings.Join(texts, ",")
|
return strings.Join(texts, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeInsertSql(quoteTag string, table string, data any, useReplace bool) (string, []any) {
|
func makeInsertSql(quoteTag string, table string, data any, useReplace bool, versionField string, nextVer int64) (string, []any) {
|
||||||
keys, vars, values := MakeKeysVarsValues(data)
|
keys, vars, values := MakeKeysVarsValues(data)
|
||||||
|
if versionField != "" {
|
||||||
|
found := false
|
||||||
|
for _, k := range keys {
|
||||||
|
if k == versionField {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
keys = append(keys, versionField)
|
||||||
|
vars = append(vars, "?")
|
||||||
|
values = append(values, nextVer)
|
||||||
|
}
|
||||||
|
}
|
||||||
operation := "insert"
|
operation := "insert"
|
||||||
if useReplace {
|
if useReplace {
|
||||||
operation = "replace"
|
operation = "replace"
|
||||||
@ -110,34 +124,76 @@ func makeInsertSql(quoteTag string, table string, data any, useReplace bool) (st
|
|||||||
return query, values
|
return query, values
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeUpdateSql(quoteTag string, table string, data any, conditions string, args ...any) (string, []any) {
|
func makeUpdateSql(quoteTag string, table string, data any, conditions string, versionField string, nextVer int64, args ...any) (string, []any) {
|
||||||
args = flatArgs(args)
|
args = flatArgs(args)
|
||||||
keys, vars, values := MakeKeysVarsValues(data)
|
keys, vars, values := MakeKeysVarsValues(data)
|
||||||
|
newKeys := make([]string, 0, len(keys))
|
||||||
|
newValues := make([]any, 0, len(values))
|
||||||
|
var oldVersion any
|
||||||
for i, k := range keys {
|
for i, k := range keys {
|
||||||
keys[i] = fmt.Sprintf("%s=%s", quote(quoteTag, k), vars[i])
|
if k == versionField {
|
||||||
|
oldVersion = values[i]
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
values = append(values, args...)
|
newKeys = append(newKeys, fmt.Sprintf("%s=%s", quote(quoteTag, k), vars[i]))
|
||||||
|
newValues = append(newValues, values[i])
|
||||||
|
}
|
||||||
|
if versionField != "" {
|
||||||
|
newKeys = append(newKeys, fmt.Sprintf("%s=?", quote(quoteTag, versionField)))
|
||||||
|
newValues = append(newValues, nextVer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if oldVersion != nil {
|
||||||
|
if conditions != "" {
|
||||||
|
conditions = fmt.Sprintf("(%s) and %s=?", conditions, quote(quoteTag, versionField))
|
||||||
|
} else {
|
||||||
|
conditions = fmt.Sprintf("%s=?", quote(quoteTag, versionField))
|
||||||
|
}
|
||||||
|
args = append(args, oldVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
newValues = append(newValues, args...)
|
||||||
if conditions != "" {
|
if conditions != "" {
|
||||||
conditions = " where " + conditions
|
conditions = " where " + conditions
|
||||||
}
|
}
|
||||||
query := fmt.Sprintf("update %s set %s%s", quote(quoteTag, table), strings.Join(keys, ","), conditions)
|
query := fmt.Sprintf("update %s set %s%s", quote(quoteTag, table), strings.Join(newKeys, ","), conditions)
|
||||||
return query, values
|
return query, newValues
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) MakeInsertSql(table string, data any, useReplace bool) (string, []any) {
|
func (db *DB) MakeInsertSql(table string, data any, useReplace bool) (string, []any) {
|
||||||
return makeInsertSql(db.QuoteTag, table, data, useReplace)
|
ts := db.getTable(table)
|
||||||
|
nextVer := int64(0)
|
||||||
|
if ts.VersionField != "" {
|
||||||
|
nextVer = db.NextVersion(table)
|
||||||
|
}
|
||||||
|
return makeInsertSql(db.QuoteTag, table, data, useReplace, ts.VersionField, nextVer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) MakeUpdateSql(table string, data any, conditions string, args ...any) (string, []any) {
|
func (db *DB) MakeUpdateSql(table string, data any, conditions string, args ...any) (string, []any) {
|
||||||
return makeUpdateSql(db.QuoteTag, table, data, conditions, args...)
|
ts := db.getTable(table)
|
||||||
|
nextVer := int64(0)
|
||||||
|
if ts.VersionField != "" {
|
||||||
|
nextVer = db.NextVersion(table)
|
||||||
|
}
|
||||||
|
return makeUpdateSql(db.QuoteTag, table, data, conditions, ts.VersionField, nextVer, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tx *Tx) MakeInsertSql(table string, data any, useReplace bool) (string, []any) {
|
func (tx *Tx) MakeInsertSql(table string, data any, useReplace bool) (string, []any) {
|
||||||
return makeInsertSql(tx.QuoteTag, table, data, useReplace)
|
ts := tx.db.getTable(table)
|
||||||
|
nextVer := int64(0)
|
||||||
|
if ts.VersionField != "" {
|
||||||
|
nextVer = tx.db.NextVersion(table)
|
||||||
|
}
|
||||||
|
return makeInsertSql(tx.QuoteTag, table, data, useReplace, ts.VersionField, nextVer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tx *Tx) MakeUpdateSql(table string, data any, conditions string, args ...any) (string, []any) {
|
func (tx *Tx) MakeUpdateSql(table string, data any, conditions string, args ...any) (string, []any) {
|
||||||
return makeUpdateSql(tx.QuoteTag, table, data, conditions, args...)
|
ts := tx.db.getTable(table)
|
||||||
|
nextVer := int64(0)
|
||||||
|
if ts.VersionField != "" {
|
||||||
|
nextVer = tx.db.NextVersion(table)
|
||||||
|
}
|
||||||
|
return makeUpdateSql(tx.QuoteTag, table, data, conditions, ts.VersionField, nextVer, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getFlatFields(fields map[string]reflect.Value, fieldKeys *[]string, value reflect.Value) {
|
func getFlatFields(fields map[string]reflect.Value, fieldKeys *[]string, value reflect.Value) {
|
||||||
|
|||||||
25
DB.go
25
DB.go
@ -11,6 +11,7 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"apigo.cc/go/cast"
|
"apigo.cc/go/cast"
|
||||||
@ -19,6 +20,7 @@ import (
|
|||||||
"apigo.cc/go/id"
|
"apigo.cc/go/id"
|
||||||
"apigo.cc/go/log"
|
"apigo.cc/go/log"
|
||||||
"apigo.cc/go/rand"
|
"apigo.cc/go/rand"
|
||||||
|
"apigo.cc/go/redis"
|
||||||
"apigo.cc/go/safe"
|
"apigo.cc/go/safe"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -37,6 +39,7 @@ type Config struct {
|
|||||||
MaxIdles int
|
MaxIdles int
|
||||||
MaxLifeTime int
|
MaxLifeTime int
|
||||||
LogSlow config.Duration
|
LogSlow config.Duration
|
||||||
|
VersionRedis string
|
||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -140,6 +143,7 @@ func (dbInfo *Config) ConfigureBy(setting string) {
|
|||||||
dbInfo.MaxLifeTime = cast.Int(q.Get("maxLifeTime"))
|
dbInfo.MaxLifeTime = cast.Int(q.Get("maxLifeTime"))
|
||||||
dbInfo.MaxOpens = cast.Int(q.Get("maxOpens"))
|
dbInfo.MaxOpens = cast.Int(q.Get("maxOpens"))
|
||||||
dbInfo.LogSlow = config.Duration(cast.Duration(q.Get("logSlow")))
|
dbInfo.LogSlow = config.Duration(cast.Duration(q.Get("logSlow")))
|
||||||
|
dbInfo.VersionRedis = q.Get("versionRedis")
|
||||||
dbInfo.SSL = q.Get("tls")
|
dbInfo.SSL = q.Get("tls")
|
||||||
|
|
||||||
sslCa := q.Get("sslCA")
|
sslCa := q.Get("sslCA")
|
||||||
@ -171,7 +175,7 @@ func (dbInfo *Config) ConfigureBy(setting string) {
|
|||||||
|
|
||||||
args := make([]string, 0)
|
args := make([]string, 0)
|
||||||
for k := range q {
|
for k := range q {
|
||||||
if k != "maxIdles" && k != "maxLifeTime" && k != "maxOpens" && k != "logSlow" && k != "tls" {
|
if k != "maxIdles" && k != "maxLifeTime" && k != "maxOpens" && k != "logSlow" && k != "tls" && k != "versionRedis" {
|
||||||
args = append(args, k+"="+q.Get(k))
|
args = append(args, k+"="+q.Get(k))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -229,8 +233,21 @@ var dbConfigsLock = sync.RWMutex{}
|
|||||||
var dbSSLs = make(map[string]*SSL)
|
var dbSSLs = make(map[string]*SSL)
|
||||||
var dbInstances = make(map[string]*DB)
|
var dbInstances = make(map[string]*DB)
|
||||||
var dbInstancesLock = sync.RWMutex{}
|
var dbInstancesLock = sync.RWMutex{}
|
||||||
|
var globalVersionMap = sync.Map{}
|
||||||
var once sync.Once
|
var once sync.Once
|
||||||
|
|
||||||
|
func (db *DB) NextVersion(key string) int64 {
|
||||||
|
if db.Config.VersionRedis != "" {
|
||||||
|
r := redis.GetRedis(db.Config.VersionRedis, db.logger.logger)
|
||||||
|
if r != nil {
|
||||||
|
return r.INCR("db_ver_" + key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
v, _ := globalVersionMap.LoadOrStore(key, new(int64))
|
||||||
|
return atomic.AddInt64(v.(*int64), 1)
|
||||||
|
}
|
||||||
|
|
||||||
func GetDBWithoutCache(name string, logger *log.Logger) *DB {
|
func GetDBWithoutCache(name string, logger *log.Logger) *DB {
|
||||||
return getDB(name, logger, false)
|
return getDB(name, logger, false)
|
||||||
}
|
}
|
||||||
@ -503,14 +520,14 @@ func (db *DB) Quotes(texts []string) string {
|
|||||||
|
|
||||||
func (db *DB) Begin() *Tx {
|
func (db *DB) Begin() *Tx {
|
||||||
if db.conn == nil {
|
if db.conn == nil {
|
||||||
return &Tx{QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: errors.New("operate on a bad connection"), logger: db.logger}
|
return &Tx{db: db, QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: errors.New("operate on a bad connection"), logger: db.logger}
|
||||||
}
|
}
|
||||||
sqlTx, err := db.conn.Begin()
|
sqlTx, err := db.conn.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
db.logger.LogError(err.Error())
|
db.logger.LogError(err.Error())
|
||||||
return &Tx{QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: err, logger: db.logger}
|
return &Tx{db: db, QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: err, logger: db.logger}
|
||||||
}
|
}
|
||||||
return &Tx{QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), conn: sqlTx, logger: db.logger}
|
return &Tx{db: db, QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), conn: sqlTx, logger: db.logger}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) Exec(query string, args ...any) *ExecResult {
|
func (db *DB) Exec(query string, args ...any) *ExecResult {
|
||||||
|
|||||||
1
Tx.go
1
Tx.go
@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
type Tx struct {
|
type Tx struct {
|
||||||
conn *sql.Tx
|
conn *sql.Tx
|
||||||
|
db *DB
|
||||||
lastSql *string
|
lastSql *string
|
||||||
lastArgs []any
|
lastArgs []any
|
||||||
Error error
|
Error error
|
||||||
|
|||||||
57
version_test.go
Normal file
57
version_test.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
package db_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"apigo.cc/go/db"
|
||||||
|
_ "modernc.org/sqlite"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestVersionControl(t *testing.T) {
|
||||||
|
dbInst := db.GetDB("sqlite://:memory:", nil)
|
||||||
|
|
||||||
|
// Create table with autoVersion
|
||||||
|
dbInst.Exec("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, autoVersion BIGINT UNSIGNED)")
|
||||||
|
|
||||||
|
t.Run("InsertAutoVersion", func(t *testing.T) {
|
||||||
|
data := map[string]any{"id": 1, "name": "Alice"}
|
||||||
|
res := dbInst.Insert("users", data)
|
||||||
|
if res.Error != nil {
|
||||||
|
t.Fatalf("Insert failed: %v", res.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify version was injected
|
||||||
|
var ver int64
|
||||||
|
qr := dbInst.Query("SELECT autoVersion FROM users WHERE id = 1")
|
||||||
|
ver, _ = db.ToValue[int64](qr)
|
||||||
|
if ver != 1 {
|
||||||
|
t.Errorf("Expected version 1, got %d", ver)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("UpdateOptimisticLock", func(t *testing.T) {
|
||||||
|
// First update
|
||||||
|
data := map[string]any{"name": "Alice Updated", "autoVersion": int64(1)}
|
||||||
|
res := dbInst.Update("users", data, "id = 1")
|
||||||
|
if res.Error != nil {
|
||||||
|
t.Fatalf("Update failed: %v", res.Error)
|
||||||
|
}
|
||||||
|
if res.Changes() != 1 {
|
||||||
|
t.Errorf("Expected 1 change, got %d", res.Changes())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify version incremented
|
||||||
|
var ver int64
|
||||||
|
qr := dbInst.Query("SELECT autoVersion FROM users WHERE id = 1")
|
||||||
|
ver, _ = db.ToValue[int64](qr)
|
||||||
|
if ver != 2 {
|
||||||
|
t.Errorf("Expected version 2, got %d", ver)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try update with old version (should fail to update any rows)
|
||||||
|
dataConflict := map[string]any{"name": "Conflict", "autoVersion": int64(1)}
|
||||||
|
resConflict := dbInst.Update("users", dataConflict, "id = 1")
|
||||||
|
if resConflict.Changes() != 0 {
|
||||||
|
t.Errorf("Expected 0 changes due to optimistic lock, got %d", resConflict.Changes())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user