openai/chat.go

179 lines
4.8 KiB
Go
Raw Normal View History

2024-09-07 23:13:36 +08:00
package openai
import (
2024-10-25 16:04:36 +08:00
"bytes"
2024-09-07 23:13:36 +08:00
"context"
2024-10-25 16:04:36 +08:00
"encoding/binary"
2024-10-31 15:27:22 +08:00
"strings"
"time"
"apigo.cc/ai"
2024-09-07 23:13:36 +08:00
"github.com/sashabaranov/go-openai"
"github.com/ssgo/log"
2024-10-25 16:04:36 +08:00
"github.com/ssgo/u"
2024-09-07 23:13:36 +08:00
)
2024-10-31 15:27:22 +08:00
func getClient(aiConf *ai.AIConfig) *openai.Client {
openaiConf := openai.DefaultConfig(aiConf.ApiKey)
if aiConf.Endpoint != "" {
openaiConf.BaseURL = aiConf.Endpoint
2024-09-07 23:13:36 +08:00
}
2024-10-31 15:27:22 +08:00
return openai.NewClientWithConfig(openaiConf)
}
2024-09-07 23:13:36 +08:00
2024-10-31 15:27:22 +08:00
// func (lm *LLM) Ask(messages []ai.ChatMessage, config ai.ChatConfig, callback func(answer string)) (string, ai.Usage, error) {
func Chat(aiConf *ai.AIConfig, messages []ai.ChatMessage, callback func(string), conf ai.ChatConfig) (ai.ChatResult, error) {
chatMessages := make([]openai.ChatCompletionMessage, len(messages))
2024-09-07 23:13:36 +08:00
for i, msg := range messages {
var contents []openai.ChatMessagePart
if msg.Contents != nil {
contents = make([]openai.ChatMessagePart, len(msg.Contents))
for j, inPart := range msg.Contents {
part := openai.ChatMessagePart{}
switch inPart.Type {
2024-10-31 15:27:22 +08:00
case ai.TypeText:
part.Type = openai.ChatMessagePartTypeText
2024-09-07 23:13:36 +08:00
part.Text = inPart.Content
2024-10-31 15:27:22 +08:00
case ai.TypeImage:
part.Type = openai.ChatMessagePartTypeImageURL
2024-09-07 23:13:36 +08:00
part.ImageURL = &openai.ChatMessageImageURL{
URL: inPart.Content,
Detail: openai.ImageURLDetailAuto,
}
2024-10-31 15:27:22 +08:00
default:
part.Type = openai.ChatMessagePartType(inPart.Type)
part.Text = inPart.Content
2024-09-07 23:13:36 +08:00
}
contents[j] = part
}
}
2024-10-31 15:27:22 +08:00
if len(contents) == 1 && contents[0].Type == ai.TypeText {
chatMessages[i] = openai.ChatCompletionMessage{
Role: msg.Role,
2024-10-25 16:04:36 +08:00
Content: contents[0].Text,
}
} else {
2024-10-31 15:27:22 +08:00
chatMessages[i] = openai.ChatCompletionMessage{
Role: msg.Role,
2024-10-25 16:04:36 +08:00
MultiContent: contents,
}
2024-09-07 23:13:36 +08:00
}
}
2024-10-31 15:27:22 +08:00
if conf.SystemPrompt != "" {
chatMessages = append([]openai.ChatCompletionMessage{{
Role: openai.ChatMessageRoleSystem,
Content: conf.SystemPrompt,
}}, chatMessages...)
}
2024-09-07 23:13:36 +08:00
opt := openai.ChatCompletionRequest{
2024-10-31 15:27:22 +08:00
Model: conf.Model,
Messages: chatMessages,
MaxTokens: conf.MaxTokens,
Temperature: float32(conf.Temperature),
TopP: float32(conf.TopP),
2024-09-07 23:13:36 +08:00
StreamOptions: &openai.StreamOptions{
IncludeUsage: true,
},
}
2024-10-31 15:27:22 +08:00
for name, toolConf := range conf.Tools {
2024-09-07 23:13:36 +08:00
switch name {
2024-10-31 15:27:22 +08:00
case ai.ToolCodeInterpreter:
2024-09-07 23:13:36 +08:00
opt.Tools = append(opt.Tools, openai.Tool{Type: "code_interpreter"})
2024-10-31 15:27:22 +08:00
case ai.ToolFunction:
conf := openai.FunctionDefinition{}
u.Convert(toolConf, &conf)
opt.Tools = append(opt.Tools, openai.Tool{Type: openai.ToolTypeFunction, Function: &conf})
2024-09-07 23:13:36 +08:00
}
}
2024-10-31 15:27:22 +08:00
c := getClient(aiConf)
2024-09-07 23:13:36 +08:00
if callback != nil {
opt.Stream = true
r, err := c.CreateChatCompletionStream(context.Background(), opt)
if err == nil {
results := make([]string, 0)
2024-10-31 15:27:22 +08:00
out := ai.ChatResult{}
2024-09-07 23:13:36 +08:00
for {
if r2, err := r.Recv(); err == nil {
if r2.Choices != nil {
for _, ch := range r2.Choices {
text := ch.Delta.Content
callback(text)
results = append(results, text)
}
}
if r2.Usage != nil {
2024-10-31 15:27:22 +08:00
out.AskTokens += int64(r2.Usage.PromptTokens)
out.AnswerTokens += int64(r2.Usage.CompletionTokens)
out.TotalTokens += int64(r2.Usage.TotalTokens)
2024-09-07 23:13:36 +08:00
}
} else {
break
}
}
_ = r.Close()
2024-10-31 15:27:22 +08:00
out.Result = strings.Join(results, "")
return out, nil
2024-09-07 23:13:36 +08:00
} else {
log.DefaultLogger.Error(err.Error())
2024-10-31 15:27:22 +08:00
return ai.ChatResult{}, err
2024-09-07 23:13:36 +08:00
}
} else {
2024-10-25 16:04:36 +08:00
t1 := time.Now().UnixMilli()
if r, err := c.CreateChatCompletion(context.Background(), opt); err == nil {
t2 := time.Now().UnixMilli() - t1
2024-09-07 23:13:36 +08:00
results := make([]string, 0)
if r.Choices != nil {
for _, ch := range r.Choices {
results = append(results, ch.Message.Content)
}
}
2024-10-31 15:27:22 +08:00
return ai.ChatResult{
Result: strings.Join(results, ""),
2024-09-07 23:13:36 +08:00
AskTokens: int64(r.Usage.PromptTokens),
AnswerTokens: int64(r.Usage.CompletionTokens),
TotalTokens: int64(r.Usage.TotalTokens),
2024-10-25 16:04:36 +08:00
UsedTime: t2,
2024-09-07 23:13:36 +08:00
}, nil
} else {
2024-10-31 15:27:22 +08:00
return ai.ChatResult{}, err
2024-10-25 16:04:36 +08:00
}
}
}
2024-10-31 15:27:22 +08:00
func Embedding(aiConf *ai.AIConfig, text string, embeddingConf ai.EmbeddingConfig) (ai.EmbeddingResult, error) {
c := getClient(aiConf)
2024-10-25 16:04:36 +08:00
req := openai.EmbeddingRequest{
Input: text,
2024-10-31 15:27:22 +08:00
Model: openai.EmbeddingModel(embeddingConf.Model),
2024-10-25 16:04:36 +08:00
User: "",
EncodingFormat: "",
Dimensions: 0,
}
t1 := time.Now().UnixMilli()
if r, err := c.CreateEmbeddings(context.Background(), req); err == nil {
t2 := time.Now().UnixMilli() - t1
buf := new(bytes.Buffer)
if r.Data != nil {
for _, ch := range r.Data {
for _, v := range ch.Embedding {
_ = binary.Write(buf, binary.LittleEndian, v)
}
}
2024-09-07 23:13:36 +08:00
}
2024-10-31 15:27:22 +08:00
return ai.EmbeddingResult{
Result: buf.Bytes(),
2024-10-25 16:04:36 +08:00
AskTokens: int64(r.Usage.PromptTokens),
AnswerTokens: int64(r.Usage.CompletionTokens),
TotalTokens: int64(r.Usage.TotalTokens),
UsedTime: t2,
}, nil
} else {
2024-10-31 15:27:22 +08:00
return ai.EmbeddingResult{}, err
2024-09-07 23:13:36 +08:00
}
}