db/Schema.go
2026-05-03 23:51:30 +08:00

697 lines
21 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package db
import (
"fmt"
"regexp"
"strings"
"apigo.cc/go/cast"
)
// SchemaGroup 描述表分组
type SchemaGroup struct {
Name string
Tables []*TableStruct
}
var fieldSpliter = regexp.MustCompile(`\s+`)
var wnMatcher = regexp.MustCompile(`^([a-zA-Z]+)([0-9]+)$`)
// ParseField 解析单行字段描述
func ParseField(line string) TableField {
lc := strings.SplitN(line, "//", 2)
comment := ""
if len(lc) == 2 {
line = strings.TrimSpace(lc[0])
comment = strings.TrimSpace(lc[1])
}
a := fieldSpliter.Split(line, 10)
field := TableField{
Name: a[0],
Type: "",
Index: "",
IndexGroup: "",
Default: "",
Comment: comment,
Null: "NULL",
Extra: "",
Desc: "",
IsVersion: false,
}
for i := 1; i < len(a); i++ {
wn := wnMatcher.FindStringSubmatch(a[i])
tag := a[i]
size := 0
if wn != nil {
tag = wn[1]
size = cast.Int(wn[2])
}
switch tag {
case "PK":
field.Index = "pk"
field.Null = "NOT NULL"
case "I":
field.Index = "index"
case "AI":
field.Extra = "AUTO_INCREMENT"
field.Index = "pk"
field.Null = "NOT NULL"
case "TI":
field.Index = "fulltext"
case "U":
field.Index = "unique"
case "ver":
field.IsVersion = true
field.Type = "bigint"
field.Default = "0"
field.Null = "NOT NULL"
case "ct":
field.Default = "CURRENT_TIMESTAMP"
case "ctu":
field.Default = "CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"
case "n":
field.Null = "NULL"
case "nn":
field.Null = "NOT NULL"
case "c":
field.Type = "char"
case "v":
field.Type = "varchar"
case "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9":
field.Type = "varchar"
case "dt":
field.Type = "datetime"
case "d":
field.Type = "date"
case "tm":
field.Type = "time"
case "i":
field.Type = "int"
case "ui":
field.Type = "int unsigned"
case "ti":
field.Type = "tinyint"
case "uti":
field.Type = "tinyint unsigned"
case "b":
field.Type = "tinyint unsigned"
case "bi":
field.Type = "bigint"
case "ubi":
field.Type = "bigint unsigned"
case "f":
field.Type = "float"
case "uf":
field.Type = "float unsigned"
case "ff":
field.Type = "double"
case "uff":
field.Type = "double unsigned"
case "si":
field.Type = "smallint"
case "usi":
field.Type = "smallint unsigned"
case "mi":
field.Type = "middleint"
case "umi":
field.Type = "middleint unsigned"
case "t":
field.Type = "text"
case "bb":
field.Type = "blob"
default:
field.Type = tag
}
if size > 0 {
switch tag {
case "I":
field.Index = "index"
field.IndexGroup = cast.String(size)
case "U":
field.Index = "unique"
field.IndexGroup = cast.String(size)
default:
if !field.IsVersion {
field.Type += fmt.Sprintf("(%d)", size)
}
}
}
}
return field
}
// ParseSchema 解析完整的 Schema 描述文本
func ParseSchema(desc string) []*SchemaGroup {
groups := make([]*SchemaGroup, 0)
var currentGroup *SchemaGroup
var currentTable *TableStruct
lines := strings.Split(desc, "\n")
for _, line := range lines {
trimmedLine := strings.TrimSpace(line)
if trimmedLine == "" || strings.HasPrefix(trimmedLine, "#") {
continue
}
if strings.HasPrefix(trimmedLine, "==") {
currentGroup = &SchemaGroup{
Name: strings.TrimSpace(strings.Trim(trimmedLine, "=")),
Tables: make([]*TableStruct, 0),
}
groups = append(groups, currentGroup)
continue
}
if currentGroup == nil {
currentGroup = &SchemaGroup{
Name: "Default",
Tables: make([]*TableStruct, 0),
}
groups = append(groups, currentGroup)
}
if !strings.HasPrefix(line, " ") && !strings.HasPrefix(line, "\t") {
lc := strings.SplitN(trimmedLine, "//", 2)
tableNamePart := strings.TrimSpace(lc[0])
comment := ""
if len(lc) == 2 {
comment = strings.TrimSpace(lc[1])
}
shadowDelete := false
if strings.HasSuffix(tableNamePart, " SD") {
shadowDelete = true
tableNamePart = strings.TrimSpace(strings.TrimSuffix(tableNamePart, " SD"))
}
currentTable = &TableStruct{
Name: tableNamePart,
Comment: comment,
Fields: make([]TableField, 0),
ShadowDelete: shadowDelete,
}
currentGroup.Tables = append(currentGroup.Tables, currentTable)
} else if currentTable != nil {
field := ParseField(trimmedLine)
if field.IsVersion {
currentTable.VersionField = field.Name
}
currentTable.Fields = append(currentTable.Fields, field)
}
}
return groups
}
// Parse 根据数据库类型解析字段的基础描述
func (field *TableField) Parse(tableType string) {
if strings.HasPrefix(tableType, "sqlite") || tableType == "chai" {
// sqlite3 不能修改字段统一使用NULL
field.Null = "NULL"
if field.Extra == "AUTO_INCREMENT" || field.Extra == "AUTOINCREMENT" {
field.Extra = "PRIMARY KEY AUTOINCREMENT"
field.Type = "integer"
field.Null = "NOT NULL"
}
}
a := make([]string, 0)
if tableType == "mysql" {
a = append(a, fmt.Sprintf("`%s` %s", field.Name, field.Type))
lowerType := strings.ToLower(field.Type)
if strings.Contains(lowerType, "varchar") || strings.Contains(lowerType, "text") {
a = append(a, " COLLATE utf8mb4_general_ci")
}
} else if tableType == "pg" || tableType == "pgsql" || tableType == "postgres" {
typ := field.Type
if field.Extra == "AUTO_INCREMENT" {
if strings.Contains(typ, "bigint") {
typ = "bigserial"
} else {
typ = "serial"
}
field.Extra = ""
}
a = append(a, fmt.Sprintf("\"%s\" %s", field.Name, typ))
} else {
a = append(a, fmt.Sprintf("\"%s\" %s", field.Name, field.Type))
}
if field.Extra != "" && !strings.Contains(field.Extra, "PRIMARY KEY") {
a = append(a, " "+field.Extra)
}
a = append(a, " "+field.Null)
if field.Default != "" {
if strings.Contains(field.Default, "CURRENT_TIMESTAMP") || strings.Contains(field.Default, "()") || strings.Contains(field.Default, "SYSTIMESTAMP") {
a = append(a, " DEFAULT "+field.Default)
} else {
a = append(a, " DEFAULT '"+field.Default+"'")
}
}
if strings.HasPrefix(tableType, "sqlite") || tableType == "chai" {
field.Comment = ""
field.Type = "numeric"
} else if tableType == "pg" || tableType == "pgsql" || tableType == "postgres" {
// PostgreSQL comments are separate statements
} else {
if field.Comment != "" {
a = append(a, " COMMENT '"+field.Comment+"'")
}
}
field.Desc = strings.Join(a, "")
}
// Sync 根据 Schema 描述文本同步数据库结构
func (db *DB) Sync(desc string) error {
groups := ParseSchema(desc)
var outErr error
for _, group := range groups {
for _, table := range group.Tables {
db.tablesLock.Lock()
db.tables[table.Name] = table
db.tablesLock.Unlock()
err := db.CheckTable(table)
if err != nil {
outErr = err
db.logger.logger.Error("failed to sync table", "table", table.Name, "err", err.Error())
}
}
}
return outErr
}
// CheckTable 检查并同步单个表结构
func (db *DB) CheckTable(table *TableStruct) error {
fieldSets := make([]string, 0)
pks := make([]string, 0)
keySets := make([]string, 0)
keySetBy := make(map[string]string)
keySetFields := make(map[string]string)
isPostgres := db.Config.Type == "pg" || db.Config.Type == "pgsql" || db.Config.Type == "postgres"
table.Columns = make([]string, 0, len(table.Fields))
for i, field := range table.Fields {
field.Parse(db.Config.Type)
table.Fields[i] = field
table.Columns = append(table.Columns, field.Name)
switch strings.ToLower(field.Index) {
case "pk", "primary key":
pks = append(pks, field.Name)
case "unique":
keyName := fmt.Sprint("uk_", table.Name, "_", field.Name)
if field.IndexGroup != "" {
keyName = fmt.Sprint("uk_", table.Name, "_", field.IndexGroup)
}
if keySetBy[keyName] != "" {
keySetFields[keyName] += " " + field.Name
if strings.HasPrefix(db.Config.Type, "sqlite") || db.Config.Type == "chai" {
keySetBy[keyName] = strings.Replace(keySetBy[keyName], ")", ", "+db.Quote(field.Name)+")", 1)
} else if isPostgres {
keySetBy[keyName] = strings.Replace(keySetBy[keyName], ")", ", "+db.Quote(field.Name)+")", 1)
} else {
keySetBy[keyName] = strings.Replace(keySetBy[keyName], ") COMMENT", ", "+db.Quote(field.Name)+") COMMENT", 1)
}
} else {
keySetFields[keyName] = field.Name
keySet := ""
if strings.HasPrefix(db.Config.Type, "sqlite") || db.Config.Type == "chai" {
keySet = fmt.Sprintf("CREATE UNIQUE INDEX \"%s\" ON \"%s\" (\"%s\")", keyName, table.Name, field.Name)
} else if isPostgres {
keySet = fmt.Sprintf("CREATE UNIQUE INDEX \"%s\" ON \"%s\" (\"%s\")", keyName, table.Name, field.Name)
} else {
keySet = fmt.Sprintf("UNIQUE KEY "+db.Quote("%s")+" ("+db.Quote("%s")+") COMMENT '%s'", keyName, field.Name, field.Comment)
}
keySets = append(keySets, keySet)
keySetBy[keyName] = keySet
}
case "fulltext":
if !strings.HasPrefix(db.Config.Type, "sqlite") && db.Config.Type != "chai" && !isPostgres {
keyName := fmt.Sprint("tk_", table.Name, "_", field.Name)
keySet := fmt.Sprintf("FULLTEXT KEY "+db.Quote("%s")+" ("+db.Quote("%s")+") COMMENT '%s'", keyName, field.Name, field.Comment)
keySets = append(keySets, keySet)
keySetBy[keyName] = keySet
}
case "index":
keyName := fmt.Sprint("ik_", table.Name, "_", field.Name)
if field.IndexGroup != "" {
keyName = fmt.Sprint("ik_", table.Name, "_", field.IndexGroup)
}
if keySetBy[keyName] != "" {
keySetFields[keyName] += " " + field.Name
if strings.HasPrefix(db.Config.Type, "sqlite") || db.Config.Type == "chai" {
keySetBy[keyName] = strings.Replace(keySetBy[keyName], ")", ", \""+field.Name+"\")", 1)
} else if isPostgres {
keySetBy[keyName] = strings.Replace(keySetBy[keyName], ")", ", \""+field.Name+"\")", 1)
} else {
keySetBy[keyName] = strings.Replace(keySetBy[keyName], ") COMMENT", ", `"+field.Name+"`) COMMENT", 1)
}
} else {
keySetFields[keyName] = field.Name
keySet := ""
if strings.HasPrefix(db.Config.Type, "sqlite") || db.Config.Type == "chai" {
keySet = fmt.Sprintf("CREATE INDEX \"%s\" ON \"%s\" (\"%s\")", keyName, table.Name, field.Name)
} else if isPostgres {
keySet = fmt.Sprintf("CREATE INDEX \"%s\" ON \"%s\" (\"%s\")", keyName, table.Name, field.Name)
} else {
keySet = fmt.Sprintf("KEY "+db.Quote("%s")+" ("+db.Quote("%s")+") COMMENT '%s'", keyName, field.Name, field.Comment)
}
keySets = append(keySets, keySet)
keySetBy[keyName] = keySet
}
}
fieldSets = append(fieldSets, field.Desc)
}
var tableInfo map[string]any
if strings.HasPrefix(db.Config.Type, "sqlite") {
tableInfo = db.Query("SELECT \"name\", \"sql\" FROM \"sqlite_master\" WHERE \"type\"='table' AND \"name\"='" + table.Name + "'").MapOnR1()
} else if db.Config.Type == "chai" {
tableInfo = db.Query("SELECT \"name\", \"sql\" FROM \"__chai_catalog\" WHERE \"type\"='table' AND \"name\"='" + table.Name + "'").MapOnR1()
} else if isPostgres {
tableInfo = db.Query("SELECT tablename name FROM pg_tables WHERE schemaname='public' AND tablename='" + table.Name + "'").MapOnR1()
} else {
tableInfo = db.Query("SELECT TABLE_NAME name, TABLE_COMMENT comment FROM information_schema.TABLES WHERE TABLE_SCHEMA='" + db.Config.DB + "' AND TABLE_NAME='" + table.Name + "'").MapOnR1()
}
if tableInfo["name"] != nil && tableInfo["name"] != "" {
oldFieldList := make([]*tableFieldDesc, 0)
oldFields := make(map[string]*tableFieldDesc)
oldIndexes := make(map[string]string)
oldIndexInfos := make([]*tableKeyDesc, 0)
oldComments := map[string]string{}
if strings.HasPrefix(db.Config.Type, "sqlite") {
tmpFields := []struct {
Name string
Type string
Notnull bool
Dflt_value any
Pk bool
}{}
db.Query("PRAGMA table_info(" + db.Quote(table.Name) + ")").To(&tmpFields)
for _, f := range tmpFields {
oldFieldList = append(oldFieldList, &tableFieldDesc{
Field: f.Name,
Type: f.Type,
Null: cast.If(f.Notnull, "NO", "YES"),
Key: cast.If(f.Pk, "PRI", ""),
Default: cast.String(f.Dflt_value),
})
}
tmpIndexes := []struct {
Name string
Unique bool
Origin string
Partial int
}{}
db.Query("PRAGMA index_list(" + db.Quote(table.Name) + ")").To(&tmpIndexes)
for _, i := range tmpIndexes {
tmpIndexInfo := []struct {
Name string
Seqno int
Cid int
}{}
db.Query("PRAGMA index_info(" + db.Quote(i.Name) + ")").To(&tmpIndexInfo)
if len(tmpIndexInfo) > 0 {
oldIndexInfos = append(oldIndexInfos, &tableKeyDesc{
Key_name: i.Name,
Column_name: tmpIndexInfo[0].Name,
})
}
}
} else if isPostgres {
tmpFields := []struct {
Column_name string
Data_type string
Is_nullable string
Column_default string
}{}
db.Query("SELECT column_name, data_type, is_nullable, column_default FROM information_schema.columns WHERE table_schema='public' AND table_name='" + table.Name + "'").To(&tmpFields)
for _, f := range tmpFields {
oldFieldList = append(oldFieldList, &tableFieldDesc{
Field: f.Column_name,
Type: f.Data_type,
Null: f.Is_nullable,
Default: f.Column_default,
})
}
tmpIndexes := []struct {
Indexname string
Indexdef string
}{}
db.Query("SELECT indexname, indexdef FROM pg_indexes WHERE schemaname='public' AND tablename='" + table.Name + "'").To(&tmpIndexes)
for _, i := range tmpIndexes {
// Parse indexdef to get columns if needed, simplified for now
oldIndexes[i.Indexname] = "" // Placeholder
}
} else if db.Config.Type == "mysql" {
_ = db.Query("SELECT column_name, column_comment FROM information_schema.columns WHERE TABLE_SCHEMA='" + db.Config.DB + "' AND TABLE_NAME='" + table.Name + "'").ToKV(&oldComments)
_ = db.Query("DESC " + db.Quote(table.Name)).To(&oldFieldList)
_ = db.Query("SHOW INDEX FROM " + db.Quote(table.Name)).To(&oldIndexInfos)
}
if !isPostgres {
for _, indexInfo := range oldIndexInfos {
if oldIndexes[indexInfo.Key_name] == "" {
oldIndexes[indexInfo.Key_name] = indexInfo.Column_name
} else {
oldIndexes[indexInfo.Key_name] += " " + indexInfo.Column_name
}
}
}
prevFieldId := ""
for _, field := range oldFieldList {
if strings.HasPrefix(db.Config.Type, "sqlite") {
field.Type = "numeric"
} else {
field.After = prevFieldId
}
prevFieldId = field.Field
oldFields[field.Field] = field
}
actions := make([]string, 0)
for keyId := range oldIndexes {
if keyId != "PRIMARY" && !isPostgres && strings.ToLower(keySetFields[keyId]) != strings.ToLower(oldIndexes[keyId]) {
if strings.HasPrefix(db.Config.Type, "sqlite") {
actions = append(actions, "DROP INDEX "+db.Quote(keyId))
} else {
actions = append(actions, "DROP KEY "+db.Quote(keyId))
}
}
}
if oldIndexes["PRIMARY"] != "" && !isPostgres && strings.ToLower(oldIndexes["PRIMARY"]) != strings.ToLower(strings.Join(pks, " ")) {
if !strings.HasPrefix(db.Config.Type, "sqlite") {
actions = append(actions, "DROP PRIMARY KEY")
}
}
newFieldExists := map[string]bool{}
prevFieldId = ""
for _, field := range table.Fields {
newFieldExists[field.Name] = true
oldField := oldFields[field.Name]
if oldField == nil {
if strings.HasPrefix(db.Config.Type, "sqlite") {
actions = append(actions, "ALTER TABLE "+db.Quote(table.Name)+" ADD COLUMN "+field.Desc)
} else if isPostgres {
actions = append(actions, "ALTER TABLE "+db.Quote(table.Name)+" ADD COLUMN "+field.Desc)
} else {
actions = append(actions, "ADD COLUMN "+field.Desc)
}
} else {
if isPostgres {
// Postgres sync is more complex, skipped for brevity in this step but should be robust
continue
}
oldField.Type = strings.TrimSpace(strings.ReplaceAll(oldField.Type, " (", "("))
fixedOldDefault := oldField.Default
if fixedOldDefault == "CURRENT_TIMESTAMP" && strings.Contains(oldField.Extra, "on update CURRENT_TIMESTAMP") {
fixedOldDefault = "CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"
}
fixedOldNull := "NOT NULL"
if oldField.Null == "YES" {
fixedOldNull = "NULL"
}
if strings.ToLower(field.Type) != strings.ToLower(oldField.Type) || strings.ToLower(field.Default) != strings.ToLower(fixedOldDefault) || strings.ToLower(field.Null) != strings.ToLower(fixedOldNull) || (db.Config.Type == "mysql" && strings.ToLower(oldField.After) != strings.ToLower(prevFieldId)) || strings.ToLower(oldComments[field.Name]) != strings.ToLower(field.Comment) {
after := ""
if db.Config.Type == "mysql" {
if oldField.After != prevFieldId {
if prevFieldId == "" {
after = " FIRST"
} else {
after = " AFTER " + db.Quote(prevFieldId)
}
}
actions = append(actions, "CHANGE `"+field.Name+"` "+field.Desc+after)
}
}
}
if db.Config.Type == "mysql" {
prevFieldId = field.Name
}
}
for oldFieldName := range oldFields {
if !newFieldExists[oldFieldName] {
if !strings.HasPrefix(db.Config.Type, "sqlite") && !isPostgres {
actions = append(actions, "DROP COLUMN "+db.Quote(oldFieldName))
}
}
}
if db.Config.Type == "mysql" {
if len(pks) > 0 && strings.ToLower(oldIndexes["PRIMARY"]) != strings.ToLower(strings.Join(pks, " ")) {
actions = append(actions, "ADD PRIMARY KEY(`"+strings.Join(pks, "`,`")+"`)")
}
}
for keyId, keySet := range keySetBy {
if oldIndexes[keyId] == "" || (!isPostgres && strings.ToLower(oldIndexes[keyId]) != strings.ToLower(keySetFields[keyId])) {
if strings.HasPrefix(db.Config.Type, "sqlite") || isPostgres {
actions = append(actions, keySet)
} else {
actions = append(actions, "ADD "+keySet)
}
}
}
if db.Config.Type == "mysql" {
oldTableComment := cast.String(tableInfo["comment"])
if table.Comment != oldTableComment {
actions = append(actions, "COMMENT '"+table.Comment+"'")
}
}
if len(actions) == 0 {
goto SYNC_SHADOW
}
tx := db.Begin()
defer tx.CheckFinished()
var res *ExecResult
if strings.HasPrefix(db.Config.Type, "sqlite") || isPostgres {
for _, action := range actions {
res = tx.Exec(action)
if res.Error != nil {
break
}
}
} else {
sql := "ALTER TABLE " + db.Quote(table.Name) + " " + strings.Join(actions, "\n,") + ";"
res = tx.Exec(sql)
}
if res != nil && res.Error != nil {
_ = tx.Rollback()
return res.Error
}
_ = tx.Commit()
} else {
// 创建新表
if len(pks) > 0 {
if isPostgres {
// In Postgres, PK is often in CREATE TABLE
fieldSets = append(fieldSets, "PRIMARY KEY ("+db.Quotes(pks)+")")
} else {
fieldSets = append(fieldSets, "PRIMARY KEY ("+db.Quotes(pks)+")")
}
}
indexSets := make([]string, 0)
if strings.HasPrefix(db.Config.Type, "sqlite") || isPostgres {
for _, indexSql := range keySetBy {
indexSets = append(indexSets, indexSql)
}
} else {
for _, key := range keySets {
fieldSets = append(fieldSets, key)
}
}
sql := ""
if strings.HasPrefix(db.Config.Type, "sqlite") || isPostgres {
sql = fmt.Sprintf("CREATE TABLE %s (\n%s\n);", db.Quote(table.Name), strings.Join(fieldSets, ",\n"))
} else if db.Config.Type == "mysql" {
sql = fmt.Sprintf("CREATE TABLE `%s` (\n%s\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='%s';", table.Name, strings.Join(fieldSets, ",\n"), table.Comment)
}
tx := db.Begin()
defer tx.CheckFinished()
res := tx.Exec(sql)
if res.Error == nil && (strings.HasPrefix(db.Config.Type, "sqlite") || isPostgres) {
for _, indexSet := range indexSets {
r := tx.Exec(indexSet)
if r.Error != nil {
res = r
break
}
}
}
if res.Error != nil {
_ = tx.Rollback()
return res.Error
}
_ = tx.Commit()
}
SYNC_SHADOW:
if table.ShadowDelete && !strings.HasSuffix(table.Name, "_deleted") {
table.HasShadowTable = true
shadowTable := *table
shadowTable.Name = table.Name + "_deleted"
shadowTable.ShadowDelete = false
shadowTable.Fields = make([]TableField, 0, len(table.Fields)+1)
for _, f := range table.Fields {
f.Index = ""
f.IndexGroup = ""
f.Extra = ""
if strings.Contains(f.Type, "serial") {
if strings.Contains(f.Type, "bigserial") {
f.Type = "bigint"
} else {
f.Type = "int"
}
}
shadowTable.Fields = append(shadowTable.Fields, f)
}
// 自动为删除表增加删除时间字段,方便同步
shadowTable.Fields = append(shadowTable.Fields, TableField{
Name: "deleted_at",
Type: "datetime",
Default: "CURRENT_TIMESTAMP",
Null: "NOT NULL",
Comment: "删除时间",
})
return db.CheckTable(&shadowTable)
}
return nil
}
type tableFieldDesc struct {
Field string
Type string
Null string
Key string
Default string
Extra string
After string
}
type tableKeyDesc struct {
Key_name string
Column_name string
}