tableDB/db.go

313 lines
8.0 KiB
Go
Raw Normal View History

package tableDB
import (
"fmt"
"strings"
"time"
"apigo.cc/go/cast"
"apigo.cc/go/db"
"apigo.cc/go/log"
)
const SystemUserID = "_system"
type Hooks struct {
OnCreatedTable func(tableName string, record map[string]any)
OnRemovedTable func(tableName string)
OnUpdatedField func(tableId, fieldName string, record map[string]any)
OnRemovedField func(tableId, fieldName string)
OnUpdatingRow func(tableName string, row map[string]any) error
OnUpdatedRows func(tableName string, count int)
OnRemovedRows func(tableName string, ids []string)
}
// TableDB wraps the base go/db implementation to provide high-level abstractions.
type TableDB struct {
base *db.DB
userID string
Hooks *Hooks
}
type App = TableDB
// GetDB retrieves a configured database instance.
func GetDB(name string, logger *log.Logger) *TableDB {
baseDB := db.GetDB(name, logger)
return &TableDB{
base: baseDB,
userID: SystemUserID,
Hooks: &Hooks{},
}
}
// Auth creates a new instance with the specified userID context.
func (d *TableDB) Auth(userID string) *App {
return &TableDB{
base: d.base,
userID: userID,
Hooks: d.Hooks,
}
}
// SyncSchema automatically applies the DSL schema to the underlying database.
func (d *TableDB) SyncSchema(schemaDSL string) error {
// 1. Auto-inject autoIndex and ensure id c10 for all tables in DSL
schemaDSL = injectUndergroundRules(schemaDSL)
// 2. Sync to actual DB
err := d.base.Sync(schemaDSL)
if err != nil {
return err
}
// 3. Update _Table and _Field metadata
res := d.base.Query("SELECT name FROM sqlite_master WHERE type='table' AND name='_Table'")
if d.base.Config.Type == "mysql" {
res = d.base.Query("SELECT TABLE_NAME name FROM information_schema.TABLES WHERE TABLE_SCHEMA=? AND TABLE_NAME='_Table'", d.base.Config.DB)
}
if res.Error == nil && res.MapOnR1()["name"] != nil {
groups := db.ParseSchema(schemaDSL)
for _, group := range groups {
for _, table := range group.Tables {
// Upsert _Table
tRecord := map[string]any{
"name": table.Name,
"memo": table.Comment,
"createTime": time.Now().UnixMilli(),
}
existingTable, _ := d.Table("_Table").List(map[string]any{"name": table.Name})
var tid string
if len(existingTable) > 0 {
tid = cast.String(existingTable[0]["id"])
tRecord["id"] = tid
}
_ = d.Table("_Table").Set(tRecord)
if tid == "" {
newTable, _ := d.Table("_Table").List(map[string]any{"name": table.Name})
if len(newTable) > 0 {
tid = cast.String(newTable[0]["id"])
}
}
if tid != "" {
// Update _Field
for _, field := range table.Fields {
fRecord := map[string]any{
"tableId": tid,
"name": field.Name,
"type": field.Type,
"isIndex": cast.If(field.Index != "", 1, 0),
"memo": field.Comment,
"createTime": time.Now().UnixMilli(),
}
existingField, _ := d.Table("_Field").List(map[string]any{"tableId": tid, "name": field.Name})
if len(existingField) > 0 {
fRecord["id"] = existingField[0]["id"]
}
_ = d.Table("_Field").Set(fRecord)
}
}
}
}
}
// 4. Reload cache
return GlobalCache.Load(d)
}
func injectUndergroundRules(dsl string) string {
lines := strings.Split(dsl, "\n")
var result []string
var currentTable string
var hasAutoIndex bool
for i := 0; i < len(lines); i++ {
line := lines[i]
trimmed := strings.TrimSpace(line)
if trimmed == "" || strings.HasPrefix(trimmed, "#") || strings.HasPrefix(trimmed, "==") {
if currentTable != "" && !hasAutoIndex {
result = append(result, " autoIndex bi AI")
}
result = append(result, line)
currentTable = ""
hasAutoIndex = false
continue
}
if !strings.HasPrefix(line, " ") && !strings.HasPrefix(line, "\t") {
if currentTable != "" && !hasAutoIndex {
result = append(result, " autoIndex bi AI")
}
currentTable = trimmed
hasAutoIndex = false
result = append(result, line)
} else {
if strings.Contains(trimmed, "autoIndex") {
hasAutoIndex = true
}
if strings.HasPrefix(trimmed, "id ") {
if !strings.Contains(trimmed, "c10") {
newField := " id c10 U"
if strings.Contains(line, "//") {
newField += " //" + strings.SplitN(line, "//", 2)[1]
}
line = newField
} else if strings.Contains(trimmed, "PK") {
line = strings.Replace(line, "PK", "U", 1)
}
}
result = append(result, line)
}
}
if currentTable != "" && !hasAutoIndex {
result = append(result, " autoIndex bi AI")
}
return strings.Join(result, "\n")
}
// Table returns an AI-friendly interface for multi-dimensional operations on a specific table.
func (d *TableDB) Table(name string) *Table {
return NewTable(name, d)
}
// Base returns the underlying apigo.cc/go/db.DB for raw queries if needed.
func (d *TableDB) Base() *db.DB {
return d.base
}
// Query performs a structured query.
func (d *TableDB) Query(req QueryRequest) ([]map[string]any, error) {
sql, args, err := d.BuildQuery(req)
if err != nil {
return nil, err
}
res := d.base.Query(sql, args...)
if res.Error != nil {
return nil, res.Error
}
return res.MapResults(), nil
}
// BuildQuery constructs a SQL query from a QueryRequest with strict identifier validation.
func (d *TableDB) BuildQuery(req QueryRequest) (string, []any, error) {
if GlobalCache.GetTable(req.Table) == nil {
return "", nil, fmt.Errorf("invalid table: %s", req.Table)
}
fields := "*"
if len(req.Select) > 0 {
validFields := GlobalCache.GetValidFields(req.Table)
fieldMap := make(map[string]bool)
for _, f := range validFields {
fieldMap[f] = true
}
var validatedSelect []string
for _, s := range req.Select {
if !fieldMap[s] {
return "", nil, fmt.Errorf("invalid field %s in table %s", s, req.Table)
}
validatedSelect = append(validatedSelect, "`"+s+"`")
}
fields = strings.Join(validatedSelect, ", ")
}
var sql strings.Builder
fmt.Fprintf(&sql, "SELECT %s FROM `%s` ", fields, req.Table)
for _, join := range req.Joins {
if GlobalCache.GetTable(join.Table) == nil {
return "", nil, fmt.Errorf("invalid join table: %s", join.Table)
}
joinType := join.Type
if joinType == "" {
joinType = "LEFT"
}
jt := strings.ToUpper(joinType)
if jt != "LEFT" && jt != "INNER" && jt != "RIGHT" && jt != "FULL" && jt != "CROSS" {
return "", nil, fmt.Errorf("invalid join type: %s", joinType)
}
fmt.Fprintf(&sql, "%s JOIN `%s` ON %s ", jt, join.Table, join.On)
}
args := req.Args
if req.Where != "" {
sql.WriteString(" WHERE ")
sql.WriteString(req.Where)
}
if req.OrderBy != "" {
parts := strings.Fields(req.OrderBy)
if len(parts) > 0 {
fieldName := parts[0]
validFields := GlobalCache.GetValidFields(req.Table)
found := false
for _, f := range validFields {
if f == fieldName {
found = true
break
}
}
if !found {
return "", nil, fmt.Errorf("invalid order by field: %s", fieldName)
}
direction := ""
if len(parts) > 1 {
dir := strings.ToUpper(parts[1])
if dir == "ASC" || dir == "DESC" {
direction = " " + dir
} else {
return "", nil, fmt.Errorf("invalid order by direction: %s", parts[1])
}
}
fmt.Fprintf(&sql, " ORDER BY `%s` %s", fieldName, direction)
}
}
if req.Limit > 0 {
fmt.Fprintf(&sql, " LIMIT %d", req.Limit)
}
if req.Offset > 0 {
fmt.Fprintf(&sql, " OFFSET %d", req.Offset)
}
return sql.String(), args, nil
}
// BuildWhere is a helper to convert a map of conditions into a SQL WHERE clause and args.
func buildWhere(filter map[string]any) (string, []any) {
if len(filter) == 0 {
return "", nil
}
var builder strings.Builder
var args []any
first := true
for k, v := range filter {
if !first {
builder.WriteString(" AND ")
}
first = false
k = strings.TrimSpace(k)
operator := "="
parts := strings.Split(k, " ")
if len(parts) > 1 {
k = parts[0]
operator = strings.Join(parts[1:], " ")
}
builder.WriteString(k)
builder.WriteString(" ")
builder.WriteString(operator)
builder.WriteString(" ?")
args = append(args, v)
}
return builder.String(), args
}