381 lines
10 KiB
Go
381 lines
10 KiB
Go
package tableDB
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"apigo.cc/go/cast"
|
|
"apigo.cc/go/db"
|
|
"apigo.cc/go/log"
|
|
)
|
|
|
|
const SystemUserID = "_system"
|
|
|
|
const SystemSchema = `
|
|
== System ==
|
|
_Table SD // 核心表:存储所有表的元数据
|
|
id c10 PK
|
|
name v64 U // 表名
|
|
memo t // 备注
|
|
enableRLS b // 是否开启行级安全
|
|
settings o // 设置
|
|
createTime bi // 创建时间
|
|
creator v64 // 创建者
|
|
|
|
_Field SD // 核心表:存储所有字段的元数据
|
|
id c10 PK
|
|
tableId c10 I // 所属表 ID
|
|
name v64 // 字段名
|
|
type v32 // 字段类型
|
|
isIndex b // 是否索引
|
|
memo t // 备注
|
|
settings o // 设置
|
|
createTime bi // 创建时间
|
|
creator v64 // 创建者
|
|
|
|
_Policy SD // 核心表:访问策略
|
|
id c10 PK
|
|
userID c10 I // 策略拥有者
|
|
type v32 // 策略类型 (inherit, table)
|
|
targets o // 作用目标数组 (inherit: userIDs, table: tableNames)
|
|
action v16 // 动作 (read, write, full)
|
|
condition v1024 // SQL WHERE 条件片段
|
|
conditionArgs o // 条件对应的参数数组
|
|
createTime bi // 创建时间
|
|
creator v64 // 创建者
|
|
`
|
|
|
|
type Hooks struct {
|
|
OnCreatedTable func(tableName string, record map[string]any)
|
|
OnRemovedTable func(tableName string)
|
|
OnUpdatedField func(tableId, fieldName string, record map[string]any)
|
|
OnRemovedField func(tableId, fieldName string)
|
|
OnUpdatingRow func(tableName string, row map[string]any) error
|
|
OnUpdatedRows func(tableName string, count int)
|
|
OnRemovedRows func(tableName string, ids []string)
|
|
}
|
|
|
|
type TableDBUnauthorized struct {
|
|
base *db.DB
|
|
Hooks *Hooks
|
|
}
|
|
|
|
// TableDB wraps the base go/db implementation to provide high-level abstractions.
|
|
type TableDB struct {
|
|
base *db.DB
|
|
userID string
|
|
hooks *Hooks
|
|
}
|
|
|
|
// GetDB retrieves a configured database instance. Must call Auth() before use.
|
|
func GetDB(name string, logger *log.Logger) *TableDBUnauthorized {
|
|
baseDB := db.GetDB(name, logger)
|
|
return &TableDBUnauthorized{
|
|
base: baseDB,
|
|
Hooks: &Hooks{},
|
|
}
|
|
}
|
|
|
|
// Auth creates a new instance with the specified userID context.
|
|
func (d *TableDBUnauthorized) Auth(userID string) *TableDB {
|
|
return &TableDB{
|
|
base: d.base,
|
|
userID: userID,
|
|
hooks: d.Hooks,
|
|
}
|
|
}
|
|
|
|
// Auth creates a new instance with the specified userID context from an existing authorized instance.
|
|
func (d *TableDB) Auth(userID string) *TableDB {
|
|
return &TableDB{
|
|
base: d.base,
|
|
userID: userID,
|
|
hooks: d.hooks,
|
|
}
|
|
}
|
|
|
|
// Tables returns a list of tables that the current user has access to.
|
|
func (d *TableDB) Tables() ([]TableSchema, error) {
|
|
res, err := d.Table("_Table").List(nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var tables []TableSchema
|
|
cast.Convert(&tables, res)
|
|
return tables, nil
|
|
}
|
|
|
|
// SetTable updates or creates the table metadata.
|
|
func (d *TableDB) SetTable(schema TableSchema) error {
|
|
record := map[string]any{
|
|
"name": schema.Name,
|
|
"memo": schema.Memo,
|
|
"enableRLS": schema.EnableRLS,
|
|
"settings": schema.Settings,
|
|
}
|
|
if schema.ID != "" {
|
|
record["id"] = schema.ID
|
|
}
|
|
return d.Table("_Table").Set(record)
|
|
}
|
|
|
|
// RemoveTable deletes a table.
|
|
func (d *TableDB) RemoveTable(name string) error {
|
|
tableRec := GlobalCache.GetTable(name)
|
|
if tableRec == nil {
|
|
return fmt.Errorf("table %s not found", name)
|
|
}
|
|
return d.Table("_Table").Remove(cast.String(tableRec["id"]))
|
|
}
|
|
|
|
// SetPolicy updates or creates an access policy.
|
|
func (d *TableDB) SetPolicy(policy PolicySchema) error {
|
|
if d.userID != SystemUserID {
|
|
if policy.Type == "inherit" {
|
|
return fmt.Errorf("only system user can set inherit policy")
|
|
}
|
|
|
|
if policy.Type == "table" {
|
|
// Only users with permission to the table can set type=table policy
|
|
for _, targetTable := range policy.Targets {
|
|
// Check if current user has 'full' access to the target table
|
|
pols := GlobalCache.GetFlatPolicies(d.userID, targetTable, "full")
|
|
hasFullAccess := false
|
|
for _, p := range pols {
|
|
if p.Condition == "" {
|
|
hasFullAccess = true
|
|
break
|
|
}
|
|
}
|
|
if !hasFullAccess {
|
|
return fmt.Errorf("permission denied to set policy for table %s", targetTable)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
record := map[string]any{
|
|
"userID": policy.UserID,
|
|
"type": policy.Type,
|
|
"targets": policy.Targets,
|
|
"action": policy.Action,
|
|
"condition": policy.Condition,
|
|
"conditionArgs": policy.ConditionArgs,
|
|
"creator": d.userID,
|
|
}
|
|
if policy.ID != "" {
|
|
record["id"] = policy.ID
|
|
}
|
|
return d.Auth(SystemUserID).Table("_Policy").Set(record)
|
|
}
|
|
|
|
// ListPolicy retrieves policies based on filter.
|
|
func (d *TableDB) ListPolicy(filter map[string]any) ([]PolicySchema, error) {
|
|
res, err := d.Table("_Policy").List(filter)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var policies []PolicySchema
|
|
cast.Convert(&policies, res)
|
|
return policies, nil
|
|
}
|
|
|
|
// syncSchema automatically applies the DSL schema to the underlying database.
|
|
func (d *TableDB) syncSchema(schemaDSL string) error {
|
|
finalDSL := schemaDSL
|
|
if !strings.Contains(schemaDSL, "_Table") {
|
|
finalDSL = SystemSchema + "\n" + schemaDSL
|
|
}
|
|
|
|
err := d.base.Sync(finalDSL)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 2. Update _Table and _Field metadata
|
|
res := d.base.Query("SELECT name FROM sqlite_master WHERE type='table' AND name='_Table'")
|
|
if d.base.Config.Type == "mysql" {
|
|
res = d.base.Query("SELECT TABLE_NAME name FROM information_schema.TABLES WHERE TABLE_SCHEMA=? AND TABLE_NAME='_Table'", d.base.Config.DB)
|
|
}
|
|
if res.Error == nil && res.MapOnR1()["name"] != nil {
|
|
groups := db.ParseSchema(finalDSL)
|
|
for _, group := range groups {
|
|
for _, table := range group.Tables {
|
|
// Upsert _Table
|
|
tRecord := map[string]any{
|
|
"name": table.Name,
|
|
"memo": table.Comment,
|
|
"createTime": time.Now().UnixMilli(),
|
|
}
|
|
existingTable, _ := d.Table("_Table").List(map[string]any{"name": table.Name})
|
|
var tid string
|
|
if len(existingTable) > 0 {
|
|
tid = cast.String(existingTable[0]["id"])
|
|
tRecord["id"] = tid
|
|
}
|
|
_ = d.Table("_Table").Set(tRecord)
|
|
|
|
if tid == "" {
|
|
newTable, _ := d.Table("_Table").List(map[string]any{"name": table.Name})
|
|
if len(newTable) > 0 {
|
|
tid = cast.String(newTable[0]["id"])
|
|
}
|
|
}
|
|
|
|
if tid != "" {
|
|
// Update _Field
|
|
for _, field := range table.Fields {
|
|
fRecord := map[string]any{
|
|
"tableId": tid,
|
|
"name": field.Name,
|
|
"type": field.Type,
|
|
"isIndex": cast.If(field.Index != "", 1, 0),
|
|
"memo": field.Comment,
|
|
"createTime": time.Now().UnixMilli(),
|
|
}
|
|
existingField, _ := d.Table("_Field").List(map[string]any{"tableId": tid, "name": field.Name})
|
|
if len(existingField) > 0 {
|
|
fRecord["id"] = existingField[0]["id"]
|
|
}
|
|
_ = d.Table("_Field").Set(fRecord)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// 3. Reload cache
|
|
return GlobalCache.Load(d)
|
|
}
|
|
|
|
// Table returns an AI-friendly interface for multi-dimensional operations on a specific table.
|
|
func (d *TableDB) Table(name string) *Table {
|
|
return NewTable(name, d)
|
|
}
|
|
|
|
// GetRawDB returns the underlying apigo.cc/go/db.DB. Only allowed for SystemUserID.
|
|
func (d *TableDB) GetRawDB() (*db.DB, error) {
|
|
if d.userID != SystemUserID {
|
|
return nil, fmt.Errorf("permission denied for GetRawDB")
|
|
}
|
|
return d.base, nil
|
|
}
|
|
|
|
// buildQuery constructs a SQL query from a QueryRequest with strict identifier validation and auth filtering.
|
|
func (d *TableDB) buildQuery(tableName string, req QueryRequest) (string, []any, error) {
|
|
if GlobalCache.GetTable(tableName) == nil {
|
|
return "", nil, fmt.Errorf("invalid table: %s", tableName)
|
|
}
|
|
|
|
fields := "*"
|
|
if len(req.Select) > 0 {
|
|
var validatedSelect []string
|
|
for _, s := range req.Select {
|
|
if !GlobalCache.IsValidField(tableName, s) {
|
|
return "", nil, fmt.Errorf("invalid field %s in table %s", s, tableName)
|
|
}
|
|
validatedSelect = append(validatedSelect, "`"+s+"`")
|
|
}
|
|
fields = strings.Join(validatedSelect, ", ")
|
|
}
|
|
|
|
var sql strings.Builder
|
|
fmt.Fprintf(&sql, "SELECT %s FROM `%s` ", fields, tableName)
|
|
|
|
for _, join := range req.Joins {
|
|
if GlobalCache.GetTable(join.Table) == nil {
|
|
return "", nil, fmt.Errorf("invalid join table: %s", join.Table)
|
|
}
|
|
joinType := join.Type
|
|
if joinType == "" {
|
|
joinType = "LEFT"
|
|
}
|
|
jt := strings.ToUpper(joinType)
|
|
if jt != "LEFT" && jt != "INNER" && jt != "RIGHT" && jt != "FULL" && jt != "CROSS" {
|
|
return "", nil, fmt.Errorf("invalid join type: %s", joinType)
|
|
}
|
|
fmt.Fprintf(&sql, "%s JOIN `%s` ON %s ", jt, join.Table, join.On)
|
|
}
|
|
|
|
args := req.Args
|
|
whereStr := req.Where
|
|
|
|
// Apply auth filtering for the main table
|
|
dummyTable := &Table{Name: tableName, userID: d.userID, db: d.base}
|
|
var err error
|
|
whereStr, args, err = dummyTable.appendAuthAndConstraint(whereStr, args)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
if whereStr != "" {
|
|
sql.WriteString(" WHERE ")
|
|
sql.WriteString(whereStr)
|
|
}
|
|
|
|
if req.OrderBy != "" {
|
|
parts := strings.Fields(req.OrderBy)
|
|
if len(parts) > 0 {
|
|
fieldName := parts[0]
|
|
if !GlobalCache.IsValidField(tableName, fieldName) {
|
|
return "", nil, fmt.Errorf("invalid order by field: %s", fieldName)
|
|
}
|
|
|
|
direction := ""
|
|
if len(parts) > 1 {
|
|
dir := strings.ToUpper(parts[1])
|
|
if dir == "ASC" || dir == "DESC" {
|
|
direction = " " + dir
|
|
} else {
|
|
return "", nil, fmt.Errorf("invalid order by direction: %s", parts[1])
|
|
}
|
|
}
|
|
fmt.Fprintf(&sql, " ORDER BY `%s` %s", fieldName, direction)
|
|
}
|
|
}
|
|
|
|
if req.Limit > 0 {
|
|
fmt.Fprintf(&sql, " LIMIT %d", req.Limit)
|
|
}
|
|
if req.Offset > 0 {
|
|
fmt.Fprintf(&sql, " OFFSET %d", req.Offset)
|
|
}
|
|
|
|
return sql.String(), args, nil
|
|
}
|
|
|
|
// buildWhere is a helper to convert a map of conditions into a SQL WHERE clause and args.
|
|
func buildWhere(filter map[string]any) (string, []any) {
|
|
if len(filter) == 0 {
|
|
return "", nil
|
|
}
|
|
|
|
var builder strings.Builder
|
|
var args []any
|
|
first := true
|
|
|
|
for k, v := range filter {
|
|
if !first {
|
|
builder.WriteString(" AND ")
|
|
}
|
|
first = false
|
|
|
|
k = strings.TrimSpace(k)
|
|
operator := "="
|
|
parts := strings.Split(k, " ")
|
|
if len(parts) > 1 {
|
|
k = parts[0]
|
|
operator = strings.Join(parts[1:], " ")
|
|
}
|
|
|
|
builder.WriteString(k)
|
|
builder.WriteString(" ")
|
|
builder.WriteString(operator)
|
|
builder.WriteString(" ?")
|
|
args = append(args, v)
|
|
}
|
|
|
|
return builder.String(), args
|
|
}
|