package tableDB import ( "fmt" "strings" "time" "apigo.cc/go/cast" "apigo.cc/go/db" "apigo.cc/go/log" ) const SystemUserID = "_system" const SystemSchema = ` == System == _Table SD // 核心表:存储所有表的元数据 id c10 PK name v64 U // 表名 memo t // 备注 enableRLS b // 是否开启行级安全 settings o // 设置 createTime bi // 创建时间 creator v64 // 创建者 _Field SD // 核心表:存储所有字段的元数据 id c10 PK tableId c10 I // 所属表 ID name v64 // 字段名 type v32 // 字段类型 isIndex b // 是否索引 memo t // 备注 settings o // 设置 createTime bi // 创建时间 creator v64 // 创建者 _Policy SD // 核心表:访问策略 id c10 PK userID c10 I // 策略拥有者 type v32 // 策略类型 (inherit, table) targets o // 作用目标数组 (inherit: userIDs, table: tableNames) action v16 // 动作 (read, write, full) condition v1024 // SQL WHERE 条件片段 conditionArgs o // 条件对应的参数数组 createTime bi // 创建时间 creator v64 // 创建者 ` type Hooks struct { OnCreatedTable func(tableName string, record map[string]any) OnRemovedTable func(tableName string) OnUpdatedField func(tableId, fieldName string, record map[string]any) OnRemovedField func(tableId, fieldName string) OnUpdatingRow func(tableName string, row map[string]any) error OnUpdatedRows func(tableName string, count int) OnRemovedRows func(tableName string, ids []string) } type TableDBUnauthorized struct { base *db.DB Hooks *Hooks } // TableDB wraps the base go/db implementation to provide high-level abstractions. type TableDB struct { base *db.DB userID string hooks *Hooks } // GetDB retrieves a configured database instance. Must call Auth() before use. func GetDB(name string, logger *log.Logger) *TableDBUnauthorized { baseDB := db.GetDB(name, logger) return &TableDBUnauthorized{ base: baseDB, Hooks: &Hooks{}, } } // Auth creates a new instance with the specified userID context. func (d *TableDBUnauthorized) Auth(userID string) *TableDB { return &TableDB{ base: d.base, userID: userID, hooks: d.Hooks, } } // Auth creates a new instance with the specified userID context from an existing authorized instance. func (d *TableDB) Auth(userID string) *TableDB { return &TableDB{ base: d.base, userID: userID, hooks: d.hooks, } } // Tables returns a list of tables that the current user has access to. func (d *TableDB) Tables() ([]TableSchema, error) { res, err := d.Table("_Table").List(nil) if err != nil { return nil, err } var tables []TableSchema cast.Convert(&tables, res) return tables, nil } // SetTable updates or creates the table metadata. func (d *TableDB) SetTable(schema TableSchema) error { record := map[string]any{ "name": schema.Name, "memo": schema.Memo, "enableRLS": schema.EnableRLS, "settings": schema.Settings, } if schema.ID != "" { record["id"] = schema.ID } return d.Table("_Table").Set(record) } // RemoveTable deletes a table. func (d *TableDB) RemoveTable(name string) error { tableRec := GlobalCache.GetTable(name) if tableRec == nil { return fmt.Errorf("table %s not found", name) } return d.Table("_Table").Remove(cast.String(tableRec["id"])) } // SetPolicy updates or creates an access policy. func (d *TableDB) SetPolicy(policy PolicySchema) error { if d.userID != SystemUserID { if policy.Type == "inherit" { return fmt.Errorf("only system user can set inherit policy") } if policy.Type == "table" { // Only users with permission to the table can set type=table policy for _, targetTable := range policy.Targets { // Check if current user has 'full' access to the target table pols := GlobalCache.GetFlatPolicies(d.userID, targetTable, "full") hasFullAccess := false for _, p := range pols { if p.Condition == "" { hasFullAccess = true break } } if !hasFullAccess { return fmt.Errorf("permission denied to set policy for table %s", targetTable) } } } } record := map[string]any{ "userID": policy.UserID, "type": policy.Type, "targets": policy.Targets, "action": policy.Action, "condition": policy.Condition, "conditionArgs": policy.ConditionArgs, "creator": d.userID, } if policy.ID != "" { record["id"] = policy.ID } return d.Auth(SystemUserID).Table("_Policy").Set(record) } // ListPolicy retrieves policies based on filter. func (d *TableDB) ListPolicy(filter map[string]any) ([]PolicySchema, error) { res, err := d.Table("_Policy").List(filter) if err != nil { return nil, err } var policies []PolicySchema cast.Convert(&policies, res) return policies, nil } // syncSchema automatically applies the DSL schema to the underlying database. func (d *TableDB) syncSchema(schemaDSL string) error { finalDSL := schemaDSL if !strings.Contains(schemaDSL, "_Table") { finalDSL = SystemSchema + "\n" + schemaDSL } err := d.base.Sync(finalDSL) if err != nil { return err } // 2. Update _Table and _Field metadata res := d.base.Query("SELECT name FROM sqlite_master WHERE type='table' AND name='_Table'") if d.base.Config.Type == "mysql" { res = d.base.Query("SELECT TABLE_NAME name FROM information_schema.TABLES WHERE TABLE_SCHEMA=? AND TABLE_NAME='_Table'", d.base.Config.DB) } if res.Error == nil && res.MapOnR1()["name"] != nil { groups := db.ParseSchema(finalDSL) for _, group := range groups { for _, table := range group.Tables { // Upsert _Table tRecord := map[string]any{ "name": table.Name, "memo": table.Comment, "createTime": time.Now().UnixMilli(), } existingTable, _ := d.Table("_Table").List(map[string]any{"name": table.Name}) var tid string if len(existingTable) > 0 { tid = cast.String(existingTable[0]["id"]) tRecord["id"] = tid } _ = d.Table("_Table").Set(tRecord) if tid == "" { newTable, _ := d.Table("_Table").List(map[string]any{"name": table.Name}) if len(newTable) > 0 { tid = cast.String(newTable[0]["id"]) } } if tid != "" { // Update _Field for _, field := range table.Fields { fRecord := map[string]any{ "tableId": tid, "name": field.Name, "type": field.Type, "isIndex": cast.If(field.Index != "", 1, 0), "memo": field.Comment, "createTime": time.Now().UnixMilli(), } existingField, _ := d.Table("_Field").List(map[string]any{"tableId": tid, "name": field.Name}) if len(existingField) > 0 { fRecord["id"] = existingField[0]["id"] } _ = d.Table("_Field").Set(fRecord) } } } } } // 3. Reload cache return GlobalCache.Load(d) } // Table returns an AI-friendly interface for multi-dimensional operations on a specific table. func (d *TableDB) Table(name string) *Table { return NewTable(name, d) } // GetRawDB returns the underlying apigo.cc/go/db.DB. Only allowed for SystemUserID. func (d *TableDB) GetRawDB() (*db.DB, error) { if d.userID != SystemUserID { return nil, fmt.Errorf("permission denied for GetRawDB") } return d.base, nil } // buildQuery constructs a SQL query from a QueryRequest with strict identifier validation and auth filtering. func (d *TableDB) buildQuery(tableName string, req QueryRequest) (string, []any, error) { if GlobalCache.GetTable(tableName) == nil { return "", nil, fmt.Errorf("invalid table: %s", tableName) } fields := "*" if len(req.Select) > 0 { var validatedSelect []string for _, s := range req.Select { if !GlobalCache.IsValidField(tableName, s) { return "", nil, fmt.Errorf("invalid field %s in table %s", s, tableName) } validatedSelect = append(validatedSelect, "`"+s+"`") } fields = strings.Join(validatedSelect, ", ") } var sql strings.Builder fmt.Fprintf(&sql, "SELECT %s FROM `%s` ", fields, tableName) for _, join := range req.Joins { if GlobalCache.GetTable(join.Table) == nil { return "", nil, fmt.Errorf("invalid join table: %s", join.Table) } joinType := join.Type if joinType == "" { joinType = "LEFT" } jt := strings.ToUpper(joinType) if jt != "LEFT" && jt != "INNER" && jt != "RIGHT" && jt != "FULL" && jt != "CROSS" { return "", nil, fmt.Errorf("invalid join type: %s", joinType) } fmt.Fprintf(&sql, "%s JOIN `%s` ON %s ", jt, join.Table, join.On) } args := req.Args whereStr := req.Where // Apply auth filtering for the main table dummyTable := &Table{Name: tableName, userID: d.userID, db: d.base} var err error whereStr, args, err = dummyTable.appendAuthAndConstraint(whereStr, args) if err != nil { return "", nil, err } if whereStr != "" { sql.WriteString(" WHERE ") sql.WriteString(whereStr) } if req.OrderBy != "" { parts := strings.Fields(req.OrderBy) if len(parts) > 0 { fieldName := parts[0] if !GlobalCache.IsValidField(tableName, fieldName) { return "", nil, fmt.Errorf("invalid order by field: %s", fieldName) } direction := "" if len(parts) > 1 { dir := strings.ToUpper(parts[1]) if dir == "ASC" || dir == "DESC" { direction = " " + dir } else { return "", nil, fmt.Errorf("invalid order by direction: %s", parts[1]) } } fmt.Fprintf(&sql, " ORDER BY `%s` %s", fieldName, direction) } } if req.Limit > 0 { fmt.Fprintf(&sql, " LIMIT %d", req.Limit) } if req.Offset > 0 { fmt.Fprintf(&sql, " OFFSET %d", req.Offset) } return sql.String(), args, nil } // buildWhere is a helper to convert a map of conditions into a SQL WHERE clause and args. func buildWhere(filter map[string]any) (string, []any) { if len(filter) == 0 { return "", nil } var builder strings.Builder var args []any first := true for k, v := range filter { if !first { builder.WriteString(" AND ") } first = false k = strings.TrimSpace(k) operator := "=" parts := strings.Split(k, " ") if len(parts) > 1 { k = parts[0] operator = strings.Join(parts[1:], " ") } builder.WriteString(k) builder.WriteString(" ") builder.WriteString(operator) builder.WriteString(" ?") args = append(args, v) } return builder.String(), args }