indexDB/indexDB.go

372 lines
8.3 KiB
Go

package indexDB
import (
"fmt"
"log"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"github.com/blevesearch/bleve/v2"
"github.com/blevesearch/bleve/v2/search/query"
"apigo.cc/go/cast"
)
const SystemUserID = "_system"
type IndexDBUnauthorized struct {
fulltextPath string
vectorPath string
embedding func(text string) ([]float32, error)
logger *log.Logger
fulltext *Engine
vector *Store
mu sync.RWMutex
}
func GetDB(fulltextPath, vectorPath string, embedding func(string) ([]float32, error), logger *log.Logger) (*IndexDBUnauthorized, error) {
if logger == nil {
logger = log.Default()
}
db := &IndexDBUnauthorized{
fulltextPath: fulltextPath,
vectorPath: vectorPath,
embedding: embedding,
logger: logger,
}
if fulltextPath != "" {
if parent := filepath.Dir(fulltextPath); parent != "." {
os.MkdirAll(parent, 0755)
}
fEngine, err := newFulltextEngine(fulltextPath)
if err != nil {
return nil, fmt.Errorf("init fulltext error: %w", err)
}
db.fulltext = fEngine
}
if vectorPath != "" && embedding != nil {
if parent := filepath.Dir(vectorPath); parent != "." {
os.MkdirAll(parent, 0755)
}
vStore, err := newVectorStore(Config{StorageDir: vectorPath})
if err != nil {
return nil, fmt.Errorf("init vector error: %w", err)
}
db.vector = vStore
}
return db, nil
}
func (db *IndexDBUnauthorized) Close() error {
db.mu.Lock()
defer db.mu.Unlock()
if db.fulltext != nil {
return db.fulltext.Close()
}
return nil
}
func (db *IndexDBUnauthorized) Auth(userId string) *IndexDB {
return &IndexDB{
db: db,
userId: userId,
}
}
type IndexDB struct {
db *IndexDBUnauthorized
userId string
}
func (idx *IndexDB) Add(id string, text string, metadata map[string]any, allowUsers []string) error {
idx.db.mu.RLock()
defer idx.db.mu.RUnlock()
if idx.db.fulltext != nil {
err := idx.db.fulltext.AddDocument(id, text, metadata, allowUsers)
if err != nil {
return fmt.Errorf("fulltext add: %w", err)
}
}
if idx.db.embedding != nil && idx.db.vector != nil {
vec, err := idx.db.embedding(text)
if err != nil {
return fmt.Errorf("embedding: %w", err)
}
coll, err := idx.db.vector.GetOrCreateCollection("main")
if err != nil {
return fmt.Errorf("vector collection: %w", err)
}
err = coll.AddRecord(id, vec, metadata, allowUsers)
if err != nil {
return fmt.Errorf("vector add: %w", err)
}
}
return nil
}
func (idx *IndexDB) Remove(id string) error {
idx.db.mu.RLock()
defer idx.db.mu.RUnlock()
if idx.db.fulltext != nil {
err := idx.db.fulltext.RemoveDocument(id)
if err != nil {
return fmt.Errorf("fulltext remove: %w", err)
}
}
if idx.db.vector != nil {
coll, err := idx.db.vector.GetOrCreateCollection("main")
if err == nil {
_ = coll.DeleteRecord(id)
}
}
return nil
}
func (idx *IndexDB) Search(idPrefix string, queryStr string, topK int, filter []Condition) ([]SearchResult, error) {
idx.db.mu.RLock()
defer idx.db.mu.RUnlock()
var wg sync.WaitGroup
var fResults []SearchResult
var vResults []SearchResult
var fErr, vErr error
if idx.db.fulltext != nil {
wg.Add(1)
go func() {
defer wg.Done()
fResults, fErr = idx.db.fulltext.Search(idPrefix, queryStr, topK*5, idx.userId, filter)
}()
}
if idx.db.vector != nil && idx.db.embedding != nil && queryStr != "" {
wg.Add(1)
go func() {
defer wg.Done()
vec, err := idx.db.embedding(queryStr)
if err != nil {
vErr = fmt.Errorf("embed query: %w", err)
return
}
if len(vec) > 0 {
coll, err := idx.db.vector.GetOrCreateCollection("main")
if err == nil {
vResults, vErr = coll.Search(vec, topK*5, idx.userId, idPrefix, filter)
} else {
vErr = err
}
}
}()
}
wg.Wait()
if fErr != nil {
return nil, fmt.Errorf("fulltext search: %w", fErr)
}
if vErr != nil {
return nil, fmt.Errorf("vector search: %w", vErr)
}
merged := mergeAndRRF(fResults, vResults, topK)
// Patch content for vector-only results
if idx.db.fulltext != nil {
for i := range merged {
if merged[i].Content == "" {
q := query.NewDocIDQuery([]string{merged[i].ID})
req := bleve.NewSearchRequest(q)
req.Fields = []string{"content", "metadata"}
res, _ := idx.db.fulltext.index.Search(req)
if res != nil && len(res.Hits) > 0 {
merged[i].Content, _ = res.Hits[0].Fields["content"].(string)
// Metadata might also be missing if it was vector-only hit
if merged[i].Metadata == nil {
metaIfc := make(map[string]any)
for k, v := range res.Hits[0].Fields {
if strings.HasPrefix(k, "metadata.") {
metaIfc[strings.TrimPrefix(k, "metadata.")] = v
}
}
merged[i].Metadata = metaIfc
}
}
}
}
}
return merged, nil
}
func mergeAndRRF(fResults []SearchResult, vResults []SearchResult, topK int) []SearchResult {
scores := make(map[string]float32)
items := make(map[string]SearchResult)
const k = 60
for i, r := range fResults {
items[r.ID] = r
rank := float32(i + 1)
scores[r.ID] += 1.0 / (k + rank)
}
for i, r := range vResults {
if _, ok := items[r.ID]; !ok {
items[r.ID] = r
}
rank := float32(i + 1)
scores[r.ID] += 1.0 / (k + rank)
}
var merged []SearchResult
for id, score := range scores {
item := items[id]
item.Score = score
merged = append(merged, item)
}
sort.Slice(merged, func(i, j int) bool {
return merged[i].Score > merged[j].Score
})
if len(merged) > topK {
merged = merged[:topK]
}
return merged
}
func (idx *IndexDB) ScanDocuments(lastID string, limit int) ([]RawDocument, error) {
idx.db.mu.RLock()
defer idx.db.mu.RUnlock()
if idx.db.fulltext == nil {
return nil, fmt.Errorf("fulltext engine not initialized")
}
var q query.Query
if lastID == "" {
q = query.NewMatchAllQuery()
} else {
rq := query.NewTermRangeQuery(lastID, "")
f := false
rq.InclusiveMin = &f
rq.SetField("id")
q = rq
}
req := bleve.NewSearchRequest(q)
req.Size = limit
req.SortBy([]string{"id"})
req.Fields = []string{"*"}
res, err := idx.db.fulltext.index.Search(req)
if err != nil {
return nil, fmt.Errorf("scan search error: %w", err)
}
var docs []RawDocument
for _, hit := range res.Hits {
idVal, _ := hit.Fields["id"].(string)
if idVal == "" {
idVal = hit.ID
}
content, _ := hit.Fields["content"].(string)
metaIfc := make(map[string]any)
var allowUsers []string
for k, v := range hit.Fields {
if strings.HasPrefix(k, "metadata.") {
metaIfc[strings.TrimPrefix(k, "metadata.")] = v
} else if strings.HasPrefix(k, "U-") && k != "U-_system" {
if cast.To[string](v) == "1" {
allowUsers = append(allowUsers, strings.TrimPrefix(k, "U-"))
}
}
}
docs = append(docs, RawDocument{
ID: idVal,
Text: content,
Metadata: metaIfc,
AllowUsers: allowUsers,
})
}
return docs, nil
}
func evalCondition(metadata map[string]any, id string, idPrefix string, filters []Condition) bool {
if idPrefix != "" && !strings.HasPrefix(id, idPrefix) {
return false
}
for _, cond := range filters {
val, ok := metadata[cond.Field]
if !ok {
return false
}
valStr := cast.To[string](val)
condValStr := cast.To[string](cond.Value)
valFloat := cast.To[float64](val)
condValFloat := cast.To[float64](cond.Value)
switch cond.Operator {
case "eq", "=", "==":
if valStr != condValStr { return false }
case "gt", ">":
if valFloat <= condValFloat { return false }
case "lt", "<":
if valFloat >= condValFloat { return false }
case "ge", ">=":
if valFloat < condValFloat { return false }
case "le", "<=":
if valFloat > condValFloat { return false }
case "contains":
if !strings.Contains(valStr, condValStr) { return false }
case "in":
found := false
if arr, ok := cond.Value.([]any); ok {
for _, item := range arr {
if cast.To[string](item) == valStr {
found = true
break
}
}
} else if arr, ok := cond.Value.([]string); ok {
for _, item := range arr {
if item == valStr {
found = true
break
}
}
}
if !found { return false }
case "between":
if arr, ok := cond.Value.([]any); ok && len(arr) == 2 {
minV := cast.To[float64](arr[0])
maxV := cast.To[float64](arr[1])
if valFloat < minV || valFloat > maxV { return false }
} else if arr, ok := cond.Value.([]float64); ok && len(arr) == 2 {
if valFloat < arr[0] || valFloat > arr[1] { return false }
}
default:
return false
}
}
return true
}