zhipu/chat.go
2024-10-31 15:25:04 +08:00

168 lines
4.2 KiB
Go

package zhipu
import (
"bytes"
"context"
"encoding/binary"
"strings"
"time"
"apigo.cc/ai"
"github.com/ssgo/u"
"github.com/yankeguo/zhipu"
)
func getClient(aiConf *ai.AIConfig) (client *zhipu.Client, err error) {
opt := []zhipu.ClientOption{zhipu.WithAPIKey(aiConf.ApiKey)}
if aiConf.Endpoint != "" {
opt = append(opt, zhipu.WithBaseURL(aiConf.Endpoint))
}
return zhipu.NewClient(opt...)
}
func Chat(aiConf *ai.AIConfig, messages []ai.ChatMessage, callback func(string), conf ai.ChatConfig) (ai.ChatResult, error) {
c, err := getClient(aiConf)
if err != nil {
return ai.ChatResult{}, err
}
cc := c.ChatCompletion(conf.Model)
if conf.SystemPrompt != "" {
cc.AddMessage(zhipu.ChatCompletionMessage{
Role: zhipu.RoleSystem,
Content: conf.SystemPrompt,
})
}
for _, msg := range messages {
var contents []zhipu.ChatCompletionMultiContent
if msg.Contents != nil {
contents = make([]zhipu.ChatCompletionMultiContent, len(msg.Contents))
for j, inPart := range msg.Contents {
part := zhipu.ChatCompletionMultiContent{}
part.Type = inPart.Type
switch inPart.Type {
case ai.TypeText:
part.Text = inPart.Content
case ai.TypeImage:
part.ImageURL = &zhipu.URLItem{URL: inPart.Content}
//case ai.TypeVideo:
// part.VideoURL = &zhipu.URLItem{URL: inPart.Content}
}
contents[j] = part
}
}
if len(contents) == 1 && contents[0].Type == ai.TypeText {
cc.AddMessage(zhipu.ChatCompletionMessage{
Role: msg.Role,
Content: contents[0].Text,
})
} else {
cc.AddMessage(zhipu.ChatCompletionMultiMessage{
Role: msg.Role,
Content: contents,
})
}
}
for name, toolConf := range conf.Tools {
switch name {
case ai.ToolFunction:
conf := zhipu.ChatCompletionToolFunction{}
u.Convert(toolConf, &conf)
cc.AddTool(conf)
case ai.ToolCodeInterpreter:
conf := zhipu.ChatCompletionToolCodeInterpreter{}
u.Convert(toolConf, &conf)
cc.AddTool(conf)
case ai.ToolWebSearch:
conf := zhipu.ChatCompletionToolWebSearch{}
u.Convert(toolConf, &conf)
cc.AddTool(conf)
case ai.ToolWebBrowser:
conf := zhipu.ChatCompletionToolWebBrowser{}
u.Convert(toolConf, &conf)
cc.AddTool(conf)
case ai.ToolDrawingTool:
conf := zhipu.ChatCompletionToolDrawingTool{}
u.Convert(toolConf, &conf)
cc.AddTool(conf)
case ai.ToolRetrieval:
conf := zhipu.ChatCompletionToolRetrieval{}
u.Convert(toolConf, &conf)
cc.AddTool(conf)
}
}
if conf.MaxTokens != 0 {
cc.SetMaxTokens(conf.MaxTokens)
}
if conf.Temperature != 0 {
cc.SetTemperature(conf.Temperature)
}
if conf.TopP != 0 {
cc.SetTopP(conf.TopP)
}
if callback != nil {
cc.SetStreamHandler(func(r2 zhipu.ChatCompletionResponse) error {
if r2.Choices != nil {
for _, ch := range r2.Choices {
text := ch.Delta.Content
callback(text)
}
}
return nil
})
}
t1 := time.Now().UnixMilli()
if r, err := cc.Do(context.Background()); err == nil {
t2 := time.Now().UnixMilli() - t1
results := make([]string, 0)
if r.Choices != nil {
for _, ch := range r.Choices {
results = append(results, ch.Message.Content)
}
}
return ai.ChatResult{
Result: strings.Join(results, ""),
AskTokens: r.Usage.PromptTokens,
AnswerTokens: r.Usage.CompletionTokens,
TotalTokens: r.Usage.TotalTokens,
UsedTime: t2,
}, nil
} else {
return ai.ChatResult{}, err
}
}
func Embedding(aiConf *ai.AIConfig, text string, embeddingConf ai.EmbeddingConfig) (ai.EmbeddingResult, error) {
c, err := getClient(aiConf)
if err != nil {
return ai.EmbeddingResult{}, err
}
cc := c.Embedding(embeddingConf.Model)
cc.SetInput(text)
t1 := time.Now().UnixMilli()
if r, err := cc.Do(context.Background()); err == nil {
t2 := time.Now().UnixMilli() - t1
buf := new(bytes.Buffer)
if r.Data != nil {
for _, ch := range r.Data {
for _, v := range ch.Embedding {
_ = binary.Write(buf, binary.LittleEndian, float32(v))
}
}
}
return ai.EmbeddingResult{
Result: buf.Bytes(),
AskTokens: r.Usage.PromptTokens,
AnswerTokens: r.Usage.CompletionTokens,
TotalTokens: r.Usage.TotalTokens,
UsedTime: t2,
}, nil
} else {
return ai.EmbeddingResult{}, err
}
}