Compare commits

..

No commits in common. "main" and "v1.0.1" have entirely different histories.
main ... v1.0.1

19 changed files with 146 additions and 1740 deletions

16
AI.md Normal file
View File

@ -0,0 +1,16 @@
# AI 指南 - @go/db
## 🤖 AI 调用规则
- **版本**: v1.0.1
- **核心原则**: 优先使用结构化绑定(`To`, `MapResults`),避免手动拼装 SQL 结果。
- **敏感数据**: 必须通过 `SetEncryptKeys` 配置密钥,确保 DSN 中的密码安全。
- **读写分离**: 鼓励在 DSN 中配置多个 Host 以利用内置的读写分离机制。
- **性能优化**:
- 大规模查询应优先绑定到 Struct 切片。
- 频繁执行的 SQL 应使用 `Prepare`
- **事务处理**: 始终使用 `tx.Finish(err == nil)``defer tx.CheckFinished()` 确保事务闭环。
## ⚠️ 注意事项
- 严禁在代码中硬编码数据库凭据。
- 严禁忽略 `Exec``Query` 返回的 `Error`
- SQLite 模式下,时间字段会自动转换,无需手动解析字符串。

165
Base.go
View File

@ -6,52 +6,12 @@ import (
"fmt"
"reflect"
"strings"
"sync"
"time"
"apigo.cc/go/cast"
"apigo.cc/go/log"
)
var structFieldsCache = sync.Map{}
type structFieldInfo struct {
name string
index []int
}
func getStructFields(typ reflect.Type) []structFieldInfo {
if v, ok := structFieldsCache.Load(typ); ok {
return v.([]structFieldInfo)
}
var fields []structFieldInfo
flattenFields(typ, nil, &fields)
structFieldsCache.Store(typ, fields)
return fields
}
func flattenFields(typ reflect.Type, index []int, fields *[]structFieldInfo) {
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
if typ.Kind() != reflect.Struct {
return
}
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
newIndex := make([]int, len(index)+len(f.Index))
copy(newIndex, index)
copy(newIndex[len(index):], f.Index)
if f.Anonymous && f.Type.Kind() == reflect.Struct {
flattenFields(f.Type, newIndex, fields)
} else {
if f.Name[0] >= 'A' && f.Name[0] <= 'Z' {
*fields = append(*fields, structFieldInfo{name: f.Name, index: newIndex})
}
}
}
}
func basePrepare(db *sql.DB, tx *sql.Tx, query string) *Stmt {
var sqlStmt *sql.Stmt
var err error
@ -140,39 +100,8 @@ func quotes(quoteTag string, texts []string) string {
return strings.Join(texts, ",")
}
func makeInsertSql(quoteTag string, table string, data any, useReplace bool, versionField string, nextVer int64, idField string, nextId string) (string, []any) {
func makeInsertSql(quoteTag string, table string, data any, useReplace bool) (string, []any) {
keys, vars, values := MakeKeysVarsValues(data)
if versionField != "" {
found := false
for _, k := range keys {
if k == versionField {
found = true
break
}
}
if !found {
keys = append(keys, versionField)
vars = append(vars, "?")
values = append(values, nextVer)
}
}
if idField != "" && nextId != "" {
found := false
for i, k := range keys {
if k == idField {
found = true
if cast.String(values[i]) == "" {
values[i] = nextId
}
break
}
}
if !found {
keys = append(keys, idField)
vars = append(vars, "?")
values = append(values, nextId)
}
}
operation := "insert"
if useReplace {
operation = "replace"
@ -181,84 +110,47 @@ func makeInsertSql(quoteTag string, table string, data any, useReplace bool, ver
return query, values
}
func makeUpdateSql(quoteTag string, table string, data any, conditions string, versionField string, nextVer int64, args ...any) (string, []any) {
func makeUpdateSql(quoteTag string, table string, data any, conditions string, args ...any) (string, []any) {
args = flatArgs(args)
keys, vars, values := MakeKeysVarsValues(data)
newKeys := make([]string, 0, len(keys))
newValues := make([]any, 0, len(values))
var oldVersion any
for i, k := range keys {
if k == versionField {
oldVersion = values[i]
continue
}
newKeys = append(newKeys, fmt.Sprintf("%s=%s", quote(quoteTag, k), vars[i]))
newValues = append(newValues, values[i])
keys[i] = fmt.Sprintf("%s=%s", quote(quoteTag, k), vars[i])
}
if versionField != "" {
newKeys = append(newKeys, fmt.Sprintf("%s=?", quote(quoteTag, versionField)))
newValues = append(newValues, nextVer)
}
if oldVersion != nil {
if conditions != "" {
conditions = fmt.Sprintf("(%s) and %s=?", conditions, quote(quoteTag, versionField))
} else {
conditions = fmt.Sprintf("%s=?", quote(quoteTag, versionField))
}
args = append(args, oldVersion)
}
newValues = append(newValues, args...)
values = append(values, args...)
if conditions != "" {
conditions = " where " + conditions
}
query := fmt.Sprintf("update %s set %s%s", quote(quoteTag, table), strings.Join(newKeys, ","), conditions)
return query, newValues
query := fmt.Sprintf("update %s set %s%s", quote(quoteTag, table), strings.Join(keys, ","), conditions)
return query, values
}
func (db *DB) MakeInsertSql(table string, data any, useReplace bool) (string, []any) {
ts := db.getTable(table)
nextVer := int64(0)
if ts.VersionField != "" {
nextVer = db.NextVersion(table)
}
nextId := ""
if ts.IdField != "" {
nextId = db.NextID(table)
}
return makeInsertSql(db.QuoteTag, table, data, useReplace, ts.VersionField, nextVer, ts.IdField, nextId)
return makeInsertSql(db.QuoteTag, table, data, useReplace)
}
func (db *DB) MakeUpdateSql(table string, data any, conditions string, args ...any) (string, []any) {
ts := db.getTable(table)
nextVer := int64(0)
if ts.VersionField != "" {
nextVer = db.NextVersion(table)
}
return makeUpdateSql(db.QuoteTag, table, data, conditions, ts.VersionField, nextVer, args...)
return makeUpdateSql(db.QuoteTag, table, data, conditions, args...)
}
func (tx *Tx) MakeInsertSql(table string, data any, useReplace bool) (string, []any) {
ts := tx.db.getTable(table)
nextVer := int64(0)
if ts.VersionField != "" {
nextVer = tx.db.NextVersion(table)
}
nextId := ""
if ts.IdField != "" {
nextId = tx.db.NextID(table)
}
return makeInsertSql(tx.QuoteTag, table, data, useReplace, ts.VersionField, nextVer, ts.IdField, nextId)
return makeInsertSql(tx.QuoteTag, table, data, useReplace)
}
func (tx *Tx) MakeUpdateSql(table string, data any, conditions string, args ...any) (string, []any) {
ts := tx.db.getTable(table)
nextVer := int64(0)
if ts.VersionField != "" {
nextVer = tx.db.NextVersion(table)
return makeUpdateSql(tx.QuoteTag, table, data, conditions, args...)
}
func getFlatFields(fields map[string]reflect.Value, fieldKeys *[]string, value reflect.Value) {
valueType := value.Type()
for i := 0; i < value.NumField(); i++ {
v := value.Field(i)
if valueType.Field(i).Anonymous {
getFlatFields(fields, fieldKeys, v)
} else {
*fieldKeys = append(*fieldKeys, valueType.Field(i).Name)
fields[valueType.Field(i).Name] = v
}
}
return makeUpdateSql(tx.QuoteTag, table, data, conditions, ts.VersionField, nextVer, args...)
}
func MakeKeysVarsValues(data any) ([]string, []string, []any) {
@ -274,13 +166,18 @@ func MakeKeysVarsValues(data any) ([]string, []string, []any) {
}
if dataType.Kind() == reflect.Struct {
fields := getStructFields(dataType)
for _, f := range fields {
v := dataValue.FieldByIndex(f.index)
fields := make(map[string]reflect.Value)
fieldKeys := make([]string, 0)
getFlatFields(fields, &fieldKeys, dataValue)
for _, k := range fieldKeys {
if k[0] >= 'a' && k[0] <= 'z' {
continue
}
v := fields[k]
if v.Kind() == reflect.Interface {
v = v.Elem()
}
keys = append(keys, f.name)
keys = append(keys, k)
if v.Kind() == reflect.String && v.Len() > 0 && v.String()[0] == ':' {
vars = append(vars, v.String()[1:])
} else {

View File

@ -1,37 +1,12 @@
# 变更记录 - @go/db
## [1.0.4] - 2026-05-04
### 优化
- **日志增强**:升级 `apigo.cc/go/log` 至 v1.0.1,并重构数据库日志逻辑,利用新版 `log.DB` API 直接支持错误字段和调用栈捕获,提升排障效率。
## [1.0.3] - 2026-05-04
### 新增
- **自动随机 ID (Auto Random ID)**:当表主键或唯一索引字段类型为 `char(8/10/12/14)` 且值为空时,自动填充分布式唯一 ID。
- **智能 ID 生成器**自动适配数据库类型MySQL 右旋散列、PostgreSQL 时间单调、SQLite 纯随机),优先使用 Redis 分布式生成器。
## [1.0.2] - 2026-05-04
### 修复
- **PostgreSQL 增强**:补全了 `getTable` 中的元数据探测逻辑,使 `autoVersion` 和影子删除在 PostgreSQL 下可自动启用。
- **错误处理一致性**:统一了 `QueryResult``ExecResult` 的错误传播逻辑,确保 `r.Error` 在数据处理阶段也能正确记录。
- **单元测试修复**:修正了 `DB_test.go` 中因 SQLite 时区差异导致的 `TestInsertReplaceUpdateDelete` 偶发失败。
### 优化
- **性能提升**:在 `Base.go` 中引入了 `sync.Map` 缓存结构体反射解析结果,减少 SQL 生成过程中的反射开销。
# CHANGELOG - @go/db
## [1.0.1] - 2026-05-03
### 新增
- **架构 DSL (Schema-as-Code)**:支持通过文本 DSL 定义并自动同步数据库结构。
- **影子删除 (Shadow Deletion)**:支持 `SD` 标记,使用 `db.Remove` 自动将删除数据移动到 `_deleted` 后缀的备份表中。
- **乐观锁与版本控制**:支持 `db.Update` 自动处理版本递增与冲突检测。
- **泛型支持**:新增 `db.ToSlice[T]``db.To[T]`,提供类型安全的查询结果映射。
- **PostgreSQL 支持**:初步支持 PostgreSQL 的架构同步逻辑。
- **AI 友好文档**:新增 `db.SchemaMarkdown()` 自动生成 Markdown 格式的数据库模型文档。
### Optimized
- Refactored `makeResults` to pre-calculate field mappings for structs, significantly improving performance for large result sets.
- Simplified and optimized `makeValue` and `makePublicVarName` functions.
- Optimized time parsing in `makeResults`.
### 优化
- 重构了 `makeResults` 逻辑,预计算 Struct 字段映射,显著提升大数据集下的查询性能。
- 完善了 SQLite 的 `DATETIME` 与 Go `time.Time` 的自动转换逻辑。
- 所有的文档和注释已本地化为中文。
### 修复
- 修复了 `Tx` 结构体中的拼写错误 `isCommitedOrRollbacked``isCommittedOrRollbacked`
- 统一了全模块的参数命名规范:`requestSql` -> `query``wheres` -> `conditions`
### Fixed
- Fixed typo `isCommitedOrRollbacked` to `isCommittedOrRollbacked` in `Tx` struct.
- Standardized parameter naming: renamed `requestSql` to `query` and `wheres` to `conditions` across the module.
- Modernized Go syntax to align with latest standards.

288
DB.go
View File

@ -11,7 +11,6 @@ import (
"regexp"
"strings"
"sync"
"sync/atomic"
"time"
"apigo.cc/go/cast"
@ -20,7 +19,6 @@ import (
"apigo.cc/go/id"
"apigo.cc/go/log"
"apigo.cc/go/rand"
"apigo.cc/go/redis"
"apigo.cc/go/safe"
)
@ -39,7 +37,6 @@ type Config struct {
MaxIdles int
MaxLifeTime int
LogSlow config.Duration
Redis string
logger *log.Logger
}
@ -143,7 +140,6 @@ func (dbInfo *Config) ConfigureBy(setting string) {
dbInfo.MaxLifeTime = cast.Int(q.Get("maxLifeTime"))
dbInfo.MaxOpens = cast.Int(q.Get("maxOpens"))
dbInfo.LogSlow = config.Duration(cast.Duration(q.Get("logSlow")))
dbInfo.Redis = q.Get("redis")
dbInfo.SSL = q.Get("tls")
sslCa := q.Get("sslCA")
@ -175,7 +171,7 @@ func (dbInfo *Config) ConfigureBy(setting string) {
args := make([]string, 0)
for k := range q {
if k != "maxIdles" && k != "maxLifeTime" && k != "maxOpens" && k != "logSlow" && k != "tls" && k != "redis" {
if k != "maxIdles" && k != "maxLifeTime" && k != "maxOpens" && k != "logSlow" && k != "tls" {
args = append(args, k+"="+q.Get(k))
}
}
@ -192,33 +188,6 @@ type DB struct {
logger *dbLogger
Error error
QuoteTag string
tables map[string]*TableStruct
tablesLock *sync.RWMutex
}
type TableStruct struct {
Name string
Comment string
Fields []TableField
Columns []string
ShadowDelete bool
HasShadowTable bool
VersionField string
IdField string
IdSize int
}
type TableField struct {
Name string
Type string
Index string
IndexGroup string
Default string
Comment string
Null string
Extra string
Desc string
IsVersion bool
}
var confAes, _ = crypto.NewAESCBCAndEraseKey([]byte("?GQ$0K0GgLdO=f+~L68PLm$uhKr4'=tV"), []byte("VFs7@sK61cj^f?HZ"))
@ -237,7 +206,7 @@ type dbLogger struct {
}
func (dl *dbLogger) LogError(errStr string) {
dl.logger.DB(dl.config.Type, dl.config.Dsn(), "", nil, 0, errStr)
dl.logger.DBError(errStr, dl.config.Type, dl.config.Dsn(), "", nil, 0)
}
func (dl *dbLogger) LogQuery(query string, args []any, usedTime float32) {
@ -245,7 +214,7 @@ func (dl *dbLogger) LogQuery(query string, args []any, usedTime float32) {
}
func (dl *dbLogger) LogQueryError(errStr string, query string, args []any, usedTime float32) {
dl.logger.DB(dl.config.Type, dl.config.Dsn(), query, args, usedTime, errStr)
dl.logger.DBError(errStr, dl.config.Type, dl.config.Dsn(), query, args, usedTime)
}
var dbConfigs = make(map[string]*Config)
@ -253,97 +222,8 @@ var dbConfigsLock = sync.RWMutex{}
var dbSSLs = make(map[string]*SSL)
var dbInstances = make(map[string]*DB)
var dbInstancesLock = sync.RWMutex{}
var globalVersionMap = sync.Map{}
var globalIdMakers = sync.Map{}
var versionInited = sync.Map{}
var once sync.Once
func (db *DB) NextVersion(table string) int64 {
ts := db.getTable(table)
if ts.VersionField == "" {
return 0
}
if _, inited := versionInited.Load(table); !inited {
db.syncVersionFromDB(table, ts.VersionField)
versionInited.Store(table, true)
}
if db.Config.Redis != "" {
r := redis.GetRedis(db.Config.Redis, db.logger.logger)
if r != nil {
return r.INCR("db_ver_" + table)
}
}
v, _ := globalVersionMap.LoadOrStore(table, new(int64))
return atomic.AddInt64(v.(*int64), 1)
}
type idMaker interface {
Get(size int) string
GetForMysql(size int) string
GetForPostgreSQL(size int) string
}
func (db *DB) NextID(table string) string {
ts := db.getTable(table)
if ts.IdField == "" || ts.IdSize == 0 {
return ""
}
var maker idMaker
if db.Config.Redis != "" {
if v, ok := globalIdMakers.Load(db.Config.Redis); ok {
maker = v.(idMaker)
} else {
r := redis.GetRedis(db.Config.Redis, db.logger.logger)
if r != nil {
maker = redis.NewIDMaker(r)
globalIdMakers.Store(db.Config.Redis, maker)
}
}
}
if maker == nil {
maker = id.DefaultIDMaker
}
switch db.Config.Type {
case "mysql":
return maker.GetForMysql(ts.IdSize)
case "postgres", "pgx":
return maker.GetForPostgreSQL(ts.IdSize)
default:
return maker.Get(ts.IdSize)
}
}
func (db *DB) syncVersionFromDB(table, versionField string) {
query := fmt.Sprintf("SELECT MAX(%s) FROM %s", db.Quote(versionField), db.Quote(table))
maxVer := db.Query(query).IntOnR1C1()
if db.Config.Redis != "" {
r := redis.GetRedis(db.Config.Redis, db.logger.logger)
if r != nil {
r.Do("SETNX", "db_ver_"+table, maxVer)
return
}
}
v, _ := globalVersionMap.LoadOrStore(table, new(int64))
ptr := v.(*int64)
for {
current := atomic.LoadInt64(ptr)
if current >= maxVer {
break
}
if atomic.CompareAndSwapInt64(ptr, current, maxVer) {
break
}
}
}
func GetDBWithoutCache(name string, logger *log.Logger) *DB {
return getDB(name, logger, false)
}
@ -467,7 +347,7 @@ func getDB(name string, logger *log.Logger, useCache bool) *DB {
conn, err := getPool(conf)
if err != nil {
logger.DB(conf.Type, conf.Dsn(), "", nil, 0, err.Error())
logger.DBError(err.Error(), conf.Type, conf.Dsn(), "", nil, 0)
return &DB{conn: nil, QuoteTag: "\"", Error: err}
}
@ -475,15 +355,13 @@ func getDB(name string, logger *log.Logger, useCache bool) *DB {
db.QuoteTag = cast.If(conf.Type == "mysql", "`", "\"")
db.name = name
db.conn = conn
db.tables = make(map[string]*TableStruct)
db.tablesLock = new(sync.RWMutex)
if conf.ReadonlyHosts != nil {
readonlyConnections := make([]*sql.DB, 0)
for _, host := range conf.ReadonlyHosts {
conn, err := getPoolForHost(conf, host)
if err != nil {
logger.DB(conf.Type, conf.Dsn(), "", nil, 0, err.Error())
logger.DBError(err.Error(), conf.Type, conf.Dsn(), "", nil, 0)
} else {
readonlyConnections = append(readonlyConnections, conn)
}
@ -562,8 +440,6 @@ func (db *DB) CopyByLogger(logger *log.Logger) *DB {
newDB.conn = db.conn
newDB.readonlyConnections = db.readonlyConnections
newDB.Config = db.Config
newDB.tables = db.tables
newDB.tablesLock = db.tablesLock
if logger == nil {
logger = log.DefaultLogger
}
@ -616,14 +492,14 @@ func (db *DB) Quotes(texts []string) string {
func (db *DB) Begin() *Tx {
if db.conn == nil {
return &Tx{db: db, QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: errors.New("operate on a bad connection"), logger: db.logger}
return &Tx{QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: errors.New("operate on a bad connection"), logger: db.logger}
}
sqlTx, err := db.conn.Begin()
if err != nil {
db.logger.LogError(err.Error())
return &Tx{db: db, QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: err, logger: db.logger}
return &Tx{QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), Error: err, logger: db.logger}
}
return &Tx{db: db, QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), conn: sqlTx, logger: db.logger}
return &Tx{QuoteTag: db.QuoteTag, logSlow: db.Config.LogSlow.TimeDuration(), conn: sqlTx, logger: db.logger}
}
func (db *DB) Exec(query string, args ...any) *ExecResult {
@ -706,148 +582,22 @@ func (db *DB) Update(table string, data any, conditions string, args ...any) *Ex
}
func (db *DB) Delete(table string, conditions string, args ...any) *ExecResult {
ts := db.getTable(table)
if !ts.HasShadowTable {
if conditions != "" {
conditions = " where " + conditions
}
query := fmt.Sprintf("delete from %s%s", db.Quote(table), conditions)
r := baseExec(db.conn, nil, query, args...)
r.logger = db.logger
if r.Error != nil {
db.logger.LogQueryError(r.Error.Error(), query, args, r.usedTime)
} else {
if db.Config.LogSlow > 0 && r.usedTime >= float32(db.Config.LogSlow.TimeDuration()/time.Millisecond) {
db.logger.LogQuery(query, args, r.usedTime)
}
}
return r
if conditions != "" {
conditions = " where " + conditions
}
// Shadow delete
tx := db.Begin()
defer tx.CheckFinished()
r := tx.Delete(table, conditions, args...)
if r.Error == nil {
tx.Commit()
query := fmt.Sprintf("delete from %s%s", db.Quote(table), conditions)
r := baseExec(db.conn, nil, query, args...)
r.logger = db.logger
if r.Error != nil {
db.logger.LogQueryError(r.Error.Error(), query, args, r.usedTime)
} else {
if db.Config.LogSlow > 0 && r.usedTime >= float32(db.Config.LogSlow.TimeDuration()/time.Millisecond) {
db.logger.LogQuery(query, args, r.usedTime)
}
}
return r
}
func (db *DB) getTable(table string) *TableStruct {
db.tablesLock.RLock()
ts, ok := db.tables[table]
db.tablesLock.RUnlock()
if ok {
return ts
}
db.tablesLock.Lock()
defer db.tablesLock.Unlock()
// Double check
if ts, ok = db.tables[table]; ok {
return ts
}
ts = &TableStruct{Name: table}
// Probe columns and autoVersion
var query string
if db.Config.Type == "mysql" {
query = "SELECT COLUMN_NAME, DATA_TYPE, CHARACTER_MAXIMUM_LENGTH, COLUMN_KEY FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?"
res := db.Query(query, db.Config.DB, table)
rows := res.MapResults()
for _, row := range rows {
col := cast.String(row["COLUMN_NAME"])
dataType := cast.String(row["DATA_TYPE"])
charLen := cast.Int(row["CHARACTER_MAXIMUM_LENGTH"])
colKey := cast.String(row["COLUMN_KEY"])
ts.Columns = append(ts.Columns, col)
if col == "autoVersion" {
ts.VersionField = "autoVersion"
}
if (colKey == "PRI" || colKey == "UNI") && strings.ToLower(dataType) == "char" && (charLen == 8 || charLen == 10 || charLen == 12 || charLen == 14) {
ts.IdField = col
ts.IdSize = charLen
}
}
} else if db.Config.Type == "postgres" || db.Config.Type == "pgx" {
query = "SELECT column_name, data_type, character_maximum_length FROM information_schema.columns WHERE table_schema = current_schema() AND table_name = ?"
res := db.Query(query, table)
rows := res.MapResults()
for _, row := range rows {
col := cast.String(row["column_name"])
dataType := cast.String(row["data_type"])
charLen := cast.Int(row["character_maximum_length"])
ts.Columns = append(ts.Columns, col)
if col == "autoVersion" {
ts.VersionField = "autoVersion"
}
// PostgreSQL PK/Unique check is complex, we use column name 'id' and char type as a heuristic or check constraints if needed.
// To keep it simple and efficient as requested:
if (col == "id" || col == "ID") && (strings.Contains(strings.ToLower(dataType), "char")) && (charLen == 8 || charLen == 10 || charLen == 12 || charLen == 14) {
ts.IdField = col
ts.IdSize = charLen
}
}
} else if isFileDB(db.Config.Type) {
// For SQLite
query = fmt.Sprintf("PRAGMA table_info(%s)", db.Quote(table))
res := db.Query(query)
rows := res.MapResults()
for _, row := range rows {
colName := cast.String(row["name"])
colType := strings.ToUpper(cast.String(row["type"]))
isPk := cast.Int(row["pk"]) > 0
ts.Columns = append(ts.Columns, colName)
if colName == "autoVersion" {
ts.VersionField = "autoVersion"
}
if isPk && strings.Contains(colType, "CHAR") {
// Extract length from CHAR(N)
charLen := 0
fmt.Sscanf(colType, "CHAR(%d)", &charLen)
if charLen == 0 {
fmt.Sscanf(colType, "CHARACTER(%d)", &charLen)
}
if charLen == 8 || charLen == 10 || charLen == 12 || charLen == 14 {
ts.IdField = colName
ts.IdSize = charLen
}
}
}
}
// Probe shadow table
shadowTable := table + "_deleted"
if db.Config.Type == "mysql" {
query = "SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?"
res := db.Query(query, db.Config.DB, shadowTable)
if res.StringOnR1C1() != "" {
ts.HasShadowTable = true
}
} else if db.Config.Type == "postgres" || db.Config.Type == "pgx" {
query = "SELECT table_name FROM information_schema.tables WHERE table_schema = current_schema() AND table_name = ?"
res := db.Query(query, shadowTable)
if res.StringOnR1C1() != "" {
ts.HasShadowTable = true
}
} else if isFileDB(db.Config.Type) {
query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?"
res := db.Query(query, shadowTable)
if res.StringOnR1C1() != "" {
ts.HasShadowTable = true
}
}
db.tables[table] = ts
return ts
}
func (db *DB) InKeys(numArgs int) string {
return InKeys(numArgs)
}

View File

@ -2,7 +2,6 @@ package db_test
import (
"fmt"
"os"
"regexp"
"strings"
"testing"
@ -16,12 +15,6 @@ import (
_ "modernc.org/sqlite"
)
func TestMain(m *testing.M) {
code := m.Run()
os.Remove("test.db")
os.Exit(code)
}
var dbset = "sqlite://test.db"
type userInfo struct {
@ -80,7 +73,7 @@ func initDB(t *testing.T) *db.DB {
email VARCHAR(45),
parents JSON,
active TINYINT NOT NULL DEFAULT 0,
time DATETIME NOT NULL DEFAULT (strftime('%Y-%m-%d %H:%M:%f', 'now', 'localtime')));`)
time DATETIME NOT NULL DEFAULT (strftime('%Y-%m-%d %H:%M:%f')));`)
}
if er.Error != nil {
t.Fatal("Failed to create table", er)

103
DSL.md
View File

@ -1,103 +0,0 @@
# 数据库架构 DSL (Schema-as-Code)
本模块提供了一种基于文本的 DSL领域专用语言来定义数据库架构。它支持 MySQL、PostgreSQL 和 SQLite旨在实现 AI 友好的“架构即代码”。
## 语法概览
一个架构描述由 **分组 (Groups)**、**数据表 (Tables)** 和 **字段 (Fields)** 组成。
### 分组 (Groups)
分组由以 `==` 开头的行定义,用于逻辑隔离不同的表集合。
```
== 用户系统 ==
```
### 数据表 (Tables)
数据表在分组下顶格定义。可以在 `//` 后添加注释。
在表名后添加 `SD` 标记可启用 **影子删除 (Shadow Deletion)**。启用后,删除的数据会自动移动到 `[表名]_deleted` 表中。
```
users SD // 系统用户表
```
### 字段 (Fields)
字段在数据表下通过缩进(空格或制表符)定义。
格式:`[名称] [类型标记][长度] [索引标记] // [注释]`
```
id AI // 主键,自动递增
username v32 U // Varchar(32),唯一索引
password v64 // Varchar(64)
version ver // Bigint 版本号,用于乐观锁和增量同步
create_time ct // 创建时间 (CURRENT_TIMESTAMP)
update_time ctu // 更新时间 (ON UPDATE CURRENT_TIMESTAMP)
```
## 类型标记 (Type Tags)
| 标记 | 对应数据库类型 (MySQL) | 说明 |
|-----|-----------------|------|
| `i` | `int` | 整型 |
| `ui`| `int unsigned` | 无符号整型 |
| `bi`| `bigint` | 长整型 |
| `ubi`| `bigint unsigned` | 无符号长整型 |
| `ti`| `tinyint` | 短整型 |
| `v` | `varchar` | 默认长度由驱动决定或忽略 |
| `v[N]`| `varchar(N)` | 例如:`v50` -> `varchar(50)` |
| `c[N]`| `char(N)` | 例如:`c32` -> `char(32)` |
| `t` | `text` | 文本 |
| `dt`| `datetime` | 日期时间 |
| `d` | `date` | 日期 |
| `tm`| `time` | 时间 |
| `f` | `float` | 浮点数 |
| `ff`| `double` | 双精度浮点数 |
| `b` | `tinyint unsigned`| 布尔值别名 |
| `bb`| `blob` | 二进制大对象 |
## 索引与特殊标记
| 标记 | 含义 |
|-----|---------|
| `PK` | 主键 (Primary Key) |
| `AI` | 自动递增 + 主键 (Auto Increment) |
| `U` | 唯一索引 (Unique Index) |
| `I` | 普通索引 (Index) |
| `TI` | 全文索引 (Fulltext Index, 仅 MySQL) |
| `ver`| 版本号字段 (用于乐观锁和增量同步) |
| `ct` | 创建时间 (Created Time) |
| `ctu`| 更新时间 (Updated Time) |
| `nn` | 非空 (NOT NULL) |
| `n` | 可为空 (NULL) |
### 复合索引
`I``U` 后添加数字可以将多个字段组合成一个复合索引。
```
first_name v32 I1
last_name v32 I1 // 在 (first_name, last_name) 上创建复合索引 'ik_table_1'
```
## 高级特性
### 1. 影子删除 (SD - Shadow Deletion)
当表标记为 `SD` 时,调用 `db.Remove()` 方法不会真正删除数据,而是将其从原表移动到 `_deleted` 后缀的影子表中。
- **优点**:主表查询不包含已删除数据,效率更高;历史数据可追溯。
- **API**: 使用 `db.Remove(table, conditions, args...)` 触发。
### 2. 乐观锁与增量同步 (ver)
标记为 `ver` 的字段(通常命名为 `version`)具有特殊行为:
- **自动递增**:每次调用 `db.Update()` 时,该字段会自动 `+1`
- **冲突检测**:如果在更新数据中包含了当前版本号,`db.Update()` 会在 `WHERE` 条件中自动加入版本校验。如果版本不一致,更新将失败(影响行数为 0
- **增量同步**:外部系统可以通过 `WHERE version > last_version` 轻松获取自上次同步以来的所有变更。
## 完整示例
```
== 默认分组 ==
users SD // 用户表
id AI // 用户 ID
username v32 U // 登录名
email v64 U // 联系邮箱
password v128 // 加密后的密码
status ti I // 0: 活跃, 1: 禁用
version ver // 行版本号
created_at ct // 创建记录
updated_at ctu // 更新记录
```

122
README.md
View File

@ -1,17 +1,15 @@
# @go/db
> **维护者声明:** 本项目由 AI 维护。代码源自 `github.com/ssgo/db` 的重构,支持现代 Go 特性、内存安全防护、读写分离、全局版本同步及泛型优化。
> **Maintainer Statement:** 本项目由 AI 维护。代码源自 github.com/ssgo/db 的重构,支持内存安全防护、读写分离及泛型优化。
## 🎯 设计哲学:约定优于配置
## 🎯 设计哲学
`@go/db` 遵循“约定优于配置”的设计哲学,旨在通过合理的默认行为和命名约定,简化数据库操作,同时保持强大的功能
`@go/db` 是一个极致精简、意图优先的数据库抽象层。它不试图取代 SQL而是通过智能结果绑定与 SQL 自动化生成,消除数据库操作中的样板代码
* **隐式高级功能**:版本控制和软删除等高级功能是**自动启用**的,无需显式配置。
- **版本控制**: 如果一个表包含名为 `autoVersion` 且类型为 `bigint unsigned` (`ubi`) 的字段,`Update``Insert` 操作将自动处理其版本递增和乐观锁。
- **自动随机 ID**: 当字段类型为 `char(8/10/12/14)` 且为 `PRIMARY KEY``UNIQUE` 时,`Insert/Replace` 操作若发现该字段为空,将自动填充唯一 ID优先使用 Redis 分布式 ID
- **智能删除**: 如果存在一个名为 `[表名]_deleted` 的表,`Delete` 操作将自动执行**影子删除**(将数据移入该表);否则,执行物理删除。
* **全局版本号**:内置基于 Redis 的全局序列(自动降级为本地 Map确保分布式环境下 `version` 绝对单调递增,为可靠的增量同步提供基础。
* **架构即代码 (DSL)**`SD` 标记现在仅用于**建表**时自动创建 `_deleted` 表,而运行时的删除行为由约定决定。
* **智能绑定**根据结果容器类型Struct/Map/Slice/BaseType自动适配查询逻辑无需手动 Scan。
* **内存防御**:集成 `go/safe`,数据库密码在内存中加密存储,使用时物理擦除。
* **读写分离**:内置连接池管理,支持配置多个只读节点实现自动负载均衡。
* **驱动透明**:统一 MySQL、PostgreSQL (pgx) 与 SQLite 的 API 差异。
## 📦 安装
@ -19,73 +17,49 @@
go get apigo.cc/go/db
```
## 💡 快速开始
```go
import "apigo.cc/go/db"
import _ "apigo.cc/go/db/mysql" // 引入驱动
// 初始化连接
d := db.GetDB("mysql://user:pass@host:3306/dbname", nil)
// 1. 查询全部结果到 Struct 切片
var users []User
d.Query("SELECT * FROM users").To(&users)
// 2. 自动化插入
d.Insert("users", User{Name: "Star", Active: true})
// 3. 事务操作
tx := d.Begin()
tx.Exec("UPDATE balance SET amount = amount - 10 WHERE id = ?", 1)
tx.Commit()
```
## 🛠 API 指南
### 1. 核心方法
- **`GetDB(name string, logger *log.Logger) *DB`**
- 获取数据库连接实例。`name` 可以是 `db.json` 中的配置名,也可以是标准 DSN`mysql://user:pwd@host:port/db``sqlite://test.db`)。
- **`Sync(schema string) error`**
- 解析 DSL 并同步数据库表结构。用于创建表(包括 `_deleted` 表)和索引。详见 [架构 DSL 指南](./DSL.md)。
### 核心对象
- **`GetDB(setting string, logger *log.Logger) *DB`**: 通过 DSN 或配置名获取数据库实例。
- **`DB.Insert/Replace/Update/Delete`**: 自动生成 SQL 并执行,支持 Struct 与 Map
- **`QueryResult.To(target any)`**: 将查询结果深度映射到目标容器。
- **`QueryResult.MapResults() []map[string]any`**: 快捷获取通用结果集
### 2. 写操作 (返回 `*ExecResult`)
- **`Insert/Replace(table string, data any)`**: 插入或替换数据。若表包含 `autoVersion` 字段,会自动注入初始版本号。
- **`Update(table string, data any, conditions string, args ...any)`**: 更新数据。若表包含 `autoVersion` 字段,自动递增版本号并应用乐观锁。
- **`Delete(table string, conditions string, args ...any)`**: **智能删除**。根据是否存在 `_deleted` 表自动选择物理删除或影子删除。
### 结果容器适配规则
| 容器类型 | 行为 |
| :--- | :--- |
| `[]Struct` | 返回所有行,按字段名自动映射 |
| `Struct` | 返回第一行,按字段名自动映射 |
| `[]map[string]any` | 返回所有行,保留原始字段名 |
| `[]BaseType` | 返回所有行,仅取第一列 |
| `BaseType` | 返回第一行第一列 |
#### 结果判定 (`ExecResult`)
```go
res := dbInst.Insert("users", newUser)
if res.Error != nil { /* 发生 SQL 错误 */ }
count := res.Changes() // 受影响行数
id := res.Id() // 获取自增 ID
```
### 安全与高级特性
- **`SetEncryptKeys(key, iv []byte)`**: 配置全局敏感数据加密密钥。
- **读写分离**: 在 DSN 中配置 `host1,host2,host3`,首个为主库,其余为随机只读库。
- **SQLite 时间修复**: 自动处理 SQLite 毫秒级 `DATETIME` 格式与标准 `time.Time` 的转换。
### 3. 读操作 (返回 `*QueryResult`)
- **`Query(query string, args ...any)`**: 执行查询。
- **结果处理 (QueryResult)**:
- **泛型绑定 (推荐)**: `db.To[T](res)`, `db.ToSlice[T](res)`
- **KV 映射**: `res.ToKV(&mapObj)` 将前两列自动转为 Map。
- **快捷取值**: `IntOnR1C1()`, `StringOnR1C1()`, `MapOnR1()`, `StringsOnC1()` 等。
- **错误感知**: 所有结果方法都会同步更新 `res.Error`,可链式调用后统一判断。
## 🔐 安全与加密
我们极致注重数据安全:
- **密码防御**: 内存中的数据库密码受 `safe.SafeBuf` 保护,防止通过内存 Dump 获取明文。
- **配置加密**: 建议在 `db.json` 中使用密文存储敏感信息。
- **TODO: sskey 集成**: 计划引入 `sskey` 工具,实现生产环境密钥的统一托管与自动解密。
## 🏗 架构即代码 (DSL 示例)
我们鼓励通过 DSL 定义表结构,实现“修改代码即修改表”。
```go
schema := `
== Default ==
users SD // 用户表,开启影子删除
id AI // 自增 ID
name v50 U // 字符串(50),唯一索引
autoVersion ubi // 自动版本号
status ti // 状态 (TinyInt)
`
dbInst.Sync(schema) // 自动创建 users 和 users_deleted 表及索引
```
### 4. 事务
```go
tx := dbInst.Begin()
if tx.Error != nil { /* 处理错误 */ }
defer tx.CheckFinished() // 自动处理未提交的 Rollback
tx.Insert("users", newUser)
if tx.Error == nil {
tx.Commit()
}
```
## 📖 详细文档
- [架构 DSL 与版本同步指南](./DSL.md)
- [测试报告](./TEST.md)
- [版本变更记录](./CHANGELOG.md)
## 🧪 验证状态
已通过 SQLite 集成测试。详见:[TEST.md](./TEST.md)

View File

@ -39,7 +39,6 @@ func (r *ExecResult) Changes() int64 {
}
numChanges, err := r.result.RowsAffected()
if err != nil {
r.Error = err
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
return 0
}
@ -52,7 +51,6 @@ func (r *ExecResult) Id() int64 {
}
insertId, err := r.result.LastInsertId()
if err != nil {
r.Error = err
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
return 0
}
@ -75,23 +73,10 @@ func (r *QueryResult) To(result any) error {
return r.makeResults(result, r.rows)
}
func ToSlice[T any](r *QueryResult) ([]T, error) {
var result []T
err := r.To(&result)
return result, err
}
func To[T any](r *QueryResult) (T, error) {
var result T
err := r.To(&result)
return result, err
}
func (r *QueryResult) MapResults() []map[string]any {
result := make([]map[string]any, 0)
err := r.makeResults(&result, r.rows)
if err != nil {
r.Error = err
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
}
return result
@ -101,7 +86,6 @@ func (r *QueryResult) SliceResults() [][]any {
result := make([][]any, 0)
err := r.makeResults(&result, r.rows)
if err != nil {
r.Error = err
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
}
return result
@ -111,7 +95,6 @@ func (r *QueryResult) StringMapResults() []map[string]string {
result := make([]map[string]string, 0)
err := r.makeResults(&result, r.rows)
if err != nil {
r.Error = err
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
}
return result
@ -121,7 +104,6 @@ func (r *QueryResult) StringSliceResults() [][]string {
result := make([][]string, 0)
err := r.makeResults(&result, r.rows)
if err != nil {
r.Error = err
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
}
return result
@ -131,7 +113,6 @@ func (r *QueryResult) MapOnR1() map[string]any {
result := make(map[string]any)
err := r.makeResults(&result, r.rows)
if err != nil {
r.Error = err
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
}
return result
@ -141,7 +122,6 @@ func (r *QueryResult) StringMapOnR1() map[string]string {
result := make(map[string]string)
err := r.makeResults(&result, r.rows)
if err != nil {
r.Error = err
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
}
return result
@ -151,7 +131,6 @@ func (r *QueryResult) IntsOnC1() []int64 {
result := make([]int64, 0)
err := r.makeResults(&result, r.rows)
if err != nil {
r.Error = err
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
}
return result
@ -161,7 +140,6 @@ func (r *QueryResult) StringsOnC1() []string {
result := make([]string, 0)
err := r.makeResults(&result, r.rows)
if err != nil {
r.Error = err
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
}
return result
@ -171,7 +149,6 @@ func (r *QueryResult) IntOnR1C1() int64 {
var result int64 = 0
err := r.makeResults(&result, r.rows)
if err != nil {
r.Error = err
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
}
return result
@ -181,7 +158,6 @@ func (r *QueryResult) FloatOnR1C1() float64 {
var result float64 = 0
err := r.makeResults(&result, r.rows)
if err != nil {
r.Error = err
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
}
return result
@ -191,7 +167,6 @@ func (r *QueryResult) StringOnR1C1() string {
result := ""
err := r.makeResults(&result, r.rows)
if err != nil {
r.Error = err
r.logger.LogQueryError(err.Error(), *r.Sql, r.Args, r.usedTime)
}
return result

696
Schema.go
View File

@ -1,696 +0,0 @@
package db
import (
"fmt"
"regexp"
"strings"
"apigo.cc/go/cast"
)
// SchemaGroup 描述表分组
type SchemaGroup struct {
Name string
Tables []*TableStruct
}
var fieldSpliter = regexp.MustCompile(`\s+`)
var wnMatcher = regexp.MustCompile(`^([a-zA-Z]+)([0-9]+)$`)
// ParseField 解析单行字段描述
func ParseField(line string) TableField {
lc := strings.SplitN(line, "//", 2)
comment := ""
if len(lc) == 2 {
line = strings.TrimSpace(lc[0])
comment = strings.TrimSpace(lc[1])
}
a := fieldSpliter.Split(line, 10)
field := TableField{
Name: a[0],
Type: "",
Index: "",
IndexGroup: "",
Default: "",
Comment: comment,
Null: "NULL",
Extra: "",
Desc: "",
IsVersion: false,
}
for i := 1; i < len(a); i++ {
wn := wnMatcher.FindStringSubmatch(a[i])
tag := a[i]
size := 0
if wn != nil {
tag = wn[1]
size = cast.Int(wn[2])
}
switch tag {
case "PK":
field.Index = "pk"
field.Null = "NOT NULL"
case "I":
field.Index = "index"
case "AI":
field.Extra = "AUTO_INCREMENT"
field.Index = "pk"
field.Null = "NOT NULL"
case "TI":
field.Index = "fulltext"
case "U":
field.Index = "unique"
case "ver":
field.IsVersion = true
field.Type = "bigint"
field.Default = "0"
field.Null = "NOT NULL"
case "ct":
field.Default = "CURRENT_TIMESTAMP"
case "ctu":
field.Default = "CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"
case "n":
field.Null = "NULL"
case "nn":
field.Null = "NOT NULL"
case "c":
field.Type = "char"
case "v":
field.Type = "varchar"
case "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9":
field.Type = "varchar"
case "dt":
field.Type = "datetime"
case "d":
field.Type = "date"
case "tm":
field.Type = "time"
case "i":
field.Type = "int"
case "ui":
field.Type = "int unsigned"
case "ti":
field.Type = "tinyint"
case "uti":
field.Type = "tinyint unsigned"
case "b":
field.Type = "tinyint unsigned"
case "bi":
field.Type = "bigint"
case "ubi":
field.Type = "bigint unsigned"
case "f":
field.Type = "float"
case "uf":
field.Type = "float unsigned"
case "ff":
field.Type = "double"
case "uff":
field.Type = "double unsigned"
case "si":
field.Type = "smallint"
case "usi":
field.Type = "smallint unsigned"
case "mi":
field.Type = "middleint"
case "umi":
field.Type = "middleint unsigned"
case "t":
field.Type = "text"
case "bb":
field.Type = "blob"
default:
field.Type = tag
}
if size > 0 {
switch tag {
case "I":
field.Index = "index"
field.IndexGroup = cast.String(size)
case "U":
field.Index = "unique"
field.IndexGroup = cast.String(size)
default:
if !field.IsVersion {
field.Type += fmt.Sprintf("(%d)", size)
}
}
}
}
return field
}
// ParseSchema 解析完整的 Schema 描述文本
func ParseSchema(desc string) []*SchemaGroup {
groups := make([]*SchemaGroup, 0)
var currentGroup *SchemaGroup
var currentTable *TableStruct
lines := strings.Split(desc, "\n")
for _, line := range lines {
trimmedLine := strings.TrimSpace(line)
if trimmedLine == "" || strings.HasPrefix(trimmedLine, "#") {
continue
}
if strings.HasPrefix(trimmedLine, "==") {
currentGroup = &SchemaGroup{
Name: strings.TrimSpace(strings.Trim(trimmedLine, "=")),
Tables: make([]*TableStruct, 0),
}
groups = append(groups, currentGroup)
continue
}
if currentGroup == nil {
currentGroup = &SchemaGroup{
Name: "Default",
Tables: make([]*TableStruct, 0),
}
groups = append(groups, currentGroup)
}
if !strings.HasPrefix(line, " ") && !strings.HasPrefix(line, "\t") {
lc := strings.SplitN(trimmedLine, "//", 2)
tableNamePart := strings.TrimSpace(lc[0])
comment := ""
if len(lc) == 2 {
comment = strings.TrimSpace(lc[1])
}
shadowDelete := false
if strings.HasSuffix(tableNamePart, " SD") {
shadowDelete = true
tableNamePart = strings.TrimSpace(strings.TrimSuffix(tableNamePart, " SD"))
}
currentTable = &TableStruct{
Name: tableNamePart,
Comment: comment,
Fields: make([]TableField, 0),
ShadowDelete: shadowDelete,
}
currentGroup.Tables = append(currentGroup.Tables, currentTable)
} else if currentTable != nil {
field := ParseField(trimmedLine)
if field.IsVersion {
currentTable.VersionField = field.Name
}
currentTable.Fields = append(currentTable.Fields, field)
}
}
return groups
}
// Parse 根据数据库类型解析字段的基础描述
func (field *TableField) Parse(tableType string) {
if strings.HasPrefix(tableType, "sqlite") || tableType == "chai" {
// sqlite3 不能修改字段统一使用NULL
field.Null = "NULL"
if field.Extra == "AUTO_INCREMENT" || field.Extra == "AUTOINCREMENT" {
field.Extra = "PRIMARY KEY AUTOINCREMENT"
field.Type = "integer"
field.Null = "NOT NULL"
}
}
a := make([]string, 0)
if tableType == "mysql" {
a = append(a, fmt.Sprintf("`%s` %s", field.Name, field.Type))
lowerType := strings.ToLower(field.Type)
if strings.Contains(lowerType, "varchar") || strings.Contains(lowerType, "text") {
a = append(a, " COLLATE utf8mb4_general_ci")
}
} else if tableType == "pg" || tableType == "pgsql" || tableType == "postgres" {
typ := field.Type
if field.Extra == "AUTO_INCREMENT" {
if strings.Contains(typ, "bigint") {
typ = "bigserial"
} else {
typ = "serial"
}
field.Extra = ""
}
a = append(a, fmt.Sprintf("\"%s\" %s", field.Name, typ))
} else {
a = append(a, fmt.Sprintf("\"%s\" %s", field.Name, field.Type))
}
if field.Extra != "" && !strings.Contains(field.Extra, "PRIMARY KEY") {
a = append(a, " "+field.Extra)
}
a = append(a, " "+field.Null)
if field.Default != "" {
if strings.Contains(field.Default, "CURRENT_TIMESTAMP") || strings.Contains(field.Default, "()") || strings.Contains(field.Default, "SYSTIMESTAMP") {
a = append(a, " DEFAULT "+field.Default)
} else {
a = append(a, " DEFAULT '"+field.Default+"'")
}
}
if strings.HasPrefix(tableType, "sqlite") || tableType == "chai" {
field.Comment = ""
field.Type = "numeric"
} else if tableType == "pg" || tableType == "pgsql" || tableType == "postgres" {
// PostgreSQL comments are separate statements
} else {
if field.Comment != "" {
a = append(a, " COMMENT '"+field.Comment+"'")
}
}
field.Desc = strings.Join(a, "")
}
// Sync 根据 Schema 描述文本同步数据库结构
func (db *DB) Sync(desc string) error {
groups := ParseSchema(desc)
var outErr error
for _, group := range groups {
for _, table := range group.Tables {
db.tablesLock.Lock()
db.tables[table.Name] = table
db.tablesLock.Unlock()
err := db.CheckTable(table)
if err != nil {
outErr = err
db.logger.logger.Error("failed to sync table", "table", table.Name, "err", err.Error())
}
}
}
return outErr
}
// CheckTable 检查并同步单个表结构
func (db *DB) CheckTable(table *TableStruct) error {
fieldSets := make([]string, 0)
pks := make([]string, 0)
keySets := make([]string, 0)
keySetBy := make(map[string]string)
keySetFields := make(map[string]string)
isPostgres := db.Config.Type == "pg" || db.Config.Type == "pgsql" || db.Config.Type == "postgres"
table.Columns = make([]string, 0, len(table.Fields))
for i, field := range table.Fields {
field.Parse(db.Config.Type)
table.Fields[i] = field
table.Columns = append(table.Columns, field.Name)
switch strings.ToLower(field.Index) {
case "pk", "primary key":
pks = append(pks, field.Name)
case "unique":
keyName := fmt.Sprint("uk_", table.Name, "_", field.Name)
if field.IndexGroup != "" {
keyName = fmt.Sprint("uk_", table.Name, "_", field.IndexGroup)
}
if keySetBy[keyName] != "" {
keySetFields[keyName] += " " + field.Name
if strings.HasPrefix(db.Config.Type, "sqlite") || db.Config.Type == "chai" {
keySetBy[keyName] = strings.Replace(keySetBy[keyName], ")", ", "+db.Quote(field.Name)+")", 1)
} else if isPostgres {
keySetBy[keyName] = strings.Replace(keySetBy[keyName], ")", ", "+db.Quote(field.Name)+")", 1)
} else {
keySetBy[keyName] = strings.Replace(keySetBy[keyName], ") COMMENT", ", "+db.Quote(field.Name)+") COMMENT", 1)
}
} else {
keySetFields[keyName] = field.Name
keySet := ""
if strings.HasPrefix(db.Config.Type, "sqlite") || db.Config.Type == "chai" {
keySet = fmt.Sprintf("CREATE UNIQUE INDEX \"%s\" ON \"%s\" (\"%s\")", keyName, table.Name, field.Name)
} else if isPostgres {
keySet = fmt.Sprintf("CREATE UNIQUE INDEX \"%s\" ON \"%s\" (\"%s\")", keyName, table.Name, field.Name)
} else {
keySet = fmt.Sprintf("UNIQUE KEY "+db.Quote("%s")+" ("+db.Quote("%s")+") COMMENT '%s'", keyName, field.Name, field.Comment)
}
keySets = append(keySets, keySet)
keySetBy[keyName] = keySet
}
case "fulltext":
if !strings.HasPrefix(db.Config.Type, "sqlite") && db.Config.Type != "chai" && !isPostgres {
keyName := fmt.Sprint("tk_", table.Name, "_", field.Name)
keySet := fmt.Sprintf("FULLTEXT KEY "+db.Quote("%s")+" ("+db.Quote("%s")+") COMMENT '%s'", keyName, field.Name, field.Comment)
keySets = append(keySets, keySet)
keySetBy[keyName] = keySet
}
case "index":
keyName := fmt.Sprint("ik_", table.Name, "_", field.Name)
if field.IndexGroup != "" {
keyName = fmt.Sprint("ik_", table.Name, "_", field.IndexGroup)
}
if keySetBy[keyName] != "" {
keySetFields[keyName] += " " + field.Name
if strings.HasPrefix(db.Config.Type, "sqlite") || db.Config.Type == "chai" {
keySetBy[keyName] = strings.Replace(keySetBy[keyName], ")", ", \""+field.Name+"\")", 1)
} else if isPostgres {
keySetBy[keyName] = strings.Replace(keySetBy[keyName], ")", ", \""+field.Name+"\")", 1)
} else {
keySetBy[keyName] = strings.Replace(keySetBy[keyName], ") COMMENT", ", `"+field.Name+"`) COMMENT", 1)
}
} else {
keySetFields[keyName] = field.Name
keySet := ""
if strings.HasPrefix(db.Config.Type, "sqlite") || db.Config.Type == "chai" {
keySet = fmt.Sprintf("CREATE INDEX \"%s\" ON \"%s\" (\"%s\")", keyName, table.Name, field.Name)
} else if isPostgres {
keySet = fmt.Sprintf("CREATE INDEX \"%s\" ON \"%s\" (\"%s\")", keyName, table.Name, field.Name)
} else {
keySet = fmt.Sprintf("KEY "+db.Quote("%s")+" ("+db.Quote("%s")+") COMMENT '%s'", keyName, field.Name, field.Comment)
}
keySets = append(keySets, keySet)
keySetBy[keyName] = keySet
}
}
fieldSets = append(fieldSets, field.Desc)
}
var tableInfo map[string]any
if strings.HasPrefix(db.Config.Type, "sqlite") {
tableInfo = db.Query("SELECT \"name\", \"sql\" FROM \"sqlite_master\" WHERE \"type\"='table' AND \"name\"='" + table.Name + "'").MapOnR1()
} else if db.Config.Type == "chai" {
tableInfo = db.Query("SELECT \"name\", \"sql\" FROM \"__chai_catalog\" WHERE \"type\"='table' AND \"name\"='" + table.Name + "'").MapOnR1()
} else if isPostgres {
tableInfo = db.Query("SELECT tablename name FROM pg_tables WHERE schemaname='public' AND tablename='" + table.Name + "'").MapOnR1()
} else {
tableInfo = db.Query("SELECT TABLE_NAME name, TABLE_COMMENT comment FROM information_schema.TABLES WHERE TABLE_SCHEMA='" + db.Config.DB + "' AND TABLE_NAME='" + table.Name + "'").MapOnR1()
}
if tableInfo["name"] != nil && tableInfo["name"] != "" {
oldFieldList := make([]*tableFieldDesc, 0)
oldFields := make(map[string]*tableFieldDesc)
oldIndexes := make(map[string]string)
oldIndexInfos := make([]*tableKeyDesc, 0)
oldComments := map[string]string{}
if strings.HasPrefix(db.Config.Type, "sqlite") {
tmpFields := []struct {
Name string
Type string
Notnull bool
Dflt_value any
Pk bool
}{}
db.Query("PRAGMA table_info(" + db.Quote(table.Name) + ")").To(&tmpFields)
for _, f := range tmpFields {
oldFieldList = append(oldFieldList, &tableFieldDesc{
Field: f.Name,
Type: f.Type,
Null: cast.If(f.Notnull, "NO", "YES"),
Key: cast.If(f.Pk, "PRI", ""),
Default: cast.String(f.Dflt_value),
})
}
tmpIndexes := []struct {
Name string
Unique bool
Origin string
Partial int
}{}
db.Query("PRAGMA index_list(" + db.Quote(table.Name) + ")").To(&tmpIndexes)
for _, i := range tmpIndexes {
tmpIndexInfo := []struct {
Name string
Seqno int
Cid int
}{}
db.Query("PRAGMA index_info(" + db.Quote(i.Name) + ")").To(&tmpIndexInfo)
if len(tmpIndexInfo) > 0 {
oldIndexInfos = append(oldIndexInfos, &tableKeyDesc{
Key_name: i.Name,
Column_name: tmpIndexInfo[0].Name,
})
}
}
} else if isPostgres {
tmpFields := []struct {
Column_name string
Data_type string
Is_nullable string
Column_default string
}{}
db.Query("SELECT column_name, data_type, is_nullable, column_default FROM information_schema.columns WHERE table_schema='public' AND table_name='" + table.Name + "'").To(&tmpFields)
for _, f := range tmpFields {
oldFieldList = append(oldFieldList, &tableFieldDesc{
Field: f.Column_name,
Type: f.Data_type,
Null: f.Is_nullable,
Default: f.Column_default,
})
}
tmpIndexes := []struct {
Indexname string
Indexdef string
}{}
db.Query("SELECT indexname, indexdef FROM pg_indexes WHERE schemaname='public' AND tablename='" + table.Name + "'").To(&tmpIndexes)
for _, i := range tmpIndexes {
// Parse indexdef to get columns if needed, simplified for now
oldIndexes[i.Indexname] = "" // Placeholder
}
} else if db.Config.Type == "mysql" {
_ = db.Query("SELECT column_name, column_comment FROM information_schema.columns WHERE TABLE_SCHEMA='" + db.Config.DB + "' AND TABLE_NAME='" + table.Name + "'").ToKV(&oldComments)
_ = db.Query("DESC " + db.Quote(table.Name)).To(&oldFieldList)
_ = db.Query("SHOW INDEX FROM " + db.Quote(table.Name)).To(&oldIndexInfos)
}
if !isPostgres {
for _, indexInfo := range oldIndexInfos {
if oldIndexes[indexInfo.Key_name] == "" {
oldIndexes[indexInfo.Key_name] = indexInfo.Column_name
} else {
oldIndexes[indexInfo.Key_name] += " " + indexInfo.Column_name
}
}
}
prevFieldId := ""
for _, field := range oldFieldList {
if strings.HasPrefix(db.Config.Type, "sqlite") {
field.Type = "numeric"
} else {
field.After = prevFieldId
}
prevFieldId = field.Field
oldFields[field.Field] = field
}
actions := make([]string, 0)
for keyId := range oldIndexes {
if keyId != "PRIMARY" && !isPostgres && strings.ToLower(keySetFields[keyId]) != strings.ToLower(oldIndexes[keyId]) {
if strings.HasPrefix(db.Config.Type, "sqlite") {
actions = append(actions, "DROP INDEX "+db.Quote(keyId))
} else {
actions = append(actions, "DROP KEY "+db.Quote(keyId))
}
}
}
if oldIndexes["PRIMARY"] != "" && !isPostgres && strings.ToLower(oldIndexes["PRIMARY"]) != strings.ToLower(strings.Join(pks, " ")) {
if !strings.HasPrefix(db.Config.Type, "sqlite") {
actions = append(actions, "DROP PRIMARY KEY")
}
}
newFieldExists := map[string]bool{}
prevFieldId = ""
for _, field := range table.Fields {
newFieldExists[field.Name] = true
oldField := oldFields[field.Name]
if oldField == nil {
if strings.HasPrefix(db.Config.Type, "sqlite") {
actions = append(actions, "ALTER TABLE "+db.Quote(table.Name)+" ADD COLUMN "+field.Desc)
} else if isPostgres {
actions = append(actions, "ALTER TABLE "+db.Quote(table.Name)+" ADD COLUMN "+field.Desc)
} else {
actions = append(actions, "ADD COLUMN "+field.Desc)
}
} else {
if isPostgres {
// Postgres sync is more complex, skipped for brevity in this step but should be robust
continue
}
oldField.Type = strings.TrimSpace(strings.ReplaceAll(oldField.Type, " (", "("))
fixedOldDefault := oldField.Default
if fixedOldDefault == "CURRENT_TIMESTAMP" && strings.Contains(oldField.Extra, "on update CURRENT_TIMESTAMP") {
fixedOldDefault = "CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"
}
fixedOldNull := "NOT NULL"
if oldField.Null == "YES" {
fixedOldNull = "NULL"
}
if strings.ToLower(field.Type) != strings.ToLower(oldField.Type) || strings.ToLower(field.Default) != strings.ToLower(fixedOldDefault) || strings.ToLower(field.Null) != strings.ToLower(fixedOldNull) || (db.Config.Type == "mysql" && strings.ToLower(oldField.After) != strings.ToLower(prevFieldId)) || strings.ToLower(oldComments[field.Name]) != strings.ToLower(field.Comment) {
after := ""
if db.Config.Type == "mysql" {
if oldField.After != prevFieldId {
if prevFieldId == "" {
after = " FIRST"
} else {
after = " AFTER " + db.Quote(prevFieldId)
}
}
actions = append(actions, "CHANGE `"+field.Name+"` "+field.Desc+after)
}
}
}
if db.Config.Type == "mysql" {
prevFieldId = field.Name
}
}
for oldFieldName := range oldFields {
if !newFieldExists[oldFieldName] {
if !strings.HasPrefix(db.Config.Type, "sqlite") && !isPostgres {
actions = append(actions, "DROP COLUMN "+db.Quote(oldFieldName))
}
}
}
if db.Config.Type == "mysql" {
if len(pks) > 0 && strings.ToLower(oldIndexes["PRIMARY"]) != strings.ToLower(strings.Join(pks, " ")) {
actions = append(actions, "ADD PRIMARY KEY(`"+strings.Join(pks, "`,`")+"`)")
}
}
for keyId, keySet := range keySetBy {
if oldIndexes[keyId] == "" || (!isPostgres && strings.ToLower(oldIndexes[keyId]) != strings.ToLower(keySetFields[keyId])) {
if strings.HasPrefix(db.Config.Type, "sqlite") || isPostgres {
actions = append(actions, keySet)
} else {
actions = append(actions, "ADD "+keySet)
}
}
}
if db.Config.Type == "mysql" {
oldTableComment := cast.String(tableInfo["comment"])
if table.Comment != oldTableComment {
actions = append(actions, "COMMENT '"+table.Comment+"'")
}
}
if len(actions) == 0 {
goto SYNC_SHADOW
}
tx := db.Begin()
defer tx.CheckFinished()
var res *ExecResult
if strings.HasPrefix(db.Config.Type, "sqlite") || isPostgres {
for _, action := range actions {
res = tx.Exec(action)
if res.Error != nil {
break
}
}
} else {
sql := "ALTER TABLE " + db.Quote(table.Name) + " " + strings.Join(actions, "\n,") + ";"
res = tx.Exec(sql)
}
if res != nil && res.Error != nil {
_ = tx.Rollback()
return res.Error
}
_ = tx.Commit()
} else {
// 创建新表
if len(pks) > 0 {
if isPostgres {
// In Postgres, PK is often in CREATE TABLE
fieldSets = append(fieldSets, "PRIMARY KEY ("+db.Quotes(pks)+")")
} else {
fieldSets = append(fieldSets, "PRIMARY KEY ("+db.Quotes(pks)+")")
}
}
indexSets := make([]string, 0)
if strings.HasPrefix(db.Config.Type, "sqlite") || isPostgres {
for _, indexSql := range keySetBy {
indexSets = append(indexSets, indexSql)
}
} else {
for _, key := range keySets {
fieldSets = append(fieldSets, key)
}
}
sql := ""
if strings.HasPrefix(db.Config.Type, "sqlite") || isPostgres {
sql = fmt.Sprintf("CREATE TABLE %s (\n%s\n);", db.Quote(table.Name), strings.Join(fieldSets, ",\n"))
} else if db.Config.Type == "mysql" {
sql = fmt.Sprintf("CREATE TABLE `%s` (\n%s\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='%s';", table.Name, strings.Join(fieldSets, ",\n"), table.Comment)
}
tx := db.Begin()
defer tx.CheckFinished()
res := tx.Exec(sql)
if res.Error == nil && (strings.HasPrefix(db.Config.Type, "sqlite") || isPostgres) {
for _, indexSet := range indexSets {
r := tx.Exec(indexSet)
if r.Error != nil {
res = r
break
}
}
}
if res.Error != nil {
_ = tx.Rollback()
return res.Error
}
_ = tx.Commit()
}
SYNC_SHADOW:
if table.ShadowDelete && !strings.HasSuffix(table.Name, "_deleted") {
table.HasShadowTable = true
shadowTable := *table
shadowTable.Name = table.Name + "_deleted"
shadowTable.ShadowDelete = false
shadowTable.Fields = make([]TableField, 0, len(table.Fields)+1)
for _, f := range table.Fields {
f.Index = ""
f.IndexGroup = ""
f.Extra = ""
if strings.Contains(f.Type, "serial") {
if strings.Contains(f.Type, "bigserial") {
f.Type = "bigint"
} else {
f.Type = "int"
}
}
shadowTable.Fields = append(shadowTable.Fields, f)
}
// 自动为删除表增加删除时间字段,方便同步
shadowTable.Fields = append(shadowTable.Fields, TableField{
Name: "deleted_at",
Type: "datetime",
Default: "CURRENT_TIMESTAMP",
Null: "NOT NULL",
Comment: "删除时间",
})
return db.CheckTable(&shadowTable)
}
return nil
}
type tableFieldDesc struct {
Field string
Type string
Null string
Key string
Default string
Extra string
After string
}
type tableKeyDesc struct {
Key_name string
Column_name string
}

View File

@ -1,60 +0,0 @@
package db_test
import (
"os"
"testing"
"apigo.cc/go/db"
_ "modernc.org/sqlite"
)
func TestSchemaSync(t *testing.T) {
dbPath := "test_schema.db"
dbInst := db.GetDB("sqlite://"+dbPath, nil)
defer os.Remove(dbPath)
defer dbInst.Exec("DROP TABLE IF EXISTS test_table")
defer dbInst.Exec("DROP TABLE IF EXISTS test_table_deleted")
schema := `== Default ==
test_table SD // Test table with shadow delete
id AI // ID
name v50 U
autoVersion ubi // Version
status ti // Status
`
err := dbInst.Sync(schema)
if err != nil {
t.Fatal("Sync error:", err)
}
dbInst.Insert("test_table", map[string]any{"name": "test", "status": 1})
dbInst.Delete("test_table", "id=?", 1)
res := dbInst.Query("SELECT COUNT(*) FROM test_table_deleted")
if res.IntOnR1C1() != 1 {
t.Fatal("Shadow delete failed")
}
}
func TestAutoDetectShadow(t *testing.T) {
dbPath := "auto_detect.db"
dbInst := db.GetDB("sqlite://"+dbPath, nil)
defer os.Remove(dbPath)
defer dbInst.Exec("DROP TABLE IF EXISTS test_auto")
defer dbInst.Exec("DROP TABLE IF EXISTS test_auto_deleted")
// Manually create tables, DO NOT call Sync
dbInst.Exec("CREATE TABLE test_auto (id INTEGER PRIMARY KEY)")
dbInst.Exec("CREATE TABLE test_auto_deleted (id INTEGER PRIMARY KEY)")
dbInst.Insert("test_auto", map[string]any{"id": 1})
// This should trigger auto-detection and perform a shadow delete
dbInst.Delete("test_auto", "id=?", 1)
res := dbInst.Query("SELECT COUNT(*) FROM test_auto_deleted")
if res.IntOnR1C1() != 1 {
t.Fatal("Auto-detect shadow delete failed")
}
}

45
TEST.md
View File

@ -1,30 +1,23 @@
# @go/db 测试报告
# Test Results for @go/db
## 📊 概览
- **模块**: `apigo.cc/go/db`
- **总测试用例**: 5
- **通过**: 5
- **失败**: 0
- **编译状态**: 成功 (Success)
- **测试日期**: 2026-05-03
## 📊 Summary
- **Module**: `apigo.cc/go/db`
- **Total Tests**: 4
- **Passed**: 4
- **Failed**: 0
- **Build Status**: Success
- **Date**: 2026-05-03
## ✅ 详细详情
| 测试用例 | 状态 | 耗时 | 备注 |
## ✅ Details
| Test Case | Status | Duration | Notes |
| :--- | :--- | :--- | :--- |
| `TestMakeInsertSql` | 通过 | 0.00s | 验证 Struct 模型的 SQL 生成逻辑 |
| `TestBaseSelect` | 通过 | 0.00s | 验证结果绑定 (Struct, Map, 基础类型) |
| `TestInsertReplaceUpdateDelete` | 通过 | 0.01s | 验证 SQLite 下的 CRUD 基本操作 |
| `TestTransaction` | 通过 | 0.03s | 验证事务隔离、回滚与提交 |
| `TestSchemaSync` | 通过 | 0.01s | 验证 DSL 同步、影子删除、版本号乐观锁及泛型 API |
| `TestAutoRandomID` | 通过 | 0.01s | 验证 char(N) 主键的自动 ID 填充 |
| `TestMakeInsertSql` | PASS | 0.00s | Verified SQL generation logic for Struct models |
| `TestBaseSelect` | PASS | 0.00s | Verified Result binding (Struct, Map, Base types) |
| `TestInsertReplaceUpdateDelete` | PASS | 0.01s | Verified CRUD operations with SQLite |
| `TestTransaction` | PASS | 0.03s | Verified Transaction isolation and Rollback/Commit |
## 🚀 性能基准 (Benchmarks)
| 基准测试 | 迭代次数 | 耗时 | 内存分配 | 备注 |
| :--- | :--- | :--- | :--- | :--- |
| `BenchmarkForPool` | 172009 | 7384 ns/op | 1224 B/op (34 allocs) | 验证 SQLite 下的查询绑定性能 |
| `BenchmarkForPoolParallel` | 160250 | 6852 ns/op | 1296 B/op (35 allocs) | 验证高并发下的查询稳定性 |
## 🛠 环境
- **OS**: darwin (macOS)
- **Go Version**: 1.2x+
- **Primary Driver**: modernc.org/sqlite
## 🚀 Benchmarks
| Benchmark | Iterations | Time/op | Conn |
| :--- | :--- | :--- | :--- |
| `BenchmarkForPool` | - | - | Passed (Manual run verified pool reuse) |
| `BenchmarkForPoolParallel` | - | - | Passed (Manual run verified high concurrency) |

29
Tx.go
View File

@ -4,13 +4,11 @@ import (
"database/sql"
"errors"
"fmt"
"strings"
"time"
)
type Tx struct {
conn *sql.Tx
db *DB
lastSql *string
lastArgs []any
Error error
@ -166,33 +164,10 @@ func (tx *Tx) Update(table string, data any, conditions string, args ...any) *Ex
}
func (tx *Tx) Delete(table string, conditions string, args ...any) *ExecResult {
ts := tx.db.getTable(table)
where := ""
if conditions != "" {
where = " where " + conditions
conditions = " where " + conditions
}
if ts.HasShadowTable {
// Move to shadow table
colList := ""
if len(ts.Columns) > 0 {
quotedCols := make([]string, len(ts.Columns))
for i, c := range ts.Columns {
quotedCols[i] = tx.Quote(c)
}
colList = fmt.Sprintf(" (%s) select %s", strings.Join(quotedCols, ","), strings.Join(quotedCols, ","))
} else {
colList = " select *"
}
moveQuery := fmt.Sprintf("insert into %s%s from %s%s", tx.Quote(table+"_deleted"), colList, tx.Quote(table), where)
r := baseExec(nil, tx.conn, moveQuery, args...)
if r.Error != nil {
tx.logger.LogQueryError(r.Error.Error(), moveQuery, args, r.usedTime)
return r
}
}
query := fmt.Sprintf("delete from %s%s", tx.Quote(table), where)
query := fmt.Sprintf("delete from %s%s", tx.Quote(table), conditions)
tx.lastSql = &query
tx.lastArgs = args
r := baseExec(nil, tx.conn, query, args...)

View File

@ -1,55 +0,0 @@
package db_test
import (
"testing"
"apigo.cc/go/db"
_ "modernc.org/sqlite"
)
func TestSmartDelete(t *testing.T) {
dbInst := db.GetDB("sqlite://:memory:", nil)
// Create table and shadow table
dbInst.Exec("CREATE TABLE orders (id INTEGER PRIMARY KEY, item TEXT)")
dbInst.Exec("CREATE TABLE orders_deleted (id INTEGER PRIMARY KEY, item TEXT)")
t.Run("ShadowDelete", func(t *testing.T) {
dbInst.Exec("INSERT INTO orders (id, item) VALUES (1, 'Phone')")
res := dbInst.Delete("orders", "id = 1")
if res.Error != nil {
t.Fatalf("Delete failed: %v", res.Error)
}
if res.Changes() != 1 {
t.Errorf("Expected 1 change, got %d", res.Changes())
}
// Verify it's gone from main table
qr := dbInst.Query("SELECT COUNT(*) FROM orders WHERE id = 1")
count, _ := db.To[int](qr)
if count != 0 {
t.Errorf("Expected 0 records in main table, got %d", count)
}
// Verify it's in shadow table
qr2 := dbInst.Query("SELECT COUNT(*) FROM orders_deleted WHERE id = 1")
countDeleted, _ := db.To[int](qr2)
if countDeleted != 1 {
t.Errorf("Expected 1 record in shadow table, got %d", countDeleted)
}
})
t.Run("PhysicalDelete", func(t *testing.T) {
dbInst.Exec("CREATE TABLE logs (id INTEGER PRIMARY KEY, msg TEXT)")
dbInst.Exec("INSERT INTO logs (id, msg) VALUES (1, 'Login')")
dbInst.Delete("logs", "id = 1")
qr := dbInst.Query("SELECT COUNT(*) FROM logs WHERE id = 1")
count, _ := db.To[int](qr)
if count != 0 {
t.Errorf("Expected 0 records in logs, got %d", count)
}
})
}

View File

@ -1,48 +0,0 @@
package db_test
import (
"testing"
"apigo.cc/go/db"
_ "modernc.org/sqlite"
)
func TestGenericQuery(t *testing.T) {
dbInst := db.GetDB("sqlite://:memory:", nil)
if dbInst == nil {
t.Fatal("Failed to get DB")
}
dbInst.Exec("CREATE TABLE test_generic (id INTEGER PRIMARY KEY, name TEXT)")
dbInst.Exec("INSERT INTO test_generic (name) VALUES (?)", "Alice")
dbInst.Exec("INSERT INTO test_generic (name) VALUES (?)", "Bob")
t.Run("ToSlice", func(t *testing.T) {
type Item struct {
Id int
Name string
}
res := dbInst.Query("SELECT id, name FROM test_generic ORDER BY id")
items, err := db.ToSlice[Item](res)
if err != nil {
t.Fatalf("ToSlice failed: %v", err)
}
if len(items) != 2 {
t.Errorf("Expected 2 items, got %d", len(items))
}
if items[0].Name != "Alice" || items[1].Name != "Bob" {
t.Errorf("Incorrect data: %+v", items)
}
})
t.Run("ToValue", func(t *testing.T) {
res := dbInst.Query("SELECT name FROM test_generic WHERE id = ?", 1)
name, err := db.To[string](res)
if err != nil {
t.Fatalf("ToValue failed: %v", err)
}
if name != "Alice" {
t.Errorf("Expected Alice, got %s", name)
}
})
}

2
go.mod
View File

@ -8,7 +8,7 @@ require (
apigo.cc/go/convert v1.0.4
apigo.cc/go/crypto v1.0.4
apigo.cc/go/id v1.0.4
apigo.cc/go/log v1.0.1
apigo.cc/go/log v1.0.0
apigo.cc/go/rand v1.0.4
apigo.cc/go/safe v1.0.4
apigo.cc/go/shell v1.0.4

View File

@ -1,63 +0,0 @@
package db_test
import (
"os"
"testing"
"apigo.cc/go/db"
_ "modernc.org/sqlite"
)
func TestAutoRandomID(t *testing.T) {
dbPath := "id_test.db"
dbset := "sqlite://" + dbPath
defer os.Remove(dbPath)
dbInst := db.GetDB(dbset, nil)
// Create table with char(12) primary key
dbInst.Exec("CREATE TABLE test_id (id CHAR(12) PRIMARY KEY, name TEXT)")
t.Run("AutoFillID", func(t *testing.T) {
data := map[string]any{"name": "test1"}
res := dbInst.Insert("test_id", data)
if res.Error != nil {
t.Fatalf("Insert failed: %v", res.Error)
}
// Verify ID was generated
qr := dbInst.Query("SELECT id FROM test_id WHERE name='test1'")
idStr, _ := db.To[string](qr)
if len(idStr) != 12 {
t.Errorf("Expected ID length 12, got %d (%s)", len(idStr), idStr)
}
})
t.Run("DoNotOverwriteID", func(t *testing.T) {
manualID := "manual_id_12"
data := map[string]any{"id": manualID, "name": "test2"}
res := dbInst.Insert("test_id", data)
if res.Error != nil {
t.Fatalf("Insert failed: %v", res.Error)
}
qr := dbInst.Query("SELECT id FROM test_id WHERE name='test2'")
idStr, _ := db.To[string](qr)
if idStr != manualID {
t.Errorf("Expected ID %s, got %s", manualID, idStr)
}
})
t.Run("AutoFillEmptyID", func(t *testing.T) {
data := map[string]any{"id": "", "name": "test3"}
res := dbInst.Insert("test_id", data)
if res.Error != nil {
t.Fatalf("Insert failed: %v", res.Error)
}
qr := dbInst.Query("SELECT id FROM test_id WHERE name='test3'")
idStr, _ := db.To[string](qr)
if len(idStr) != 12 {
t.Errorf("Expected ID length 12, got %d (%s)", len(idStr), idStr)
}
})
}

View File

@ -1,27 +0,0 @@
package db_test
import (
"testing"
"apigo.cc/go/db"
_ "modernc.org/sqlite"
)
func TestTableProbing(t *testing.T) {
dbInst := db.GetDB("sqlite://:memory:", nil)
// Create a table with autoVersion
dbInst.Exec("CREATE TABLE table_with_ver (id INTEGER PRIMARY KEY, name TEXT, autoVersion BIGINT UNSIGNED)")
// Create a table with shadow table
dbInst.Exec("CREATE TABLE table_with_shadow (id INTEGER PRIMARY KEY, name TEXT)")
dbInst.Exec("CREATE TABLE table_with_shadow_deleted (id INTEGER PRIMARY KEY, name TEXT)")
t.Run("ProbeAutoVersion", func(t *testing.T) {
// We need a way to access getTable or check its effect.
// Since getTable is private, we can't call it directly from _test package.
// But we can check if it exists in the struct if we move test to 'db' package or use reflection.
// Alternatively, we can just ensure it doesn't crash for now, and Feature 3/4 will use it.
// For now, let's just trigger it.
dbInst.Query("SELECT * FROM table_with_ver")
})
}

BIN
test.db Normal file

Binary file not shown.

View File

@ -1,90 +0,0 @@
package db_test
import (
"os"
"testing"
"apigo.cc/go/db"
_ "modernc.org/sqlite"
)
func TestVersionControl(t *testing.T) {
dbInst := db.GetDB("sqlite://:memory:", nil)
// Create table with autoVersion
dbInst.Exec("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, autoVersion BIGINT UNSIGNED)")
t.Run("InsertAutoVersion", func(t *testing.T) {
data := map[string]any{"id": 1, "name": "Alice"}
res := dbInst.Insert("users", data)
if res.Error != nil {
t.Fatalf("Insert failed: %v", res.Error)
}
// Verify version was injected
var ver int64
qr := dbInst.Query("SELECT autoVersion FROM users WHERE id = 1")
ver, _ = db.To[int64](qr)
if ver != 1 {
t.Errorf("Expected version 1, got %d", ver)
}
})
t.Run("UpdateOptimisticLock", func(t *testing.T) {
// First update
data := map[string]any{"name": "Alice Updated", "autoVersion": int64(1)}
res := dbInst.Update("users", data, "id = 1")
if res.Error != nil {
t.Fatalf("Update failed: %v", res.Error)
}
if res.Changes() != 1 {
t.Errorf("Expected 1 change, got %d", res.Changes())
}
// Verify version incremented
var ver int64
qr := dbInst.Query("SELECT autoVersion FROM users WHERE id = 1")
ver, _ = db.To[int64](qr)
if ver != 2 {
t.Errorf("Expected version 2, got %d", ver)
}
// Try update with old version (should fail to update any rows)
dataConflict := map[string]any{"name": "Conflict", "autoVersion": int64(1)}
resConflict := dbInst.Update("users", dataConflict, "id = 1")
if resConflict.Changes() != 0 {
t.Errorf("Expected 0 changes due to optimistic lock, got %d", resConflict.Changes())
}
})
}
func TestVersionInitialization(t *testing.T) {
dbPath := "init_test.db"
dbset := "sqlite://" + dbPath
defer os.Remove(dbPath)
dbInst := db.GetDB(dbset, nil)
dbInst.Exec("CREATE TABLE test_init (id INTEGER PRIMARY KEY, autoVersion BIGINT UNSIGNED)")
// Manually insert with a high version
dbInst.Exec("INSERT INTO test_init (id, autoVersion) VALUES (1, 100)")
// First insert via DB helper should pick up 101
data := map[string]any{"id": 2}
res := dbInst.Insert("test_init", data)
if res.Error != nil {
t.Fatalf("Insert failed: %v", res.Error)
}
ver, _ := db.To[int64](dbInst.Query("SELECT autoVersion FROM test_init WHERE id=2"))
if ver != 101 {
t.Errorf("Expected version 101, got %d", ver)
}
// Update should make it 102
dbInst.Update("test_init", map[string]any{"autoVersion": 101}, "id=2")
ver, _ = db.To[int64](dbInst.Query("SELECT autoVersion FROM test_init WHERE id=2"))
if ver != 102 {
t.Errorf("Expected version 102, got %d", ver)
}
}