138 lines
3.9 KiB
Go
138 lines
3.9 KiB
Go
package zhipu
|
|
|
|
import (
|
|
"apigo.cc/ai/ai/llm"
|
|
"apigo.cc/ai/ai/llm/zhipu/zhipu"
|
|
"context"
|
|
"strings"
|
|
)
|
|
|
|
func (lm *LLM) FastAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
|
|
return lm.Ask(messages, llm.ChatConfig{
|
|
Model: ModelGLM4Flash,
|
|
}, callback)
|
|
}
|
|
|
|
func (lm *LLM) LongAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
|
|
return lm.Ask(messages, llm.ChatConfig{
|
|
Model: ModelGLM4Long,
|
|
}, callback)
|
|
}
|
|
|
|
func (lm *LLM) BatterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
|
|
return lm.Ask(messages, llm.ChatConfig{
|
|
Model: ModelGLM4Plus,
|
|
}, callback)
|
|
}
|
|
|
|
func (lm *LLM) BestAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
|
|
return lm.Ask(messages, llm.ChatConfig{
|
|
Model: ModelGLM40520,
|
|
}, callback)
|
|
}
|
|
|
|
func (lm *LLM) MultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
|
|
return lm.Ask(messages, llm.ChatConfig{
|
|
Model: ModelGLM4VPlus,
|
|
}, callback)
|
|
}
|
|
|
|
func (lm *LLM) BestMultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
|
|
return lm.Ask(messages, llm.ChatConfig{
|
|
Model: ModelGLM4V,
|
|
}, callback)
|
|
}
|
|
|
|
func (lm *LLM) CodeInterpreterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
|
|
return lm.Ask(messages, llm.ChatConfig{
|
|
Model: ModelGLM4AllTools,
|
|
Tools: map[string]any{llm.ToolCodeInterpreter: nil},
|
|
}, callback)
|
|
}
|
|
|
|
func (lm *LLM) WebSearchAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
|
|
return lm.Ask(messages, llm.ChatConfig{
|
|
Model: ModelGLM4AllTools,
|
|
Tools: map[string]any{llm.ToolWebSearch: nil},
|
|
}, callback)
|
|
}
|
|
|
|
func (lm *LLM) Ask(messages []llm.ChatMessage, config llm.ChatConfig, callback func(answer string)) (string, llm.TokenUsage, error) {
|
|
config.SetDefault(&lm.config.ChatConfig)
|
|
c, err := zhipu.NewClient(zhipu.WithAPIKey(lm.config.ApiKey), zhipu.WithBaseURL(lm.config.Endpoint))
|
|
if err != nil {
|
|
return "", llm.TokenUsage{}, err
|
|
}
|
|
|
|
cc := c.ChatCompletion(config.GetModel())
|
|
for _, msg := range messages {
|
|
var contents []zhipu.ChatCompletionMultiContent
|
|
if msg.Contents != nil {
|
|
contents = make([]zhipu.ChatCompletionMultiContent, len(msg.Contents))
|
|
for j, inPart := range msg.Contents {
|
|
part := zhipu.ChatCompletionMultiContent{}
|
|
part.Type = NameMap[inPart.Type]
|
|
switch inPart.Type {
|
|
case llm.TypeText:
|
|
part.Text = inPart.Content
|
|
case llm.TypeImage:
|
|
part.ImageURL = &zhipu.URLItem{URL: inPart.Content}
|
|
case llm.TypeVideo:
|
|
part.VideoURL = &zhipu.URLItem{URL: inPart.Content}
|
|
}
|
|
contents[j] = part
|
|
}
|
|
}
|
|
cc.AddMessage(zhipu.ChatCompletionMultiMessage{
|
|
Role: NameMap[msg.Role],
|
|
Content: contents,
|
|
})
|
|
}
|
|
|
|
for name := range config.GetTools() {
|
|
switch name {
|
|
case llm.ToolCodeInterpreter:
|
|
cc.AddTool(zhipu.ChatCompletionToolCodeInterpreter{})
|
|
case llm.ToolWebSearch:
|
|
cc.AddTool(zhipu.ChatCompletionToolWebBrowser{})
|
|
}
|
|
}
|
|
|
|
if config.GetMaxTokens() != 0 {
|
|
cc.SetMaxTokens(config.GetMaxTokens())
|
|
}
|
|
if config.GetTemperature() != 0 {
|
|
cc.SetTemperature(config.GetTemperature())
|
|
}
|
|
if config.GetTopP() != 0 {
|
|
cc.SetTopP(config.GetTopP())
|
|
}
|
|
if callback != nil {
|
|
cc.SetStreamHandler(func(r2 zhipu.ChatCompletionResponse) error {
|
|
if r2.Choices != nil {
|
|
for _, ch := range r2.Choices {
|
|
text := ch.Delta.Content
|
|
callback(text)
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
if r, err := cc.Do(context.Background()); err == nil {
|
|
results := make([]string, 0)
|
|
if r.Choices != nil {
|
|
for _, ch := range r.Choices {
|
|
results = append(results, ch.Message.Content)
|
|
}
|
|
}
|
|
return strings.Join(results, ""), llm.TokenUsage{
|
|
AskTokens: r.Usage.PromptTokens,
|
|
AnswerTokens: r.Usage.CompletionTokens,
|
|
TotalTokens: r.Usage.TotalTokens,
|
|
}, nil
|
|
} else {
|
|
return "", llm.TokenUsage{}, err
|
|
}
|
|
}
|