tableDB/table.go

674 lines
15 KiB
Go

package tableDB
import (
"fmt"
"strings"
"time"
"apigo.cc/go/cast"
"apigo.cc/go/db"
"apigo.cc/go/id"
)
// Table provides an AI-friendly interface for interacting with structured data or schema.
type Table struct {
Name string
userID string
db *db.DB
app *TableDB
constraint map[string]any
}
// NewTable creates a new Table instance. Handles ":Field" suffix.
func NewTable(name string, app *TableDB) *Table {
actualName := name
var constraint map[string]any
if strings.HasSuffix(name, ":Field") {
tableName := strings.TrimSuffix(name, ":Field")
actualName = "_Field"
// lookup table_id
res := app.base.Query("SELECT id FROM `_Table` WHERE name = ? LIMIT 1", tableName)
if res.Error == nil {
rec := res.MapOnR1()
if len(rec) > 0 {
constraint = map[string]any{"tableId": cast.String(rec["id"])}
}
}
}
return &Table{
Name: actualName,
userID: app.userID,
db: app.base,
app: app,
constraint: constraint,
}
}
func (t *Table) checkAuth(id string, action string) error {
if t.userID == SystemUserID {
return nil
}
if action == "write" && (t.Name == "_Policy" || t.Name == "_Backup") {
return fmt.Errorf("permission denied for %s", t.Name)
}
tableRec := GlobalCache.GetTable(t.Name)
enableRLS := false
if tableRec != nil {
enableRLS = cast.Bool(tableRec["enableRLS"])
}
if !enableRLS && !strings.HasPrefix(t.Name, "_") {
return nil
}
policies := GlobalCache.GetFlatPolicies(t.userID, t.Name, action)
hasFullAccess := false
for _, p := range policies {
if p.Condition == "" {
hasFullAccess = true
break
}
}
if hasFullAccess {
return nil
}
// 构建合并查询:一次性判断 Creator 和 所有 Conditions
var authConditions []string
var authArgs []any
// 1. 追加 Creator 检查
hasCreator := false
if tableRec != nil {
tid := cast.String(tableRec["id"])
fields := GlobalCache.GetFields(tid)
for _, f := range fields {
if f.Name == "creator" {
hasCreator = true
break
}
}
}
if hasCreator {
authConditions = append(authConditions, "creator = ?")
authArgs = append(authArgs, t.userID)
}
// 2. 追加 Policy Conditions
for _, p := range policies {
if p.Condition != "" {
authConditions = append(authConditions, "("+p.Condition+")")
authArgs = append(authArgs, p.ConditionArgs...)
}
}
// 3. 组装并执行终极 1-RTT 查询
if len(authConditions) > 0 {
authPart := "(" + strings.Join(authConditions, " OR ") + ")"
checkSQL := fmt.Sprintf("SELECT 1 FROM `%s` WHERE id = ? AND %s LIMIT 1", t.Name, authPart)
finalArgs := append([]any{id}, authArgs...)
checkRes := t.db.Query(checkSQL, finalArgs...)
if len(checkRes.MapOnR1()) > 0 {
return nil // 验证通过!
}
}
return fmt.Errorf("permission denied for %s record %s", t.Name, id)
}
func (t *Table) appendAuthAndConstraint(whereStr string, args []any) (string, []any, error) {
if t.constraint != nil {
for k, v := range t.constraint {
if whereStr != "" {
whereStr += fmt.Sprintf(" AND %s = ?", k)
} else {
whereStr = fmt.Sprintf("%s = ?", k)
}
args = append(args, v)
}
}
if t.userID == SystemUserID {
return whereStr, args, nil
}
tableRec := GlobalCache.GetTable(t.Name)
enableRLS := false
if tableRec != nil {
enableRLS = cast.Bool(tableRec["enableRLS"])
}
if !enableRLS && !strings.HasPrefix(t.Name, "_") {
return whereStr, args, nil
}
policies := GlobalCache.GetFlatPolicies(t.userID, t.Name, "read")
hasFullAccess := false
for _, p := range policies {
if p.Condition == "" {
hasFullAccess = true
break
}
}
if hasFullAccess {
return whereStr, args, nil
}
// Build dynamic SQL
var authConditions []string
var authArgs []any
// Check creator field
hasCreator := false
if tableRec != nil {
tid := cast.String(tableRec["id"])
fields := GlobalCache.GetFields(tid)
for _, f := range fields {
if f.Name == "creator" {
hasCreator = true
break
}
}
}
if hasCreator {
authConditions = append(authConditions, "creator = ?")
authArgs = append(authArgs, t.userID)
}
for _, p := range policies {
if p.Condition != "" {
authConditions = append(authConditions, "("+p.Condition+")")
authArgs = append(authArgs, p.ConditionArgs...)
}
}
authPart := "0"
if len(authConditions) > 0 {
authPart = "(" + strings.Join(authConditions, " OR ") + ")"
args = append(args, authArgs...)
}
if whereStr != "" {
whereStr = "(" + whereStr + ") AND " + authPart
} else {
whereStr = authPart
}
return whereStr, args, nil
}
func (t *Table) reconstructAndSyncSchema() error {
tables := t.db.Query("SELECT * FROM `_Table`").MapResults()
fields := t.db.Query("SELECT * FROM `_Field` ORDER BY tableId").MapResults()
fieldMap := make(map[string][]map[string]any)
for _, f := range fields {
tid := cast.String(f["tableId"])
fieldMap[tid] = append(fieldMap[tid], f)
}
var sb strings.Builder
for _, tbl := range tables {
name := cast.String(tbl["name"])
if name == "" {
continue
}
tid := cast.String(tbl["id"])
tblFields := fieldMap[tid]
if len(tblFields) == 0 {
continue // Skip tables with no fields to avoid SQL errors
}
memo := cast.String(tbl["memo"])
sb.WriteString(name + " SD")
if memo != "" {
sb.WriteString(" //" + strings.ReplaceAll(memo, "\n", " "))
}
sb.WriteString("\n")
hasID := false
for _, f := range tblFields {
if cast.String(f["name"]) == "id" {
hasID = true
break
}
}
if !hasID {
sb.WriteString(" id c10 PK\n")
}
for _, f := range tblFields {
fname := cast.String(f["name"])
ftype := cast.String(f["type"])
if ftype == "" {
ftype = "v255"
}
isIndex := cast.Bool(f["isIndex"]) || cast.Int(f["isIndex"]) == 1
fmemo := cast.String(f["memo"])
sb.WriteString(" " + fname + " " + ftype)
if isIndex {
sb.WriteString(" I")
}
if fmemo != "" {
sb.WriteString(" //" + strings.ReplaceAll(fmemo, "\n", " "))
}
sb.WriteString("\n")
}
sb.WriteString("\n")
}
return t.db.Sync(sb.String())
}
// Set performs an upsert of one or more records.
func (t *Table) Set(data ...any) error {
if t.userID == "" {
return fmt.Errorf("authentication required")
}
metaTouched := false
for _, d := range data {
record := make(map[string]any)
cast.Convert(&record, d)
if t.constraint != nil {
for k, v := range t.constraint {
record[k] = v
}
}
if t.app.hooks.OnUpdatingRow != nil && t.Name != "_Table" && t.Name != "_Field" {
if err := t.app.hooks.OnUpdatingRow(t.Name, record); err != nil {
return err
}
}
idVal := record["id"]
var isInsert bool
if idVal == nil || cast.String(idVal) == "" {
record["id"] = t.db.NextID(t.Name)
if record["id"] == "" {
record["id"] = id.MakeID(10)
}
isInsert = true
} else {
idStr := cast.String(idVal)
res := t.db.Query(fmt.Sprintf("SELECT id FROM `%s` WHERE id = ? LIMIT 1", t.Name), idStr)
rec := res.MapOnR1()
if rec != nil && len(rec) > 0 {
if err := t.checkAuth(idStr, "write"); err != nil {
return err
}
isInsert = false
// Prevent overwriting CreateTime and Creator on update
delete(record, "createTime")
delete(record, "creator")
} else {
isInsert = true
}
}
var err error
if isInsert {
hasCreator := false
hasCreateTime := false
tableRec := GlobalCache.GetTable(t.Name)
if tableRec != nil {
tid := cast.String(tableRec["id"])
fields := GlobalCache.GetFields(tid)
for _, f := range fields {
if f.Name == "creator" {
hasCreator = true
}
if f.Name == "createTime" {
hasCreateTime = true
}
}
}
if hasCreateTime || t.Name == "_Table" || t.Name == "_Field" {
record["createTime"] = time.Now().UnixMilli()
}
if t.userID != SystemUserID {
if hasCreator || strings.HasPrefix(t.Name, "_") {
if t.Name == "_Policy" || t.Name == "_Backup" {
return fmt.Errorf("permission denied for %s", t.Name)
}
record["creator"] = t.userID
}
} else {
if hasCreator || strings.HasPrefix(t.Name, "_") {
if record["creator"] == nil {
record["creator"] = t.userID
}
}
}
err = t.db.Insert(t.Name, record).Error
} else {
err = t.db.Update(t.Name, record, "id = ?", cast.String(record["id"])).Error
}
if err != nil {
return err
}
if t.Name == "_Table" || t.Name == "_Field" || t.Name == "_Policy" {
metaTouched = true
if t.Name == "_Table" {
if isInsert && t.app.hooks.OnCreatedTable != nil {
t.app.hooks.OnCreatedTable(cast.String(record["name"]), record)
}
} else if t.Name == "_Field" {
if t.app.hooks.OnUpdatedField != nil {
t.app.hooks.OnUpdatedField(cast.String(record["tableId"]), cast.String(record["name"]), record)
}
}
} else {
if t.app.hooks.OnUpdatedRows != nil {
t.app.hooks.OnUpdatedRows(t.Name, 1)
}
}
}
if metaTouched {
if t.Name != "_Policy" {
_ = t.reconstructAndSyncSchema()
}
_ = GlobalCache.Load(t.app)
}
return nil
}
// SetField adds or updates one or more fields. If table doesn't exist, it will be created.
func (t *Table) SetField(fields ...FieldSchema) error {
if t.userID == "" {
return fmt.Errorf("authentication required")
}
// 1. Ensure table exists in _Table
tableRec := GlobalCache.GetTable(t.Name)
if tableRec == nil {
// Create table entry
err := t.app.SetTable(TableSchema{Name: t.Name})
if err != nil {
return err
}
tableRec = GlobalCache.GetTable(t.Name)
if tableRec == nil {
return fmt.Errorf("failed to create table entry for %s", t.Name)
}
// Automatically add default ID field to metadata
_ = t.app.Table(t.Name + ":Field").Set(map[string]any{
"name": "id",
"type": "c10",
"isIndex": 1,
"memo": "Primary Key",
})
}
// 2. Prepare field records
fieldTable := t.app.Table(t.Name + ":Field")
var fieldRecords []any
for _, f := range fields {
fRecord := map[string]any{
"name": f.Name,
"type": f.Type,
"isIndex": f.IsIndex,
"memo": f.Memo,
"settings": f.Settings,
}
if f.ID != "" {
fRecord["id"] = f.ID
} else {
// If ID is missing, try to find existing field to update
existing, _ := fieldTable.List(map[string]any{"name": f.Name})
if len(existing) > 0 {
fRecord["id"] = existing[0]["id"]
}
}
fieldRecords = append(fieldRecords, fRecord)
}
// 3. Batch Set fields
return fieldTable.Set(fieldRecords...)
}
// RemoveField deletes one or more fields by name.
func (t *Table) RemoveField(names ...string) error {
if t.userID == "" {
return fmt.Errorf("authentication required")
}
tableRec := GlobalCache.GetTable(t.Name)
if tableRec == nil {
return fmt.Errorf("table %s not found", t.Name)
}
fieldTable := t.app.Table(t.Name + ":Field")
for _, name := range names {
existing, _ := fieldTable.List(map[string]any{"name": name})
if len(existing) > 0 {
err := fieldTable.Remove(cast.String(existing[0]["id"]))
if err != nil {
return err
}
}
}
return nil
}
// Get retrieves a single record.
func (t *Table) Get(id string) (map[string]any, error) {
if t.userID == "" {
return nil, fmt.Errorf("authentication required")
}
if err := t.checkAuth(id, "read"); err != nil {
return nil, err
}
query := fmt.Sprintf("SELECT * FROM `%s` WHERE id = ? LIMIT 1", t.Name)
res := t.db.Query(query, id)
if res.Error != nil {
return nil, res.Error
}
record := res.MapOnR1()
if len(record) == 0 {
return nil, nil
}
return record, nil
}
// Remove deletes one or more records by ID.
func (t *Table) Remove(ids ...string) error {
if t.userID == "" {
return fmt.Errorf("authentication required")
}
for _, id := range ids {
if err := t.checkAuth(id, "write"); err != nil {
return err
}
var record map[string]any
if t.Name == "_Table" || t.Name == "_Field" || t.Name == "_Policy" {
res := t.db.Query(fmt.Sprintf("SELECT * FROM `%s` WHERE id = ?", t.Name), id)
record = res.MapOnR1()
}
res := t.db.Delete(t.Name, "id = ?", id)
if res.Error == nil {
if t.Name == "_Table" || t.Name == "_Field" || t.Name == "_Policy" {
if t.Name != "_Policy" {
_ = t.reconstructAndSyncSchema()
}
_ = GlobalCache.Load(t.app)
if t.Name == "_Table" && record != nil && t.app.hooks.OnRemovedTable != nil {
t.app.hooks.OnRemovedTable(cast.String(record["name"]))
} else if t.Name == "_Field" && record != nil && t.app.hooks.OnRemovedField != nil {
t.app.hooks.OnRemovedField(cast.String(record["tableId"]), cast.String(record["name"]))
}
} else {
if t.app.hooks.OnRemovedRows != nil {
t.app.hooks.OnRemovedRows(t.Name, []string{id})
}
}
} else {
return res.Error
}
}
return nil
}
// List retrieves multiple records.
func (t *Table) List(where any, args ...any) ([]map[string]any, error) {
if t.userID == "" {
return nil, fmt.Errorf("authentication required")
}
query := fmt.Sprintf("SELECT * FROM `%s` ", t.Name)
whereStr := ""
if where != nil {
switch v := where.(type) {
case string:
whereStr = v
case map[string]any:
whereStr, args = buildWhere(v)
}
}
var err error
whereStr, args, err = t.appendAuthAndConstraint(whereStr, args)
if err != nil {
return nil, err
}
if whereStr != "" {
query += " WHERE " + whereStr
}
res := t.db.Query(query, args...)
if res.Error != nil {
return nil, res.Error
}
return res.MapResults(), nil
}
// Query performs a structured query on the current table.
func (t *Table) Query(req QueryRequest) ([]map[string]any, error) {
if t.userID == "" {
return nil, fmt.Errorf("authentication required")
}
sql, args, err := t.app.buildQuery(t.Name, req)
if err != nil {
return nil, err
}
res := t.db.Query(sql, args...)
if res.Error != nil {
return nil, res.Error
}
return res.MapResults(), nil
}
// Count returns the number of records.
func (t *Table) Count(where any, args ...any) (int64, error) {
if t.userID == "" {
return 0, fmt.Errorf("authentication required")
}
query := fmt.Sprintf("SELECT COUNT(*) FROM `%s` ", t.Name)
whereStr := ""
if where != nil {
switch v := where.(type) {
case string:
whereStr = v
case map[string]any:
w, a := buildWhere(v)
whereStr = w
args = a
}
}
var err error
whereStr, args, err = t.appendAuthAndConstraint(whereStr, args)
if err != nil {
return 0, err
}
if whereStr != "" {
query += " WHERE " + whereStr
}
res := t.db.Query(query, args...)
if res.Error != nil {
return 0, res.Error
}
return res.IntOnR1C1(), nil
}
// CountBy returns counts grouped by a field.
func (t *Table) CountBy(field string) (map[any]int64, error) {
if t.userID == "" {
return nil, fmt.Errorf("authentication required")
}
query := fmt.Sprintf("SELECT `%s`, COUNT(*) as cnt FROM `%s` ", field, t.Name)
whereStr, args, err := t.appendAuthAndConstraint("", nil)
if err != nil {
return nil, err
}
if whereStr != "" {
query += " WHERE " + whereStr
}
query += fmt.Sprintf(" GROUP BY `%s` ", field)
res := t.db.Query(query, args...)
if res.Error != nil {
return nil, res.Error
}
rows := res.MapResults()
result := make(map[any]int64)
for _, row := range rows {
result[row[field]] = cast.Int64(row["cnt"])
}
return result, nil
}
// Fields returns field metadata.
func (t *Table) Fields() ([]FieldSchema, error) {
if t.userID == "" {
return nil, fmt.Errorf("authentication required")
}
tid := ""
tableRecord := GlobalCache.GetTable(t.Name)
if tableRecord != nil {
tid = cast.String(tableRecord["id"])
} else {
return nil, fmt.Errorf("table metadata not found in cache: %s", t.Name)
}
return GlobalCache.GetFields(tid), nil
}