179 lines
4.8 KiB
Go
179 lines
4.8 KiB
Go
package openai
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/binary"
|
|
"strings"
|
|
"time"
|
|
|
|
"apigo.cc/ai"
|
|
"github.com/sashabaranov/go-openai"
|
|
"github.com/ssgo/log"
|
|
"github.com/ssgo/u"
|
|
)
|
|
|
|
func getClient(aiConf *ai.AIConfig) *openai.Client {
|
|
openaiConf := openai.DefaultConfig(aiConf.ApiKey)
|
|
if aiConf.Endpoint != "" {
|
|
openaiConf.BaseURL = aiConf.Endpoint
|
|
}
|
|
return openai.NewClientWithConfig(openaiConf)
|
|
}
|
|
|
|
// func (lm *LLM) Ask(messages []ai.ChatMessage, config ai.ChatConfig, callback func(answer string)) (string, ai.Usage, error) {
|
|
func Chat(aiConf *ai.AIConfig, messages []ai.ChatMessage, callback func(string), conf ai.ChatConfig) (ai.ChatResult, error) {
|
|
chatMessages := make([]openai.ChatCompletionMessage, len(messages))
|
|
for i, msg := range messages {
|
|
var contents []openai.ChatMessagePart
|
|
if msg.Contents != nil {
|
|
contents = make([]openai.ChatMessagePart, len(msg.Contents))
|
|
for j, inPart := range msg.Contents {
|
|
part := openai.ChatMessagePart{}
|
|
switch inPart.Type {
|
|
case ai.TypeText:
|
|
part.Type = openai.ChatMessagePartTypeText
|
|
part.Text = inPart.Content
|
|
case ai.TypeImage:
|
|
part.Type = openai.ChatMessagePartTypeImageURL
|
|
part.ImageURL = &openai.ChatMessageImageURL{
|
|
URL: inPart.Content,
|
|
Detail: openai.ImageURLDetailAuto,
|
|
}
|
|
default:
|
|
part.Type = openai.ChatMessagePartType(inPart.Type)
|
|
part.Text = inPart.Content
|
|
}
|
|
contents[j] = part
|
|
}
|
|
}
|
|
if len(contents) == 1 && contents[0].Type == ai.TypeText {
|
|
chatMessages[i] = openai.ChatCompletionMessage{
|
|
Role: msg.Role,
|
|
Content: contents[0].Text,
|
|
}
|
|
} else {
|
|
chatMessages[i] = openai.ChatCompletionMessage{
|
|
Role: msg.Role,
|
|
MultiContent: contents,
|
|
}
|
|
}
|
|
}
|
|
|
|
if conf.SystemPrompt != "" {
|
|
chatMessages = append([]openai.ChatCompletionMessage{{
|
|
Role: openai.ChatMessageRoleSystem,
|
|
Content: conf.SystemPrompt,
|
|
}}, chatMessages...)
|
|
}
|
|
|
|
opt := openai.ChatCompletionRequest{
|
|
Model: conf.Model,
|
|
Messages: chatMessages,
|
|
MaxTokens: conf.MaxTokens,
|
|
Temperature: float32(conf.Temperature),
|
|
TopP: float32(conf.TopP),
|
|
StreamOptions: &openai.StreamOptions{
|
|
IncludeUsage: true,
|
|
},
|
|
}
|
|
|
|
for name, toolConf := range conf.Tools {
|
|
switch name {
|
|
case ai.ToolCodeInterpreter:
|
|
opt.Tools = append(opt.Tools, openai.Tool{Type: "code_interpreter"})
|
|
case ai.ToolFunction:
|
|
conf := openai.FunctionDefinition{}
|
|
u.Convert(toolConf, &conf)
|
|
opt.Tools = append(opt.Tools, openai.Tool{Type: openai.ToolTypeFunction, Function: &conf})
|
|
}
|
|
}
|
|
|
|
c := getClient(aiConf)
|
|
if callback != nil {
|
|
opt.Stream = true
|
|
r, err := c.CreateChatCompletionStream(context.Background(), opt)
|
|
if err == nil {
|
|
results := make([]string, 0)
|
|
out := ai.ChatResult{}
|
|
for {
|
|
if r2, err := r.Recv(); err == nil {
|
|
if r2.Choices != nil {
|
|
for _, ch := range r2.Choices {
|
|
text := ch.Delta.Content
|
|
callback(text)
|
|
results = append(results, text)
|
|
}
|
|
}
|
|
if r2.Usage != nil {
|
|
out.AskTokens += int64(r2.Usage.PromptTokens)
|
|
out.AnswerTokens += int64(r2.Usage.CompletionTokens)
|
|
out.TotalTokens += int64(r2.Usage.TotalTokens)
|
|
}
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
_ = r.Close()
|
|
out.Result = strings.Join(results, "")
|
|
return out, nil
|
|
} else {
|
|
log.DefaultLogger.Error(err.Error())
|
|
return ai.ChatResult{}, err
|
|
}
|
|
} else {
|
|
t1 := time.Now().UnixMilli()
|
|
if r, err := c.CreateChatCompletion(context.Background(), opt); err == nil {
|
|
t2 := time.Now().UnixMilli() - t1
|
|
results := make([]string, 0)
|
|
if r.Choices != nil {
|
|
for _, ch := range r.Choices {
|
|
results = append(results, ch.Message.Content)
|
|
}
|
|
}
|
|
return ai.ChatResult{
|
|
Result: strings.Join(results, ""),
|
|
AskTokens: int64(r.Usage.PromptTokens),
|
|
AnswerTokens: int64(r.Usage.CompletionTokens),
|
|
TotalTokens: int64(r.Usage.TotalTokens),
|
|
UsedTime: t2,
|
|
}, nil
|
|
} else {
|
|
return ai.ChatResult{}, err
|
|
}
|
|
}
|
|
}
|
|
|
|
func Embedding(aiConf *ai.AIConfig, text string, embeddingConf ai.EmbeddingConfig) (ai.EmbeddingResult, error) {
|
|
c := getClient(aiConf)
|
|
req := openai.EmbeddingRequest{
|
|
Input: text,
|
|
Model: openai.EmbeddingModel(embeddingConf.Model),
|
|
User: "",
|
|
EncodingFormat: "",
|
|
Dimensions: 0,
|
|
}
|
|
|
|
t1 := time.Now().UnixMilli()
|
|
if r, err := c.CreateEmbeddings(context.Background(), req); err == nil {
|
|
t2 := time.Now().UnixMilli() - t1
|
|
buf := new(bytes.Buffer)
|
|
if r.Data != nil {
|
|
for _, ch := range r.Data {
|
|
for _, v := range ch.Embedding {
|
|
_ = binary.Write(buf, binary.LittleEndian, v)
|
|
}
|
|
}
|
|
}
|
|
return ai.EmbeddingResult{
|
|
Result: buf.Bytes(),
|
|
AskTokens: int64(r.Usage.PromptTokens),
|
|
AnswerTokens: int64(r.Usage.CompletionTokens),
|
|
TotalTokens: int64(r.Usage.TotalTokens),
|
|
UsedTime: t2,
|
|
}, nil
|
|
} else {
|
|
return ai.EmbeddingResult{}, err
|
|
}
|
|
}
|