168 lines
4.2 KiB
Go
168 lines
4.2 KiB
Go
package zhipu
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/binary"
|
|
"strings"
|
|
"time"
|
|
|
|
"apigo.cc/ai"
|
|
"github.com/ssgo/u"
|
|
"github.com/yankeguo/zhipu"
|
|
)
|
|
|
|
func getClient(aiConf *ai.AIConfig) (client *zhipu.Client, err error) {
|
|
opt := []zhipu.ClientOption{zhipu.WithAPIKey(aiConf.ApiKey)}
|
|
if aiConf.Endpoint != "" {
|
|
opt = append(opt, zhipu.WithBaseURL(aiConf.Endpoint))
|
|
}
|
|
return zhipu.NewClient(opt...)
|
|
}
|
|
|
|
func Chat(aiConf *ai.AIConfig, messages []ai.ChatMessage, callback func(string), conf ai.ChatConfig) (ai.ChatResult, error) {
|
|
c, err := getClient(aiConf)
|
|
if err != nil {
|
|
return ai.ChatResult{}, err
|
|
}
|
|
|
|
cc := c.ChatCompletion(conf.Model)
|
|
if conf.SystemPrompt != "" {
|
|
cc.AddMessage(zhipu.ChatCompletionMessage{
|
|
Role: zhipu.RoleSystem,
|
|
Content: conf.SystemPrompt,
|
|
})
|
|
}
|
|
for _, msg := range messages {
|
|
var contents []zhipu.ChatCompletionMultiContent
|
|
if msg.Contents != nil {
|
|
contents = make([]zhipu.ChatCompletionMultiContent, len(msg.Contents))
|
|
for j, inPart := range msg.Contents {
|
|
part := zhipu.ChatCompletionMultiContent{}
|
|
part.Type = inPart.Type
|
|
switch inPart.Type {
|
|
case ai.TypeText:
|
|
part.Text = inPart.Content
|
|
case ai.TypeImage:
|
|
part.ImageURL = &zhipu.URLItem{URL: inPart.Content}
|
|
//case ai.TypeVideo:
|
|
// part.VideoURL = &zhipu.URLItem{URL: inPart.Content}
|
|
}
|
|
contents[j] = part
|
|
}
|
|
}
|
|
if len(contents) == 1 && contents[0].Type == ai.TypeText {
|
|
cc.AddMessage(zhipu.ChatCompletionMessage{
|
|
Role: msg.Role,
|
|
Content: contents[0].Text,
|
|
})
|
|
} else {
|
|
cc.AddMessage(zhipu.ChatCompletionMultiMessage{
|
|
Role: msg.Role,
|
|
Content: contents,
|
|
})
|
|
}
|
|
}
|
|
|
|
for name, toolConf := range conf.Tools {
|
|
switch name {
|
|
case ai.ToolFunction:
|
|
conf := zhipu.ChatCompletionToolFunction{}
|
|
u.Convert(toolConf, &conf)
|
|
cc.AddTool(conf)
|
|
case ai.ToolCodeInterpreter:
|
|
conf := zhipu.ChatCompletionToolCodeInterpreter{}
|
|
u.Convert(toolConf, &conf)
|
|
cc.AddTool(conf)
|
|
case ai.ToolWebSearch:
|
|
conf := zhipu.ChatCompletionToolWebSearch{}
|
|
u.Convert(toolConf, &conf)
|
|
cc.AddTool(conf)
|
|
case ai.ToolWebBrowser:
|
|
conf := zhipu.ChatCompletionToolWebBrowser{}
|
|
u.Convert(toolConf, &conf)
|
|
cc.AddTool(conf)
|
|
case ai.ToolDrawingTool:
|
|
conf := zhipu.ChatCompletionToolDrawingTool{}
|
|
u.Convert(toolConf, &conf)
|
|
cc.AddTool(conf)
|
|
case ai.ToolRetrieval:
|
|
conf := zhipu.ChatCompletionToolRetrieval{}
|
|
u.Convert(toolConf, &conf)
|
|
cc.AddTool(conf)
|
|
}
|
|
}
|
|
|
|
if conf.MaxTokens != 0 {
|
|
cc.SetMaxTokens(conf.MaxTokens)
|
|
}
|
|
if conf.Temperature != 0 {
|
|
cc.SetTemperature(conf.Temperature)
|
|
}
|
|
if conf.TopP != 0 {
|
|
cc.SetTopP(conf.TopP)
|
|
}
|
|
if callback != nil {
|
|
cc.SetStreamHandler(func(r2 zhipu.ChatCompletionResponse) error {
|
|
if r2.Choices != nil {
|
|
for _, ch := range r2.Choices {
|
|
text := ch.Delta.Content
|
|
callback(text)
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
t1 := time.Now().UnixMilli()
|
|
if r, err := cc.Do(context.Background()); err == nil {
|
|
t2 := time.Now().UnixMilli() - t1
|
|
results := make([]string, 0)
|
|
if r.Choices != nil {
|
|
for _, ch := range r.Choices {
|
|
results = append(results, ch.Message.Content)
|
|
}
|
|
}
|
|
return ai.ChatResult{
|
|
Result: strings.Join(results, ""),
|
|
AskTokens: r.Usage.PromptTokens,
|
|
AnswerTokens: r.Usage.CompletionTokens,
|
|
TotalTokens: r.Usage.TotalTokens,
|
|
UsedTime: t2,
|
|
}, nil
|
|
} else {
|
|
return ai.ChatResult{}, err
|
|
}
|
|
}
|
|
|
|
func Embedding(aiConf *ai.AIConfig, text string, embeddingConf ai.EmbeddingConfig) (ai.EmbeddingResult, error) {
|
|
c, err := getClient(aiConf)
|
|
if err != nil {
|
|
return ai.EmbeddingResult{}, err
|
|
}
|
|
|
|
cc := c.Embedding(embeddingConf.Model)
|
|
cc.SetInput(text)
|
|
t1 := time.Now().UnixMilli()
|
|
if r, err := cc.Do(context.Background()); err == nil {
|
|
t2 := time.Now().UnixMilli() - t1
|
|
buf := new(bytes.Buffer)
|
|
if r.Data != nil {
|
|
for _, ch := range r.Data {
|
|
for _, v := range ch.Embedding {
|
|
_ = binary.Write(buf, binary.LittleEndian, float32(v))
|
|
}
|
|
}
|
|
}
|
|
return ai.EmbeddingResult{
|
|
Result: buf.Bytes(),
|
|
AskTokens: r.Usage.PromptTokens,
|
|
AnswerTokens: r.Usage.CompletionTokens,
|
|
TotalTokens: r.Usage.TotalTokens,
|
|
UsedTime: t2,
|
|
}, nil
|
|
} else {
|
|
return ai.EmbeddingResult{}, err
|
|
}
|
|
}
|