131 lines
2.7 KiB
Go
131 lines
2.7 KiB
Go
|
|
package indexDB
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"encoding/json"
|
||
|
|
"fmt"
|
||
|
|
"path/filepath"
|
||
|
|
"runtime"
|
||
|
|
|
||
|
|
"github.com/philippgille/chromem-go"
|
||
|
|
)
|
||
|
|
|
||
|
|
type Store struct {
|
||
|
|
db *chromem.DB
|
||
|
|
}
|
||
|
|
|
||
|
|
type Config struct {
|
||
|
|
StorageDir string
|
||
|
|
}
|
||
|
|
|
||
|
|
func newVectorStore(cfg Config) (*Store, error) {
|
||
|
|
var db *chromem.DB
|
||
|
|
var err error
|
||
|
|
|
||
|
|
if cfg.StorageDir == "" {
|
||
|
|
db = chromem.NewDB()
|
||
|
|
} else {
|
||
|
|
path := filepath.Clean(cfg.StorageDir)
|
||
|
|
db, err = chromem.NewPersistentDB(path, false)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to open persistent vector DB: %w", err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return &Store{db: db}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
type Collection struct {
|
||
|
|
coll *chromem.Collection
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *Store) GetOrCreateCollection(name string) (*Collection, error) {
|
||
|
|
coll := s.db.GetCollection(name, nil)
|
||
|
|
if coll == nil {
|
||
|
|
var err error
|
||
|
|
coll, err = s.db.CreateCollection(name, nil, nil)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to create collection: %w", err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return &Collection{coll: coll}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *Collection) AddRecord(id string, vector []float32, metadata map[string]any, allowUsers []string) error {
|
||
|
|
rawJson, _ := json.Marshal(metadata)
|
||
|
|
metaStr := map[string]string{
|
||
|
|
"raw_json": string(rawJson),
|
||
|
|
}
|
||
|
|
for _, u := range allowUsers {
|
||
|
|
metaStr["U-"+u] = "1"
|
||
|
|
}
|
||
|
|
|
||
|
|
doc := chromem.Document{
|
||
|
|
ID: id,
|
||
|
|
Embedding: vector,
|
||
|
|
Metadata: metaStr,
|
||
|
|
}
|
||
|
|
|
||
|
|
err := c.coll.AddDocuments(context.Background(), []chromem.Document{doc}, runtime.NumCPU())
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("failed to add records: %w", err)
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *Collection) Search(queryVector []float32, topK int, userID string, idPrefix string, filter []Condition) ([]SearchResult, error) {
|
||
|
|
var vFilter map[string]string
|
||
|
|
if userID != "" && userID != SystemUserID {
|
||
|
|
vFilter = map[string]string{
|
||
|
|
"U-" + userID: "1",
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
nResults := topK * 10
|
||
|
|
count := c.coll.Count()
|
||
|
|
if nResults > count {
|
||
|
|
nResults = count
|
||
|
|
}
|
||
|
|
|
||
|
|
if nResults == 0 {
|
||
|
|
return nil, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// We query more and filter in memory since Chromem-go only supports exact map filter.
|
||
|
|
res, err := c.coll.QueryEmbedding(context.Background(), queryVector, nResults, vFilter, nil)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("vector search failed: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
var results []SearchResult
|
||
|
|
for _, r := range res {
|
||
|
|
var metaIfc map[string]any
|
||
|
|
if raw, ok := r.Metadata["raw_json"]; ok {
|
||
|
|
_ = json.Unmarshal([]byte(raw), &metaIfc)
|
||
|
|
} else {
|
||
|
|
metaIfc = make(map[string]any)
|
||
|
|
}
|
||
|
|
|
||
|
|
if !evalCondition(metaIfc, r.ID, idPrefix, filter) {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
|
||
|
|
results = append(results, SearchResult{
|
||
|
|
ID: r.ID,
|
||
|
|
Score: r.Similarity,
|
||
|
|
Content: r.Content,
|
||
|
|
Metadata: metaIfc,
|
||
|
|
})
|
||
|
|
|
||
|
|
if len(results) >= topK {
|
||
|
|
break
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return results, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *Collection) DeleteRecord(id string) error {
|
||
|
|
return c.coll.Delete(context.Background(), nil, nil, id)
|
||
|
|
}
|