tableDB/db.go

386 lines
10 KiB
Go

package tableDB
import (
"fmt"
"strings"
"apigo.cc/go/cast"
"apigo.cc/go/db"
"apigo.cc/go/log"
)
const SystemUserID = "_system"
const SystemSchema = `
== System ==
_Table SD // 核心表:存储所有表的元数据
id c10 PK
name v64 U // 表名
memo t // 备注
enableRLS b // 是否开启行级安全
settings o // 设置
isSecret b // 是否敏感表(不被索引)
createTime bi I // 创建时间
creator v64 // 创建者
updateTime bi I // 更新时间
updater v64 // 更新者
_Field SD // 核心表:存储所有字段的元数据
id c10 PK
tableId c10 I // 所属表 ID
name v64 // 字段名
type v32 // 字段类型
isIndex b // 是否索引
memo t // 备注
settings o // 设置
createTime bi I // 创建时间
creator v64 // 创建者
updateTime bi I // 更新时间
updater v64 // 更新者
_Policy SD // 核心表:访问策略
id c10 PK
userID c10 I // 策略拥有者
type v32 // 策略类型 (inherit, table)
targets o // 作用目标数组 (inherit: userIDs, table: tableNames)
action v16 // 动作 (read, write, full)
condition v1024 // SQL WHERE 条件片段
conditionArgs o // 条件对应的参数数组
createTime bi I // 创建时间
creator v64 // 创建者
updateTime bi I // 更新时间
updater v64 // 更新者
`
type Hooks struct {
OnCreatedTable func(table *TableSchema)
OnRemovedTable func(table *TableSchema)
OnUpdatedField func(table *TableSchema, field *FieldSchema)
OnRemovedField func(table *TableSchema, fieldName string)
OnUpdatingRow func(row map[string]any, table *TableSchema, fields []FieldSchema) error
OnUpdatedRows func(rows []map[string]any, table *TableSchema, fields []FieldSchema)
OnRemovedRows func(ids []string, table *TableSchema, fields []FieldSchema)
}
type TableDBUnauthorized struct {
base *db.DB
Hooks *Hooks
}
// TableDB wraps the base go/db implementation to provide high-level abstractions.
type TableDB struct {
base *db.DB
userID string
hooks *Hooks
}
// GetDB retrieves a configured database instance. Must call Auth() before use.
func GetDB(name string, logger *log.Logger, redis string) *TableDBUnauthorized {
baseDB := db.GetDB(name, logger)
if redis != "" {
baseDB.Config.Redis = redis
}
return &TableDBUnauthorized{
base: baseDB,
Hooks: &Hooks{},
}
}
// Bootstrap initializes the system tables metadata required for tableDB to function.
func (d *TableDBUnauthorized) Bootstrap() error {
err := d.base.Sync(SystemSchema)
if err != nil {
return err
}
sys := d.Auth(SystemUserID)
groups := db.ParseSchema(SystemSchema)
for _, group := range groups {
for _, table := range group.Tables {
tRecord := map[string]any{
"name": table.Name,
"memo": table.Comment,
}
res := sys.base.Query("SELECT id FROM `_Table` WHERE name = ?", table.Name)
var tid string
rows := res.MapResults()
if len(rows) > 0 {
tid = cast.String(rows[0]["id"])
tRecord["id"] = tid
}
_ = sys.Table("_Table").Set(tRecord)
if tid == "" {
res = sys.base.Query("SELECT id FROM `_Table` WHERE name = ?", table.Name)
rows = res.MapResults()
if len(rows) > 0 {
tid = cast.String(rows[0]["id"])
}
}
if tid != "" {
for _, field := range table.Fields {
fName := field.Name
if fName == "id" || fName == "createTime" || fName == "creator" || fName == "updateTime" || fName == "updater" {
continue
}
fRecord := map[string]any{
"tableId": tid,
"name": fName,
"type": field.Type,
"isIndex": cast.If(field.Index != "", 1, 0),
"memo": field.Comment,
}
fRes := sys.base.Query("SELECT id FROM `_Field` WHERE tableId = ? AND name = ?", tid, fName)
fRows := fRes.MapResults()
if len(fRows) > 0 {
fRecord["id"] = fRows[0]["id"]
}
_ = sys.Table("_Field").Set(fRecord)
}
}
}
}
return GlobalCache.Load(sys)
}
// Auth creates a new instance with the specified userID context.
func (d *TableDBUnauthorized) Auth(userID string) *TableDB {
return &TableDB{
base: d.base,
userID: userID,
hooks: d.Hooks,
}
}
// Auth creates a new instance with the specified userID context from an existing authorized instance.
func (d *TableDB) Auth(userID string) *TableDB {
return &TableDB{
base: d.base,
userID: userID,
hooks: d.hooks,
}
}
// Tables returns a list of tables that the current user has access to.
func (d *TableDB) Tables() ([]TableSchema, error) {
res, err := d.Table("_Table").List(nil)
if err != nil {
return nil, err
}
var tables []TableSchema
cast.Convert(&tables, res)
return tables, nil
}
// SetTable updates or creates the table metadata.
func (d *TableDB) SetTable(schema TableSchema) error {
record := map[string]any{
"name": schema.Name,
"memo": schema.Memo,
"enableRLS": schema.EnableRLS,
"settings": schema.Settings,
}
if schema.ID != "" {
record["id"] = schema.ID
}
return d.Table("_Table").Set(record)
}
// RemoveTable deletes a table.
func (d *TableDB) RemoveTable(name string) error {
tableRec := GlobalCache.GetTable(name)
if tableRec == nil {
return fmt.Errorf("table %s not found", name)
}
return d.Table("_Table").Remove(cast.String(tableRec["id"]))
}
// SetPolicy updates or creates an access policy.
func (d *TableDB) SetPolicy(policy PolicySchema) error {
if d.userID != SystemUserID {
if policy.Type == "inherit" {
return fmt.Errorf("only system user can set inherit policy")
}
if policy.Type == "table" {
// Only users with permission to the table can set type=table policy
for _, targetTable := range policy.Targets {
// Check if current user has 'full' access to the target table
pols := GlobalCache.GetFlatPolicies(d.userID, targetTable, "full")
hasFullAccess := false
for _, p := range pols {
if p.Condition == "" {
hasFullAccess = true
break
}
}
if !hasFullAccess {
return fmt.Errorf("permission denied to set policy for table %s", targetTable)
}
}
}
}
record := map[string]any{
"userID": policy.UserID,
"type": policy.Type,
"targets": policy.Targets,
"action": policy.Action,
"condition": policy.Condition,
"conditionArgs": policy.ConditionArgs,
"creator": d.userID,
}
if policy.ID != "" {
record["id"] = policy.ID
}
return d.Auth(SystemUserID).Table("_Policy").Set(record)
}
// ListPolicy retrieves policies based on filter.
func (d *TableDB) ListPolicy(filter map[string]any) ([]PolicySchema, error) {
res, err := d.Table("_Policy").List(filter)
if err != nil {
return nil, err
}
var policies []PolicySchema
cast.Convert(&policies, res)
return policies, nil
}
// Table returns an AI-friendly interface for multi-dimensional operations on a specific table.
func (d *TableDB) Table(name string) *Table {
return NewTable(name, d)
}
// GetRawDB returns the underlying apigo.cc/go/db.DB. Only allowed for SystemUserID.
func (d *TableDB) GetRawDB() (*db.DB, error) {
if d.userID != SystemUserID {
return nil, fmt.Errorf("permission denied for GetRawDB")
}
return d.base, nil
}
// buildQuery constructs a SQL query from a QueryRequest with strict identifier validation and auth filtering.
func (d *TableDB) buildQuery(t *Table, req QueryRequest) (string, []any, error) {
tableName := t.Name
if GlobalCache.GetTable(tableName) == nil {
return "", nil, fmt.Errorf("invalid table: %s", tableName)
}
fieldList := req.Select
if len(fieldList) == 0 {
fieldList = GlobalCache.GetValidFields(tableName)
}
var validatedSelect []string
for _, s := range fieldList {
if !GlobalCache.IsValidField(tableName, s) {
return "", nil, fmt.Errorf("invalid field %s in table %s", s, tableName)
}
validatedSelect = append(validatedSelect, "`"+s+"`")
}
fieldsStr := strings.Join(validatedSelect, ", ")
var sql strings.Builder
fmt.Fprintf(&sql, "SELECT %s FROM `%s` ", fieldsStr, tableName)
for _, join := range req.Joins {
if GlobalCache.GetTable(join.Table) == nil {
return "", nil, fmt.Errorf("invalid join table: %s", join.Table)
}
joinType := join.Type
if joinType == "" {
joinType = "LEFT"
}
jt := strings.ToUpper(joinType)
if jt != "LEFT" && jt != "INNER" && jt != "RIGHT" && jt != "FULL" && jt != "CROSS" {
return "", nil, fmt.Errorf("invalid join type: %s", joinType)
}
fmt.Fprintf(&sql, "%s JOIN `%s` ON %s ", jt, join.Table, join.On)
}
args := req.Args
whereStr := req.Where
// Apply auth filtering for the main table
var err error
whereStr, args, err = t.appendAuthAndConstraint(whereStr, args)
if err != nil {
return "", nil, err
}
if whereStr != "" {
sql.WriteString(" WHERE ")
sql.WriteString(whereStr)
}
orderBy := req.OrderBy
if orderBy == "" {
orderBy = "createTime DESC"
}
parts := strings.Fields(orderBy)
if len(parts) > 0 {
fieldName := parts[0]
if !GlobalCache.IsValidField(tableName, fieldName) {
return "", nil, fmt.Errorf("invalid order by field: %s", fieldName)
}
direction := ""
if len(parts) > 1 {
dir := strings.ToUpper(parts[1])
if dir == "ASC" || dir == "DESC" {
direction = " " + dir
} else {
return "", nil, fmt.Errorf("invalid order by direction: %s", parts[1])
}
}
fmt.Fprintf(&sql, " ORDER BY `%s` %s", fieldName, direction)
}
if req.Limit > 0 {
fmt.Fprintf(&sql, " LIMIT %d", req.Limit)
}
if req.Offset > 0 {
fmt.Fprintf(&sql, " OFFSET %d", req.Offset)
}
return sql.String(), args, nil
}
// buildWhere is a helper to convert a map of conditions into a SQL WHERE clause and args.
func buildWhere(filter map[string]any) (string, []any) {
if len(filter) == 0 {
return "", nil
}
var builder strings.Builder
var args []any
first := true
for k, v := range filter {
if !first {
builder.WriteString(" AND ")
}
first = false
k = strings.TrimSpace(k)
operator := "="
parts := strings.Split(k, " ")
if len(parts) > 1 {
k = parts[0]
operator = strings.Join(parts[1:], " ")
}
builder.WriteString(k)
builder.WriteString(" ")
builder.WriteString(operator)
builder.WriteString(" ?")
args = append(args, v)
}
return builder.String(), args
}