231 lines
6.4 KiB
Go
231 lines
6.4 KiB
Go
|
package openai
|
||
|
|
||
|
import (
|
||
|
"apigo.cc/ai/llm/llm"
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"encoding/binary"
|
||
|
"fmt"
|
||
|
"github.com/sashabaranov/go-openai"
|
||
|
"github.com/ssgo/log"
|
||
|
"github.com/ssgo/u"
|
||
|
"strings"
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
func (lm *LLM) FastAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
||
|
return lm.Ask(messages, llm.ChatConfig{
|
||
|
Model: ModelGPT_4o_mini_2024_07_18,
|
||
|
}, callback)
|
||
|
}
|
||
|
|
||
|
func (lm *LLM) LongAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
||
|
return lm.Ask(messages, llm.ChatConfig{
|
||
|
Model: ModelGPT_4_32k_0613,
|
||
|
}, callback)
|
||
|
}
|
||
|
|
||
|
func (lm *LLM) BatterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
||
|
return lm.Ask(messages, llm.ChatConfig{
|
||
|
Model: ModelGPT_4_turbo,
|
||
|
}, callback)
|
||
|
}
|
||
|
|
||
|
func (lm *LLM) BestAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
||
|
return lm.Ask(messages, llm.ChatConfig{
|
||
|
Model: ModelGPT_4o_2024_08_06,
|
||
|
}, callback)
|
||
|
}
|
||
|
|
||
|
func (lm *LLM) MultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
||
|
return lm.Ask(messages, llm.ChatConfig{
|
||
|
Model: ModelGPT_4o_mini_2024_07_18,
|
||
|
}, callback)
|
||
|
}
|
||
|
|
||
|
func (lm *LLM) BestMultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
||
|
return lm.Ask(messages, llm.ChatConfig{
|
||
|
Model: ModelGPT_4o_2024_08_06,
|
||
|
}, callback)
|
||
|
}
|
||
|
|
||
|
func (lm *LLM) CodeInterpreterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
||
|
return lm.Ask(messages, llm.ChatConfig{
|
||
|
Model: ModelGPT_4o,
|
||
|
Tools: map[string]any{llm.ToolCodeInterpreter: nil},
|
||
|
}, callback)
|
||
|
}
|
||
|
|
||
|
func (lm *LLM) WebSearchAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
||
|
return lm.Ask(messages, llm.ChatConfig{
|
||
|
Model: ModelGPT_4o_mini_2024_07_18,
|
||
|
Tools: map[string]any{llm.ToolWebSearch: nil},
|
||
|
}, callback)
|
||
|
}
|
||
|
|
||
|
func (lm *LLM) Ask(messages []llm.ChatMessage, config llm.ChatConfig, callback func(answer string)) (string, llm.Usage, error) {
|
||
|
openaiConf := openai.DefaultConfig(lm.config.ApiKey)
|
||
|
if lm.config.Endpoint != "" {
|
||
|
openaiConf.BaseURL = lm.config.Endpoint
|
||
|
}
|
||
|
|
||
|
config.SetDefault(&lm.config.ChatConfig)
|
||
|
|
||
|
agentMessages := make([]openai.ChatCompletionMessage, len(messages))
|
||
|
for i, msg := range messages {
|
||
|
var contents []openai.ChatMessagePart
|
||
|
if msg.Contents != nil {
|
||
|
contents = make([]openai.ChatMessagePart, len(msg.Contents))
|
||
|
for j, inPart := range msg.Contents {
|
||
|
part := openai.ChatMessagePart{}
|
||
|
part.Type = TypeMap[inPart.Type]
|
||
|
switch inPart.Type {
|
||
|
case llm.TypeText:
|
||
|
part.Text = inPart.Content
|
||
|
case llm.TypeImage:
|
||
|
part.ImageURL = &openai.ChatMessageImageURL{
|
||
|
URL: inPart.Content,
|
||
|
Detail: openai.ImageURLDetailAuto,
|
||
|
}
|
||
|
}
|
||
|
contents[j] = part
|
||
|
}
|
||
|
}
|
||
|
if len(contents) == 1 && contents[0].Type == llm.TypeText {
|
||
|
agentMessages[i] = openai.ChatCompletionMessage{
|
||
|
Role: RoleMap[msg.Role],
|
||
|
Content: contents[0].Text,
|
||
|
}
|
||
|
} else {
|
||
|
agentMessages[i] = openai.ChatCompletionMessage{
|
||
|
Role: RoleMap[msg.Role],
|
||
|
MultiContent: contents,
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
opt := openai.ChatCompletionRequest{
|
||
|
Model: config.GetModel(),
|
||
|
Messages: agentMessages,
|
||
|
MaxTokens: config.GetMaxTokens(),
|
||
|
Temperature: float32(config.GetTemperature()),
|
||
|
TopP: float32(config.GetTopP()),
|
||
|
StreamOptions: &openai.StreamOptions{
|
||
|
IncludeUsage: true,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for name := range config.GetTools() {
|
||
|
switch name {
|
||
|
case llm.ToolCodeInterpreter:
|
||
|
opt.Tools = append(opt.Tools, openai.Tool{Type: "code_interpreter"})
|
||
|
case llm.ToolWebSearch:
|
||
|
}
|
||
|
}
|
||
|
|
||
|
c := openai.NewClientWithConfig(openaiConf)
|
||
|
if callback != nil {
|
||
|
opt.Stream = true
|
||
|
r, err := c.CreateChatCompletionStream(context.Background(), opt)
|
||
|
if err == nil {
|
||
|
results := make([]string, 0)
|
||
|
usage := llm.Usage{}
|
||
|
for {
|
||
|
if r2, err := r.Recv(); err == nil {
|
||
|
if r2.Choices != nil {
|
||
|
for _, ch := range r2.Choices {
|
||
|
text := ch.Delta.Content
|
||
|
callback(text)
|
||
|
results = append(results, text)
|
||
|
}
|
||
|
}
|
||
|
if r2.Usage != nil {
|
||
|
usage.AskTokens += int64(r2.Usage.PromptTokens)
|
||
|
usage.AnswerTokens += int64(r2.Usage.CompletionTokens)
|
||
|
usage.TotalTokens += int64(r2.Usage.TotalTokens)
|
||
|
}
|
||
|
} else {
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
_ = r.Close()
|
||
|
return strings.Join(results, ""), usage, nil
|
||
|
} else {
|
||
|
log.DefaultLogger.Error(err.Error())
|
||
|
return "", llm.Usage{}, err
|
||
|
}
|
||
|
} else {
|
||
|
t1 := time.Now().UnixMilli()
|
||
|
if r, err := c.CreateChatCompletion(context.Background(), opt); err == nil {
|
||
|
t2 := time.Now().UnixMilli() - t1
|
||
|
results := make([]string, 0)
|
||
|
if r.Choices != nil {
|
||
|
for _, ch := range r.Choices {
|
||
|
results = append(results, ch.Message.Content)
|
||
|
}
|
||
|
}
|
||
|
return strings.Join(results, ""), llm.Usage{
|
||
|
AskTokens: int64(r.Usage.PromptTokens),
|
||
|
AnswerTokens: int64(r.Usage.CompletionTokens),
|
||
|
TotalTokens: int64(r.Usage.TotalTokens),
|
||
|
UsedTime: t2,
|
||
|
}, nil
|
||
|
} else {
|
||
|
//fmt.Println(u.BMagenta(err.Error()), u.BMagenta(u.JsonP(r)))
|
||
|
return "", llm.Usage{}, err
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (lm *LLM) FastEmbedding(text string) ([]byte, llm.Usage, error) {
|
||
|
return lm.Embedding(text, string(openai.AdaEmbeddingV2))
|
||
|
}
|
||
|
|
||
|
func (lm *LLM) BestEmbedding(text string) ([]byte, llm.Usage, error) {
|
||
|
return lm.Embedding(text, string(openai.LargeEmbedding3))
|
||
|
}
|
||
|
|
||
|
func (lm *LLM) Embedding(text, model string) ([]byte, llm.Usage, error) {
|
||
|
fmt.Println(111, model, text)
|
||
|
openaiConf := openai.DefaultConfig(lm.config.ApiKey)
|
||
|
if lm.config.Endpoint != "" {
|
||
|
openaiConf.BaseURL = lm.config.Endpoint
|
||
|
}
|
||
|
|
||
|
c := openai.NewClientWithConfig(openaiConf)
|
||
|
req := openai.EmbeddingRequest{
|
||
|
Input: text,
|
||
|
Model: openai.EmbeddingModel(model),
|
||
|
User: "",
|
||
|
EncodingFormat: "",
|
||
|
Dimensions: 0,
|
||
|
}
|
||
|
|
||
|
if lm.config.Debug {
|
||
|
fmt.Println(u.JsonP(req))
|
||
|
}
|
||
|
|
||
|
t1 := time.Now().UnixMilli()
|
||
|
if r, err := c.CreateEmbeddings(context.Background(), req); err == nil {
|
||
|
t2 := time.Now().UnixMilli() - t1
|
||
|
buf := new(bytes.Buffer)
|
||
|
if r.Data != nil {
|
||
|
for _, ch := range r.Data {
|
||
|
for _, v := range ch.Embedding {
|
||
|
_ = binary.Write(buf, binary.LittleEndian, v)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
fmt.Println(len(buf.Bytes()))
|
||
|
return buf.Bytes(), llm.Usage{
|
||
|
AskTokens: int64(r.Usage.PromptTokens),
|
||
|
AnswerTokens: int64(r.Usage.CompletionTokens),
|
||
|
TotalTokens: int64(r.Usage.TotalTokens),
|
||
|
UsedTime: t2,
|
||
|
}, nil
|
||
|
} else {
|
||
|
fmt.Println(err.Error())
|
||
|
return nil, llm.Usage{}, err
|
||
|
}
|
||
|
}
|