package openai import ( "bytes" "context" "encoding/binary" "strings" "time" "apigo.cc/ai" "github.com/sashabaranov/go-openai" "github.com/ssgo/log" "github.com/ssgo/u" ) func getClient(aiConf *ai.AIConfig) *openai.Client { openaiConf := openai.DefaultConfig(aiConf.ApiKey) if aiConf.Endpoint != "" { openaiConf.BaseURL = aiConf.Endpoint } return openai.NewClientWithConfig(openaiConf) } // func (lm *LLM) Ask(messages []ai.ChatMessage, config ai.ChatConfig, callback func(answer string)) (string, ai.Usage, error) { func Chat(aiConf *ai.AIConfig, messages []ai.ChatMessage, callback func(string), conf ai.ChatConfig) (ai.ChatResult, error) { chatMessages := make([]openai.ChatCompletionMessage, len(messages)) for i, msg := range messages { var contents []openai.ChatMessagePart if msg.Contents != nil { contents = make([]openai.ChatMessagePart, len(msg.Contents)) for j, inPart := range msg.Contents { part := openai.ChatMessagePart{} switch inPart.Type { case ai.TypeText: part.Type = openai.ChatMessagePartTypeText part.Text = inPart.Content case ai.TypeImage: part.Type = openai.ChatMessagePartTypeImageURL part.ImageURL = &openai.ChatMessageImageURL{ URL: inPart.Content, Detail: openai.ImageURLDetailAuto, } default: part.Type = openai.ChatMessagePartType(inPart.Type) part.Text = inPart.Content } contents[j] = part } } if len(contents) == 1 && contents[0].Type == ai.TypeText { chatMessages[i] = openai.ChatCompletionMessage{ Role: msg.Role, Content: contents[0].Text, } } else { chatMessages[i] = openai.ChatCompletionMessage{ Role: msg.Role, MultiContent: contents, } } } if conf.SystemPrompt != "" { chatMessages = append([]openai.ChatCompletionMessage{{ Role: openai.ChatMessageRoleSystem, Content: conf.SystemPrompt, }}, chatMessages...) } opt := openai.ChatCompletionRequest{ Model: conf.Model, Messages: chatMessages, MaxTokens: conf.MaxTokens, Temperature: float32(conf.Temperature), TopP: float32(conf.TopP), StreamOptions: &openai.StreamOptions{ IncludeUsage: true, }, } for name, toolConf := range conf.Tools { switch name { case ai.ToolCodeInterpreter: opt.Tools = append(opt.Tools, openai.Tool{Type: "code_interpreter"}) case ai.ToolFunction: conf := openai.FunctionDefinition{} u.Convert(toolConf, &conf) opt.Tools = append(opt.Tools, openai.Tool{Type: openai.ToolTypeFunction, Function: &conf}) } } c := getClient(aiConf) if callback != nil { opt.Stream = true r, err := c.CreateChatCompletionStream(context.Background(), opt) if err == nil { results := make([]string, 0) out := ai.ChatResult{} for { if r2, err := r.Recv(); err == nil { if r2.Choices != nil { for _, ch := range r2.Choices { text := ch.Delta.Content callback(text) results = append(results, text) } } if r2.Usage != nil { out.AskTokens += int64(r2.Usage.PromptTokens) out.AnswerTokens += int64(r2.Usage.CompletionTokens) out.TotalTokens += int64(r2.Usage.TotalTokens) } } else { break } } _ = r.Close() out.Result = strings.Join(results, "") return out, nil } else { log.DefaultLogger.Error(err.Error()) return ai.ChatResult{}, err } } else { t1 := time.Now().UnixMilli() if r, err := c.CreateChatCompletion(context.Background(), opt); err == nil { t2 := time.Now().UnixMilli() - t1 results := make([]string, 0) if r.Choices != nil { for _, ch := range r.Choices { results = append(results, ch.Message.Content) } } return ai.ChatResult{ Result: strings.Join(results, ""), AskTokens: int64(r.Usage.PromptTokens), AnswerTokens: int64(r.Usage.CompletionTokens), TotalTokens: int64(r.Usage.TotalTokens), UsedTime: t2, }, nil } else { return ai.ChatResult{}, err } } } func Embedding(aiConf *ai.AIConfig, text string, embeddingConf ai.EmbeddingConfig) (ai.EmbeddingResult, error) { c := getClient(aiConf) req := openai.EmbeddingRequest{ Input: text, Model: openai.EmbeddingModel(embeddingConf.Model), User: "", EncodingFormat: "", Dimensions: 0, } t1 := time.Now().UnixMilli() if r, err := c.CreateEmbeddings(context.Background(), req); err == nil { t2 := time.Now().UnixMilli() - t1 buf := new(bytes.Buffer) if r.Data != nil { for _, ch := range r.Data { for _, v := range ch.Embedding { _ = binary.Write(buf, binary.LittleEndian, v) } } } return ai.EmbeddingResult{ Result: buf.Bytes(), AskTokens: int64(r.Usage.PromptTokens), AnswerTokens: int64(r.Usage.CompletionTokens), TotalTokens: int64(r.Usage.TotalTokens), UsedTime: t2, }, nil } else { return ai.EmbeddingResult{}, err } }