indexDB/indexDB.go

421 lines
9.8 KiB
Go
Raw Normal View History

package indexDB
import (
"fmt"
"log"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"github.com/blevesearch/bleve/v2/search/query"
"github.com/blevesearch/bleve/v2"
"apigo.cc/go/cast"
)
const SystemUserID = "_system"
type IndexDBUnauthorized struct {
indexDBPath string
embedding func(text string) ([]float32, error)
logger *log.Logger
fulltext *Engine
vector *Store
mu sync.RWMutex
}
func GetDB(indexDBPath string, embedding func(string) ([]float32, error), logger *log.Logger) (*IndexDBUnauthorized, error) {
if logger == nil {
logger = log.Default()
}
db := &IndexDBUnauthorized{
indexDBPath: indexDBPath,
embedding: embedding,
logger: logger,
}
err := db.load()
if err != nil {
return nil, err
}
return db, nil
}
func (db *IndexDBUnauthorized) load() error {
os.MkdirAll(db.indexDBPath, 0755)
fvPath := filepath.Join(db.indexDBPath, "fulltextVersion.txt")
vvPath := filepath.Join(db.indexDBPath, "vectorVersion.txt")
if _, err := os.Stat(fvPath); os.IsNotExist(err) {
os.WriteFile(fvPath, []byte("1"), 0644)
}
if _, err := os.Stat(vvPath); os.IsNotExist(err) {
os.WriteFile(vvPath, []byte("1"), 0644)
}
fv, _ := os.ReadFile(fvPath)
vv, _ := os.ReadFile(vvPath)
fVersion := strings.TrimSpace(string(fv))
vVersion := strings.TrimSpace(string(vv))
if fVersion == "" {
fVersion = "1"
}
if vVersion == "" {
vVersion = "1"
}
fulltextPath := filepath.Join(db.indexDBPath, "fulltextV"+fVersion)
vectorPath := filepath.Join(db.indexDBPath, "vectorV"+vVersion)
fEngine, err := newFulltextEngine(fulltextPath)
if err != nil {
return fmt.Errorf("init fulltext error: %w", err)
}
var vStore *Store
if db.embedding != nil {
vStore, err = newVectorStore(Config{StorageDir: vectorPath})
if err != nil {
return fmt.Errorf("init vector error: %w", err)
}
}
db.fulltext = fEngine
db.vector = vStore
return nil
}
func (db *IndexDBUnauthorized) Auth(userId string) *IndexDB {
return &IndexDB{
db: db,
userId: userId,
}
}
type IndexDB struct {
db *IndexDBUnauthorized
userId string
}
func (idx *IndexDB) Add(id string, text string, metadata map[string]any, allowUsers []string) error {
idx.db.mu.RLock()
defer idx.db.mu.RUnlock()
err := idx.db.fulltext.AddDocument(id, text, metadata, allowUsers)
if err != nil {
return fmt.Errorf("fulltext add: %w", err)
}
if idx.db.embedding != nil && idx.db.vector != nil {
vec, err := idx.db.embedding(text)
if err != nil {
return fmt.Errorf("embedding: %w", err)
}
coll, err := idx.db.vector.GetOrCreateCollection("main")
if err != nil {
return fmt.Errorf("vector collection: %w", err)
}
err = coll.AddRecord(id, vec, metadata, allowUsers)
if err != nil {
return fmt.Errorf("vector add: %w", err)
}
}
return nil
}
func (idx *IndexDB) Remove(id string) error {
idx.db.mu.RLock()
defer idx.db.mu.RUnlock()
err := idx.db.fulltext.RemoveDocument(id)
if err != nil {
return fmt.Errorf("fulltext remove: %w", err)
}
if idx.db.vector != nil {
coll, err := idx.db.vector.GetOrCreateCollection("main")
if err == nil {
_ = coll.DeleteRecord(id)
}
}
return nil
}
func (idx *IndexDB) Search(idPrefix string, queryStr string, topK int, filter []Condition) ([]SearchResult, error) {
idx.db.mu.RLock()
defer idx.db.mu.RUnlock()
var wg sync.WaitGroup
var fResults []SearchResult
var vResults []SearchResult
var fErr, vErr error
wg.Add(1)
go func() {
defer wg.Done()
fResults, fErr = idx.db.fulltext.Search(idPrefix, queryStr, topK*5, idx.userId, filter)
}()
if idx.db.vector != nil && idx.db.embedding != nil && queryStr != "" {
wg.Add(1)
go func() {
defer wg.Done()
vec, err := idx.db.embedding(queryStr)
if err != nil {
vErr = fmt.Errorf("embed query: %w", err)
return
}
if len(vec) > 0 {
coll, err := idx.db.vector.GetOrCreateCollection("main")
if err == nil {
vResults, vErr = coll.Search(vec, topK*5, idx.userId, idPrefix, filter)
} else {
vErr = err
}
}
}()
}
wg.Wait()
if fErr != nil {
return nil, fmt.Errorf("fulltext search: %w", fErr)
}
if vErr != nil {
return nil, fmt.Errorf("vector search: %w", vErr)
}
return mergeAndRRF(fResults, vResults, topK), nil
}
func mergeAndRRF(fResults []SearchResult, vResults []SearchResult, topK int) []SearchResult {
scores := make(map[string]float32)
items := make(map[string]SearchResult)
const k = 60
for i, r := range fResults {
items[r.ID] = r
rank := float32(i + 1)
scores[r.ID] += 1.0 / (k + rank)
}
for i, r := range vResults {
if _, ok := items[r.ID]; !ok {
items[r.ID] = r
}
rank := float32(i + 1)
scores[r.ID] += 1.0 / (k + rank)
}
var merged []SearchResult
for id, score := range scores {
item := items[id]
item.Score = score
merged = append(merged, item)
}
sort.Slice(merged, func(i, j int) bool {
return merged[i].Score > merged[j].Score
})
if len(merged) > topK {
merged = merged[:topK]
}
return merged
}
func evalCondition(metadata map[string]any, id string, idPrefix string, filters []Condition) bool {
if idPrefix != "" && !strings.HasPrefix(id, idPrefix) {
return false
}
for _, cond := range filters {
val, ok := metadata[cond.Field]
if !ok {
return false
}
valStr := cast.To[string](val)
condValStr := cast.To[string](cond.Value)
valFloat := cast.To[float64](val)
condValFloat := cast.To[float64](cond.Value)
switch cond.Operator {
case "eq", "=", "==":
if valStr != condValStr { return false }
case "gt", ">":
if valFloat <= condValFloat { return false }
case "lt", "<":
if valFloat >= condValFloat { return false }
case "ge", ">=":
if valFloat < condValFloat { return false }
case "le", "<=":
if valFloat > condValFloat { return false }
case "contains":
if !strings.Contains(valStr, condValStr) { return false }
case "in":
found := false
if arr, ok := cond.Value.([]any); ok {
for _, item := range arr {
if cast.To[string](item) == valStr {
found = true
break
}
}
} else if arr, ok := cond.Value.([]string); ok {
for _, item := range arr {
if item == valStr {
found = true
break
}
}
}
if !found { return false }
case "between":
if arr, ok := cond.Value.([]any); ok && len(arr) == 2 {
minV := cast.To[float64](arr[0])
maxV := cast.To[float64](arr[1])
if valFloat < minV || valFloat > maxV { return false }
} else if arr, ok := cond.Value.([]float64); ok && len(arr) == 2 {
if valFloat < arr[0] || valFloat > arr[1] { return false }
}
default:
return false
}
}
return true
}
func (idx *IndexDB) RebuildAll() error {
idx.db.mu.Lock()
defer idx.db.mu.Unlock()
// 1. Get current versions
fvPath := filepath.Join(idx.db.indexDBPath, "fulltextVersion.txt")
vvPath := filepath.Join(idx.db.indexDBPath, "vectorVersion.txt")
fv, _ := os.ReadFile(fvPath)
vv, _ := os.ReadFile(vvPath)
fVersion := cast.To[int](strings.TrimSpace(string(fv)))
vVersion := cast.To[int](strings.TrimSpace(string(vv)))
if fVersion <= 0 { fVersion = 1 }
if vVersion <= 0 { vVersion = 1 }
newFVersion := fVersion + 1
newVVersion := vVersion + 1
newFPath := filepath.Join(idx.db.indexDBPath, "fulltextV"+cast.To[string](newFVersion))
newVPath := filepath.Join(idx.db.indexDBPath, "vectorV"+cast.To[string](newVVersion))
// 2. Initialize new engines
newFEngine, err := newFulltextEngine(newFPath)
if err != nil {
return fmt.Errorf("rebuild new fulltext err: %w", err)
}
var newVStore *Store
if idx.db.embedding != nil {
newVStore, err = newVectorStore(Config{StorageDir: newVPath})
if err != nil {
return fmt.Errorf("rebuild new vector err: %w", err)
}
}
// 3. Read all data from old fulltext engine
// Bleve search MatchAll with large size or pagination
req := bleve.NewSearchRequest(query.NewMatchAllQuery())
req.Fields = []string{"*"}
req.Size = 1000
from := 0
for {
req.From = from
res, err := idx.db.fulltext.index.Search(req)
if err != nil {
return fmt.Errorf("rebuild search fulltext err: %w", err)
}
if len(res.Hits) == 0 {
break
}
for _, hit := range res.Hits {
idVal, _ := hit.Fields["id"].(string)
if idVal == "" {
idVal = hit.ID
}
content, _ := hit.Fields["content"].(string)
// Extract metadata and allowUsers
metaIfc := make(map[string]any)
allowUsers := []string{}
for k, v := range hit.Fields {
if strings.HasPrefix(k, "metadata.") {
metaIfc[strings.TrimPrefix(k, "metadata.")] = v
} else if strings.HasPrefix(k, "U-") && k != "U-_system" {
if cast.To[string](v) == "1" {
allowUsers = append(allowUsers, strings.TrimPrefix(k, "U-"))
}
}
}
// Add to new engines
err = newFEngine.AddDocument(idVal, content, metaIfc, allowUsers)
if err != nil {
idx.db.logger.Println("Rebuild fulltext add err:", err)
}
if newVStore != nil && idx.db.embedding != nil {
vec, err := idx.db.embedding(content)
if err == nil {
coll, _ := newVStore.GetOrCreateCollection("main")
err = coll.AddRecord(idVal, vec, metaIfc, allowUsers)
if err != nil {
idx.db.logger.Println("Rebuild vector add err:", err)
}
} else {
idx.db.logger.Println("Rebuild vector embed err:", err)
}
}
}
from += len(res.Hits)
if from >= int(res.Total) {
break
}
}
// 4. Swap and cleanup
oldFEngine := idx.db.fulltext
oldVStore := idx.db.vector
idx.db.fulltext = newFEngine
idx.db.vector = newVStore
os.WriteFile(fvPath, []byte(cast.To[string](newFVersion)), 0644)
os.WriteFile(vvPath, []byte(cast.To[string](newVVersion)), 0644)
// Close old engines and remove dirs in background
go func() {
oldFEngine.Close()
os.RemoveAll(filepath.Join(idx.db.indexDBPath, "fulltextV"+cast.To[string](fVersion)))
if oldVStore != nil {
os.RemoveAll(filepath.Join(idx.db.indexDBPath, "vectorV"+cast.To[string](vVersion)))
}
}()
return nil
}