zhipu/chat.go
2024-09-07 23:14:12 +08:00

141 lines
4.1 KiB
Go

package zhipu
import (
"apigo.cc/ai/agent"
"apigo.cc/ai/zhipu/zhipu"
"context"
"strings"
)
func (ag *Agent) FastAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) {
return ag.Ask(messages, &agent.ChatModelConfig{
Model: ModelGLM4Flash,
}, callback)
}
func (ag *Agent) LongAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) {
return ag.Ask(messages, &agent.ChatModelConfig{
Model: ModelGLM4Long,
}, callback)
}
func (ag *Agent) BatterAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) {
return ag.Ask(messages, &agent.ChatModelConfig{
Model: ModelGLM4Plus,
}, callback)
}
func (ag *Agent) BestAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) {
return ag.Ask(messages, &agent.ChatModelConfig{
Model: ModelGLM40520,
}, callback)
}
func (ag *Agent) MultiAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) {
return ag.Ask(messages, &agent.ChatModelConfig{
Model: ModelGLM4VPlus,
}, callback)
}
func (ag *Agent) BestMultiAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) {
return ag.Ask(messages, &agent.ChatModelConfig{
Model: ModelGLM4V,
}, callback)
}
func (ag *Agent) CodeInterpreterAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) {
return ag.Ask(messages, &agent.ChatModelConfig{
Model: ModelGLM4AllTools,
Tools: map[string]any{agent.ToolCodeInterpreter: nil},
}, callback)
}
func (ag *Agent) WebSearchAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) {
return ag.Ask(messages, &agent.ChatModelConfig{
Model: ModelGLM4AllTools,
Tools: map[string]any{agent.ToolWebSearch: nil},
}, callback)
}
func (ag *Agent) Ask(messages []agent.ChatMessage, config *agent.ChatModelConfig, callback func(answer string)) (string, agent.TokenUsage, error) {
if config == nil {
config = &agent.ChatModelConfig{}
}
config.SetDefault(&ag.config.DefaultChatModelConfig)
c, err := zhipu.NewClient(zhipu.WithAPIKey(ag.config.ApiKey), zhipu.WithBaseURL(ag.config.Endpoint))
if err != nil {
return "", agent.TokenUsage{}, err
}
cc := c.ChatCompletion(config.GetModel())
for _, msg := range messages {
var contents []zhipu.ChatCompletionMultiContent
if msg.Contents != nil {
contents = make([]zhipu.ChatCompletionMultiContent, len(msg.Contents))
for j, inPart := range msg.Contents {
part := zhipu.ChatCompletionMultiContent{}
part.Type = NameMap[inPart.Type]
switch inPart.Type {
case agent.TypeText:
part.Text = inPart.Content
case agent.TypeImage:
part.ImageURL = &zhipu.URLItem{URL: inPart.Content}
case agent.TypeVideo:
part.VideoURL = &zhipu.URLItem{URL: inPart.Content}
}
contents[j] = part
}
}
cc.AddMessage(zhipu.ChatCompletionMultiMessage{
Role: NameMap[msg.Role],
Content: contents,
})
}
for name := range config.GetTools() {
switch name {
case agent.ToolCodeInterpreter:
cc.AddTool(zhipu.ChatCompletionToolCodeInterpreter{})
case agent.ToolWebSearch:
cc.AddTool(zhipu.ChatCompletionToolWebBrowser{})
}
}
if config.GetMaxTokens() != 0 {
cc.SetMaxTokens(config.GetMaxTokens())
}
if config.GetTemperature() != 0 {
cc.SetTemperature(config.GetTemperature())
}
if config.GetTopP() != 0 {
cc.SetTopP(config.GetTopP())
}
if callback != nil {
cc.SetStreamHandler(func(r2 zhipu.ChatCompletionResponse) error {
if r2.Choices != nil {
for _, ch := range r2.Choices {
text := ch.Delta.Content
callback(text)
}
}
return nil
})
}
if r, err := cc.Do(context.Background()); err == nil {
results := make([]string, 0)
if r.Choices != nil {
for _, ch := range r.Choices {
results = append(results, ch.Message.Content)
}
}
return strings.Join(results, ""), agent.TokenUsage{
AskTokens: r.Usage.PromptTokens,
AnswerTokens: r.Usage.CompletionTokens,
TotalTokens: r.Usage.TotalTokens,
}, nil
} else {
return "", agent.TokenUsage{}, err
}
}