indexDB/vector.go

137 lines
2.9 KiB
Go
Raw Permalink Normal View History

package indexDB
import (
"context"
"fmt"
"path/filepath"
"runtime"
"github.com/philippgille/chromem-go"
"apigo.cc/go/cast"
)
type Store struct {
db *chromem.DB
}
type Config struct {
StorageDir string
}
func newVectorStore(cfg Config) (*Store, error) {
var db *chromem.DB
var err error
if cfg.StorageDir == "" {
db = chromem.NewDB()
} else {
path := filepath.Clean(cfg.StorageDir)
db, err = chromem.NewPersistentDB(path, false)
if err != nil {
return nil, fmt.Errorf("failed to open persistent vector DB: %w", err)
}
}
return &Store{db: db}, nil
}
type Collection struct {
coll *chromem.Collection
}
func (s *Store) GetOrCreateCollection(name string) (*Collection, error) {
coll := s.db.GetCollection(name, nil)
if coll == nil {
var err error
coll, err = s.db.CreateCollection(name, nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to create collection: %w", err)
}
}
return &Collection{coll: coll}, nil
}
func (c *Collection) AddRecord(id string, vector []float32, metadata map[string]any, allowUsers []string) error {
rawJson := cast.As(cast.ToJSON(metadata))
metaStr := map[string]string{
"raw_json": rawJson,
}
for _, u := range allowUsers {
metaStr["U-"+u] = "1"
}
doc := chromem.Document{
ID: id,
Embedding: vector,
Metadata: metaStr,
}
err := c.coll.AddDocuments(context.Background(), []chromem.Document{doc}, runtime.NumCPU())
if err != nil {
return fmt.Errorf("failed to add records: %w", err)
}
return nil
}
func (c *Collection) Search(queryVector []float32, topK int, userID string, idPrefix string, filter []Condition) ([]SearchResult, error) {
var vFilter map[string]string
if userID != "" && userID != SystemUserID {
vFilter = map[string]string{
"U-" + userID: "1",
}
}
count := c.coll.Count()
if count == 0 {
return nil, nil
}
nResults := topK * 10
if len(filter) > 0 || idPrefix != "" {
// If there are filters, we should fetch as much as possible to ensure accuracy.
nResults = count
}
if nResults > count {
nResults = count
}
// We query more and filter in memory since Chromem-go only supports exact map filter.
res, err := c.coll.QueryEmbedding(context.Background(), queryVector, nResults, vFilter, nil)
if err != nil {
return nil, fmt.Errorf("vector search failed: %w", err)
}
var results []SearchResult
for _, r := range res {
var metaIfc map[string]any
if raw, ok := r.Metadata["raw_json"]; ok {
metaIfc = cast.As(cast.FromJSON[map[string]any](raw))
} else {
metaIfc = make(map[string]any)
}
if !evalCondition(metaIfc, r.ID, idPrefix, filter) {
continue
}
results = append(results, SearchResult{
ID: r.ID,
Score: r.Similarity,
Content: r.Content,
Metadata: metaIfc,
})
if len(results) >= topK {
break
}
}
return results, nil
}
func (c *Collection) DeleteRecord(id string) error {
return c.coll.Delete(context.Background(), nil, nil, id)
}