openai/chat.go
2024-10-31 15:27:22 +08:00

179 lines
4.8 KiB
Go

package openai
import (
"bytes"
"context"
"encoding/binary"
"strings"
"time"
"apigo.cc/ai"
"github.com/sashabaranov/go-openai"
"github.com/ssgo/log"
"github.com/ssgo/u"
)
func getClient(aiConf *ai.AIConfig) *openai.Client {
openaiConf := openai.DefaultConfig(aiConf.ApiKey)
if aiConf.Endpoint != "" {
openaiConf.BaseURL = aiConf.Endpoint
}
return openai.NewClientWithConfig(openaiConf)
}
// func (lm *LLM) Ask(messages []ai.ChatMessage, config ai.ChatConfig, callback func(answer string)) (string, ai.Usage, error) {
func Chat(aiConf *ai.AIConfig, messages []ai.ChatMessage, callback func(string), conf ai.ChatConfig) (ai.ChatResult, error) {
chatMessages := make([]openai.ChatCompletionMessage, len(messages))
for i, msg := range messages {
var contents []openai.ChatMessagePart
if msg.Contents != nil {
contents = make([]openai.ChatMessagePart, len(msg.Contents))
for j, inPart := range msg.Contents {
part := openai.ChatMessagePart{}
switch inPart.Type {
case ai.TypeText:
part.Type = openai.ChatMessagePartTypeText
part.Text = inPart.Content
case ai.TypeImage:
part.Type = openai.ChatMessagePartTypeImageURL
part.ImageURL = &openai.ChatMessageImageURL{
URL: inPart.Content,
Detail: openai.ImageURLDetailAuto,
}
default:
part.Type = openai.ChatMessagePartType(inPart.Type)
part.Text = inPart.Content
}
contents[j] = part
}
}
if len(contents) == 1 && contents[0].Type == ai.TypeText {
chatMessages[i] = openai.ChatCompletionMessage{
Role: msg.Role,
Content: contents[0].Text,
}
} else {
chatMessages[i] = openai.ChatCompletionMessage{
Role: msg.Role,
MultiContent: contents,
}
}
}
if conf.SystemPrompt != "" {
chatMessages = append([]openai.ChatCompletionMessage{{
Role: openai.ChatMessageRoleSystem,
Content: conf.SystemPrompt,
}}, chatMessages...)
}
opt := openai.ChatCompletionRequest{
Model: conf.Model,
Messages: chatMessages,
MaxTokens: conf.MaxTokens,
Temperature: float32(conf.Temperature),
TopP: float32(conf.TopP),
StreamOptions: &openai.StreamOptions{
IncludeUsage: true,
},
}
for name, toolConf := range conf.Tools {
switch name {
case ai.ToolCodeInterpreter:
opt.Tools = append(opt.Tools, openai.Tool{Type: "code_interpreter"})
case ai.ToolFunction:
conf := openai.FunctionDefinition{}
u.Convert(toolConf, &conf)
opt.Tools = append(opt.Tools, openai.Tool{Type: openai.ToolTypeFunction, Function: &conf})
}
}
c := getClient(aiConf)
if callback != nil {
opt.Stream = true
r, err := c.CreateChatCompletionStream(context.Background(), opt)
if err == nil {
results := make([]string, 0)
out := ai.ChatResult{}
for {
if r2, err := r.Recv(); err == nil {
if r2.Choices != nil {
for _, ch := range r2.Choices {
text := ch.Delta.Content
callback(text)
results = append(results, text)
}
}
if r2.Usage != nil {
out.AskTokens += int64(r2.Usage.PromptTokens)
out.AnswerTokens += int64(r2.Usage.CompletionTokens)
out.TotalTokens += int64(r2.Usage.TotalTokens)
}
} else {
break
}
}
_ = r.Close()
out.Result = strings.Join(results, "")
return out, nil
} else {
log.DefaultLogger.Error(err.Error())
return ai.ChatResult{}, err
}
} else {
t1 := time.Now().UnixMilli()
if r, err := c.CreateChatCompletion(context.Background(), opt); err == nil {
t2 := time.Now().UnixMilli() - t1
results := make([]string, 0)
if r.Choices != nil {
for _, ch := range r.Choices {
results = append(results, ch.Message.Content)
}
}
return ai.ChatResult{
Result: strings.Join(results, ""),
AskTokens: int64(r.Usage.PromptTokens),
AnswerTokens: int64(r.Usage.CompletionTokens),
TotalTokens: int64(r.Usage.TotalTokens),
UsedTime: t2,
}, nil
} else {
return ai.ChatResult{}, err
}
}
}
func Embedding(aiConf *ai.AIConfig, text string, embeddingConf ai.EmbeddingConfig) (ai.EmbeddingResult, error) {
c := getClient(aiConf)
req := openai.EmbeddingRequest{
Input: text,
Model: openai.EmbeddingModel(embeddingConf.Model),
User: "",
EncodingFormat: "",
Dimensions: 0,
}
t1 := time.Now().UnixMilli()
if r, err := c.CreateEmbeddings(context.Background(), req); err == nil {
t2 := time.Now().UnixMilli() - t1
buf := new(bytes.Buffer)
if r.Data != nil {
for _, ch := range r.Data {
for _, v := range ch.Embedding {
_ = binary.Write(buf, binary.LittleEndian, v)
}
}
}
return ai.EmbeddingResult{
Result: buf.Bytes(),
AskTokens: int64(r.Usage.PromptTokens),
AnswerTokens: int64(r.Usage.CompletionTokens),
TotalTokens: int64(r.Usage.TotalTokens),
UsedTime: t2,
}, nil
} else {
return ai.EmbeddingResult{}, err
}
}