tableDB/cache.go

246 lines
6.4 KiB
Go

package tableDB
import (
"sync"
"time"
"apigo.cc/go/cast"
)
type SchemaCache struct {
Tables map[string]map[string]any // name -> table record
TableIDMap map[string]string // id -> name
Fields map[string][]FieldSchema // tableId -> fields
ValidFieldsMap map[string]map[string]bool // tableName -> fieldName -> true
FlatPolicies map[string]map[string]map[string][]FlatPolicy // userID -> tableName -> action -> []FlatPolicy
lock sync.RWMutex
lastLoad time.Time
}
var GlobalCache = &SchemaCache{
Tables: make(map[string]map[string]any),
TableIDMap: make(map[string]string),
Fields: make(map[string][]FieldSchema),
ValidFieldsMap: make(map[string]map[string]bool),
FlatPolicies: make(map[string]map[string]map[string][]FlatPolicy),
}
func (c *SchemaCache) Load(tDB *TableDB) error {
dbInst := tDB.base
// Check if _Table exists first
res := dbInst.Query("SELECT name FROM sqlite_master WHERE type='table' AND name='_Table'")
if dbInst.Config.Type == "mysql" {
res = dbInst.Query("SELECT TABLE_NAME name FROM information_schema.TABLES WHERE TABLE_SCHEMA=? AND TABLE_NAME='_Table'", dbInst.Config.DB)
}
if res.Error != nil || res.MapOnR1()["name"] == nil {
return nil // System tables not yet created
}
tablesRes := dbInst.Query("SELECT * FROM `_Table`")
if tablesRes.Error != nil {
return tablesRes.Error
}
tables := tablesRes.MapResults()
fieldsRes := dbInst.Query("SELECT * FROM `_Field`")
if fieldsRes.Error != nil {
return fieldsRes.Error
}
fields := fieldsRes.MapResults()
policiesRes := dbInst.Query("SELECT * FROM `_Policy`")
if policiesRes.Error != nil {
return policiesRes.Error
}
policies := policiesRes.MapResults()
newTables := make(map[string]map[string]any)
newTableIDMap := make(map[string]string)
for _, t := range tables {
name := cast.String(t["name"])
newTables[name] = t
newTableIDMap[cast.String(t["id"])] = name
}
newFields := make(map[string][]FieldSchema)
newValidFieldsMap := make(map[string]map[string]bool)
for _, f := range fields {
var fs FieldSchema
cast.Convert(&fs, f)
// If fs.TableID is empty, try to manually get it just in case
if fs.TableID == "" {
fs.TableID = cast.String(f["tableId"])
}
tid := fs.TableID
newFields[tid] = append(newFields[tid], fs)
tableName := newTableIDMap[tid]
if tableName != "" {
if newValidFieldsMap[tableName] == nil {
newValidFieldsMap[tableName] = make(map[string]bool)
}
newValidFieldsMap[tableName][fs.Name] = true
}
}
// Flatten Policies
newFlatPolicies := make(map[string]map[string]map[string][]FlatPolicy)
type rawPolicy struct {
UserID string
Type string
Targets []string
Action string
Condition string
ConditionArgs []any
}
var rawPolicies []rawPolicy
cast.Convert(&rawPolicies, policies)
userToRaw := make(map[string][]rawPolicy)
for _, p := range rawPolicies {
userToRaw[p.UserID] = append(userToRaw[p.UserID], p)
}
var flatten func(userID string, visited map[string]bool) map[string]map[string][]FlatPolicy
flatten = func(userID string, visited map[string]bool) map[string]map[string][]FlatPolicy {
if visited[userID] {
return nil
}
visited[userID] = true
defer delete(visited, userID)
result := make(map[string]map[string][]FlatPolicy) // table -> action -> []FlatPolicy
addPolicy := func(table, action string, fp FlatPolicy) {
if result[table] == nil {
result[table] = make(map[string][]FlatPolicy)
}
result[table][action] = append(result[table][action], fp)
if action == "full" {
result[table]["read"] = append(result[table]["read"], fp)
result[table]["write"] = append(result[table]["write"], fp)
}
}
for _, p := range userToRaw[userID] {
if p.Type == "table" {
cond := p.Condition
if cond == "null" {
cond = ""
}
var condArgs []any
for _, arg := range p.ConditionArgs {
if arg != nil {
condArgs = append(condArgs, arg)
}
}
for _, targetTable := range p.Targets {
addPolicy(targetTable, p.Action, FlatPolicy{
Condition: cond,
ConditionArgs: condArgs,
})
}
} else if p.Type == "inherit" {
for _, parentID := range p.Targets {
parentFlat := flatten(parentID, visited)
for table, actions := range parentFlat {
for action, fps := range actions {
if result[table] == nil {
result[table] = make(map[string][]FlatPolicy)
}
result[table][action] = append(result[table][action], fps...)
}
}
}
}
}
return result
}
for userID := range userToRaw {
newFlatPolicies[userID] = flatten(userID, make(map[string]bool))
}
c.lock.Lock()
defer c.lock.Unlock()
c.Tables = newTables
c.TableIDMap = newTableIDMap
c.Fields = newFields
c.ValidFieldsMap = newValidFieldsMap
c.FlatPolicies = newFlatPolicies
c.lastLoad = time.Now()
return nil
}
func (c *SchemaCache) IsValidField(tableName, fieldName string) bool {
c.lock.RLock()
defer c.lock.RUnlock()
if fieldName == "id" || fieldName == "createTime" || fieldName == "creator" || fieldName == "updateTime" || fieldName == "updater" {
return true
}
if c.ValidFieldsMap[tableName] == nil {
return false
}
return c.ValidFieldsMap[tableName][fieldName]
}
func (c *SchemaCache) GetFlatPolicies(userID, tableName, action string) []FlatPolicy {
c.lock.RLock()
defer c.lock.RUnlock()
userPol := c.FlatPolicies[userID]
if userPol == nil {
return nil
}
tablePol := userPol[tableName]
if tablePol == nil {
return nil
}
return tablePol[action]
}
func (c *SchemaCache) GetTable(name string) map[string]any {
c.lock.RLock()
defer c.lock.RUnlock()
return c.Tables[name]
}
func (c *SchemaCache) GetFields(tableID string) []FieldSchema {
c.lock.RLock()
defer c.lock.RUnlock()
return c.Fields[tableID]
}
func (c *SchemaCache) GetValidFields(tableName string) []string {
c.lock.RLock()
defer c.lock.RUnlock()
table := c.Tables[tableName]
if table == nil {
return nil
}
tid := cast.String(table["id"])
fields := c.Fields[tid]
var names []string
standardFields := map[string]bool{
"id": true,
"createTime": true,
"creator": true,
"updateTime": true,
"updater": true,
}
for _, f := range fields {
names = append(names, f.Name)
delete(standardFields, f.Name)
}
for f := range standardFields {
names = append(names, f)
}
return names
}