ai_old/llm/openai/chat.go

173 lines
5.0 KiB
Go
Raw Normal View History

2024-09-17 18:44:21 +08:00
package openai
import (
"apigo.cc/ai/ai/llm"
"context"
"github.com/sashabaranov/go-openai"
"github.com/ssgo/log"
"strings"
)
func (lm *LLM) FastAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
return lm.Ask(messages, llm.ChatConfig{
Model: ModelGPT_4o_mini_2024_07_18,
}, callback)
}
func (lm *LLM) LongAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
return lm.Ask(messages, llm.ChatConfig{
Model: ModelGPT_4_32k_0613,
}, callback)
}
func (lm *LLM) BatterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
return lm.Ask(messages, llm.ChatConfig{
Model: ModelGPT_4_turbo,
}, callback)
}
func (lm *LLM) BestAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
return lm.Ask(messages, llm.ChatConfig{
Model: ModelGPT_4o_2024_08_06,
}, callback)
}
func (lm *LLM) MultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
return lm.Ask(messages, llm.ChatConfig{
Model: ModelGPT_4o_mini_2024_07_18,
}, callback)
}
func (lm *LLM) BestMultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
return lm.Ask(messages, llm.ChatConfig{
Model: ModelGPT_4o_2024_08_06,
}, callback)
}
func (lm *LLM) CodeInterpreterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
return lm.Ask(messages, llm.ChatConfig{
Model: ModelGPT_4o,
Tools: map[string]any{llm.ToolCodeInterpreter: nil},
}, callback)
}
func (lm *LLM) WebSearchAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
return lm.Ask(messages, llm.ChatConfig{
Model: ModelGPT_4o_mini_2024_07_18,
Tools: map[string]any{llm.ToolWebSearch: nil},
}, callback)
}
func (lm *LLM) Ask(messages []llm.ChatMessage, config llm.ChatConfig, callback func(answer string)) (string, llm.TokenUsage, error) {
openaiConf := openai.DefaultConfig(lm.config.ApiKey)
if lm.config.Endpoint != "" {
openaiConf.BaseURL = lm.config.Endpoint
}
config.SetDefault(&lm.config.ChatConfig)
agentMessages := make([]openai.ChatCompletionMessage, len(messages))
for i, msg := range messages {
var contents []openai.ChatMessagePart
if msg.Contents != nil {
contents = make([]openai.ChatMessagePart, len(msg.Contents))
for j, inPart := range msg.Contents {
part := openai.ChatMessagePart{}
part.Type = TypeMap[inPart.Type]
switch inPart.Type {
case llm.TypeText:
part.Text = inPart.Content
case llm.TypeImage:
part.ImageURL = &openai.ChatMessageImageURL{
URL: inPart.Content,
Detail: openai.ImageURLDetailAuto,
}
}
contents[j] = part
}
}
2024-09-23 18:15:02 +08:00
if len(contents) == 1 && contents[0].Type == llm.TypeText {
agentMessages[i] = openai.ChatCompletionMessage{
Role: RoleMap[msg.Role],
Content: contents[0].Text,
}
} else {
agentMessages[i] = openai.ChatCompletionMessage{
Role: RoleMap[msg.Role],
MultiContent: contents,
}
2024-09-17 18:44:21 +08:00
}
}
opt := openai.ChatCompletionRequest{
Model: config.GetModel(),
Messages: agentMessages,
MaxTokens: config.GetMaxTokens(),
Temperature: float32(config.GetTemperature()),
TopP: float32(config.GetTopP()),
StreamOptions: &openai.StreamOptions{
IncludeUsage: true,
},
}
for name := range config.GetTools() {
switch name {
case llm.ToolCodeInterpreter:
opt.Tools = append(opt.Tools, openai.Tool{Type: "code_interpreter"})
case llm.ToolWebSearch:
}
}
c := openai.NewClientWithConfig(openaiConf)
if callback != nil {
opt.Stream = true
r, err := c.CreateChatCompletionStream(context.Background(), opt)
if err == nil {
results := make([]string, 0)
usage := llm.TokenUsage{}
for {
if r2, err := r.Recv(); err == nil {
if r2.Choices != nil {
for _, ch := range r2.Choices {
text := ch.Delta.Content
callback(text)
results = append(results, text)
}
}
if r2.Usage != nil {
usage.AskTokens += int64(r2.Usage.PromptTokens)
usage.AnswerTokens += int64(r2.Usage.CompletionTokens)
usage.TotalTokens += int64(r2.Usage.TotalTokens)
}
} else {
break
}
}
_ = r.Close()
return strings.Join(results, ""), usage, nil
} else {
log.DefaultLogger.Error(err.Error())
return "", llm.TokenUsage{}, err
}
} else {
r, err := c.CreateChatCompletion(context.Background(), opt)
if err == nil {
results := make([]string, 0)
if r.Choices != nil {
for _, ch := range r.Choices {
results = append(results, ch.Message.Content)
}
}
return strings.Join(results, ""), llm.TokenUsage{
AskTokens: int64(r.Usage.PromptTokens),
AnswerTokens: int64(r.Usage.CompletionTokens),
TotalTokens: int64(r.Usage.TotalTokens),
}, nil
} else {
//fmt.Println(u.BMagenta(err.Error()), u.BMagenta(u.JsonP(r)))
return "", llm.TokenUsage{}, err
}
}
}