tableDB/cache.go

224 lines
5.6 KiB
Go

package tableDB
import (
"sync"
"time"
"apigo.cc/go/cast"
)
type SchemaCache struct {
Tables map[string]map[string]any // name -> table record
Fields map[string][]FieldSchema // tableId -> fields
ValidFieldsMap map[string]map[string]bool // tableName -> fieldName -> true
FlatPolicies map[string]map[string]map[string][]FlatPolicy // userID -> tableName -> action -> []FlatPolicy
lock sync.RWMutex
lastLoad time.Time
}
var GlobalCache = &SchemaCache{
Tables: make(map[string]map[string]any),
Fields: make(map[string][]FieldSchema),
ValidFieldsMap: make(map[string]map[string]bool),
FlatPolicies: make(map[string]map[string]map[string][]FlatPolicy),
}
func (c *SchemaCache) Load(tDB *TableDB) error {
dbInst := tDB.base
// Check if _Table exists first
res := dbInst.Query("SELECT name FROM sqlite_master WHERE type='table' AND name='_Table'")
if dbInst.Config.Type == "mysql" {
res = dbInst.Query("SELECT TABLE_NAME name FROM information_schema.TABLES WHERE TABLE_SCHEMA=? AND TABLE_NAME='_Table'", dbInst.Config.DB)
}
if res.Error != nil || res.MapOnR1()["name"] == nil {
return nil // System tables not yet created
}
sysApp := tDB.Auth(SystemUserID)
tables, err := sysApp.Table("_Table").List(nil)
if err != nil {
return err
}
fields, err := sysApp.Table("_Field").List(nil)
if err != nil {
return err
}
policies, err := sysApp.Table("_Policy").List(nil)
if err != nil {
return err
}
newTables := make(map[string]map[string]any)
for _, t := range tables {
newTables[cast.String(t["name"])] = t
}
newFields := make(map[string][]FieldSchema)
newValidFieldsMap := make(map[string]map[string]bool)
tableIdToName := make(map[string]string)
for _, t := range tables {
tableIdToName[cast.String(t["id"])] = cast.String(t["name"])
}
for _, f := range fields {
var fs FieldSchema
cast.Convert(&fs, f)
tid := fs.TableID
newFields[tid] = append(newFields[tid], fs)
tableName := tableIdToName[tid]
if tableName != "" {
if newValidFieldsMap[tableName] == nil {
newValidFieldsMap[tableName] = make(map[string]bool)
}
newValidFieldsMap[tableName][fs.Name] = true
}
}
// Flatten Policies
newFlatPolicies := make(map[string]map[string]map[string][]FlatPolicy)
type rawPolicy struct {
UserID string
Type string
Targets []string
Action string
Condition string
ConditionArgs []any
}
var rawPolicies []rawPolicy
cast.Convert(&rawPolicies, policies)
userToRaw := make(map[string][]rawPolicy)
for _, p := range rawPolicies {
userToRaw[p.UserID] = append(userToRaw[p.UserID], p)
}
var flatten func(userID string, visited map[string]bool) map[string]map[string][]FlatPolicy
flatten = func(userID string, visited map[string]bool) map[string]map[string][]FlatPolicy {
if visited[userID] {
return nil
}
visited[userID] = true
defer delete(visited, userID)
result := make(map[string]map[string][]FlatPolicy) // table -> action -> []FlatPolicy
addPolicy := func(table, action string, fp FlatPolicy) {
if result[table] == nil {
result[table] = make(map[string][]FlatPolicy)
}
result[table][action] = append(result[table][action], fp)
if action == "full" {
result[table]["read"] = append(result[table]["read"], fp)
result[table]["write"] = append(result[table]["write"], fp)
}
}
for _, p := range userToRaw[userID] {
if p.Type == "table" {
cond := p.Condition
if cond == "null" {
cond = ""
}
var condArgs []any
for _, arg := range p.ConditionArgs {
if arg != nil {
condArgs = append(condArgs, arg)
}
}
for _, targetTable := range p.Targets {
addPolicy(targetTable, p.Action, FlatPolicy{
Condition: cond,
ConditionArgs: condArgs,
})
}
} else if p.Type == "inherit" {
for _, parentID := range p.Targets {
parentFlat := flatten(parentID, visited)
for table, actions := range parentFlat {
for action, fps := range actions {
if result[table] == nil {
result[table] = make(map[string][]FlatPolicy)
}
result[table][action] = append(result[table][action], fps...)
}
}
}
}
}
return result
}
for userID := range userToRaw {
newFlatPolicies[userID] = flatten(userID, make(map[string]bool))
}
c.lock.Lock()
defer c.lock.Unlock()
c.Tables = newTables
c.Fields = newFields
c.ValidFieldsMap = newValidFieldsMap
c.FlatPolicies = newFlatPolicies
c.lastLoad = time.Now()
return nil
}
func (c *SchemaCache) IsValidField(tableName, fieldName string) bool {
c.lock.RLock()
defer c.lock.RUnlock()
if c.ValidFieldsMap[tableName] == nil {
return false
}
return c.ValidFieldsMap[tableName][fieldName]
}
func (c *SchemaCache) GetFlatPolicies(userID, tableName, action string) []FlatPolicy {
c.lock.RLock()
defer c.lock.RUnlock()
userPol := c.FlatPolicies[userID]
if userPol == nil {
return nil
}
tablePol := userPol[tableName]
if tablePol == nil {
return nil
}
return tablePol[action]
}
func (c *SchemaCache) GetTable(name string) map[string]any {
c.lock.RLock()
defer c.lock.RUnlock()
return c.Tables[name]
}
func (c *SchemaCache) GetFields(tableID string) []FieldSchema {
c.lock.RLock()
defer c.lock.RUnlock()
return c.Fields[tableID]
}
func (c *SchemaCache) GetValidFields(tableName string) []string {
c.lock.RLock()
defer c.lock.RUnlock()
table := c.Tables[tableName]
if table == nil {
return nil
}
tid := cast.String(table["id"])
fields := c.Fields[tid]
var names []string
for _, f := range fields {
names = append(names, f.Name)
}
return names
}