use new gojs
This commit is contained in:
parent
29b0faf61b
commit
3e64ec275a
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
.*
|
||||||
|
!.gitignore
|
||||||
|
go.sum
|
||||||
|
env.yml
|
||||||
|
node_modules
|
||||||
|
package.json
|
83
aigc.go
83
aigc.go
@ -1,83 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
import (
|
|
||||||
"apigo.cc/ai/zhipu/zhipu"
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (ag *Agent) FastMakeImage(prompt, size, refImage string) ([]string, error) {
|
|
||||||
return ag.MakeImage(ModelCogView3Plus, prompt, size, refImage)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ag *Agent) BestMakeImage(prompt, size, refImage string) ([]string, error) {
|
|
||||||
return ag.MakeImage(ModelCogView3, prompt, size, refImage)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ag *Agent) MakeImage(model, prompt, size, refImage string) ([]string, error) {
|
|
||||||
c, err := zhipu.NewClient(zhipu.WithAPIKey(ag.config.ApiKey), zhipu.WithBaseURL(ag.config.Endpoint))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cc := c.ImageGeneration(model).SetPrompt(prompt)
|
|
||||||
if size != "" {
|
|
||||||
cc.SetSize(size)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r, err := cc.Do(context.Background()); err == nil {
|
|
||||||
results := make([]string, 0)
|
|
||||||
for _, item := range r.Data {
|
|
||||||
results = append(results, item.URL)
|
|
||||||
}
|
|
||||||
return results, nil
|
|
||||||
} else {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ag *Agent) FastMakeVideo(prompt, size, refImage string) ([]string, []string, error) {
|
|
||||||
return ag.MakeVideo(ModelCogVideoX, prompt, size, refImage)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ag *Agent) BestMakeVideo(prompt, size, refImage string) ([]string, []string, error) {
|
|
||||||
return ag.MakeVideo(ModelCogVideoX, prompt, size, refImage)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ag *Agent) MakeVideo(model, prompt, size, refImage string) ([]string, []string, error) {
|
|
||||||
c, err := zhipu.NewClient(zhipu.WithAPIKey(ag.config.ApiKey), zhipu.WithBaseURL(ag.config.Endpoint))
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cc := c.VideoGeneration(model).SetPrompt(prompt)
|
|
||||||
if refImage != "" {
|
|
||||||
cc.SetImageURL(refImage)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp, err := cc.Do(context.Background()); err == nil {
|
|
||||||
for i := 0; i < 1200; i++ {
|
|
||||||
r, err := c.AsyncResult(resp.ID).Do(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
if r.TaskStatus == zhipu.VideoGenerationTaskStatusSuccess {
|
|
||||||
covers := make([]string, 0)
|
|
||||||
results := make([]string, 0)
|
|
||||||
for _, item := range r.VideoResult {
|
|
||||||
results = append(results, item.URL)
|
|
||||||
covers = append(covers, item.CoverImageURL)
|
|
||||||
}
|
|
||||||
return results, covers, nil
|
|
||||||
}
|
|
||||||
if r.TaskStatus == zhipu.VideoGenerationTaskStatusFail {
|
|
||||||
return nil, nil, errors.New("fail on task " + resp.ID)
|
|
||||||
}
|
|
||||||
time.Sleep(3 * time.Second)
|
|
||||||
}
|
|
||||||
return nil, nil, errors.New("timeout on task " + resp.ID)
|
|
||||||
} else {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
}
|
|
133
chat.go
133
chat.go
@ -1,70 +1,72 @@
|
|||||||
package zhipu
|
package zhipu
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"apigo.cc/ai/agent"
|
"apigo.cc/ai/llm/llm"
|
||||||
"apigo.cc/ai/zhipu/zhipu"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"github.com/ssgo/u"
|
||||||
|
"github.com/yankeguo/zhipu"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (ag *Agent) FastAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) {
|
func (lm *LLM) FastAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
||||||
return ag.Ask(messages, &agent.ChatModelConfig{
|
return lm.Ask(messages, llm.ChatConfig{
|
||||||
Model: ModelGLM4Flash,
|
Model: ModelGLM4Flash,
|
||||||
}, callback)
|
}, callback)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ag *Agent) LongAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) {
|
func (lm *LLM) LongAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
||||||
return ag.Ask(messages, &agent.ChatModelConfig{
|
return lm.Ask(messages, llm.ChatConfig{
|
||||||
Model: ModelGLM4Long,
|
Model: ModelGLM4Long,
|
||||||
}, callback)
|
}, callback)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ag *Agent) BatterAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) {
|
func (lm *LLM) BatterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
||||||
return ag.Ask(messages, &agent.ChatModelConfig{
|
return lm.Ask(messages, llm.ChatConfig{
|
||||||
Model: ModelGLM4Plus,
|
Model: ModelGLM4Plus,
|
||||||
}, callback)
|
}, callback)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ag *Agent) BestAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) {
|
func (lm *LLM) BestAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
||||||
return ag.Ask(messages, &agent.ChatModelConfig{
|
return lm.Ask(messages, llm.ChatConfig{
|
||||||
Model: ModelGLM40520,
|
Model: ModelGLM40520,
|
||||||
}, callback)
|
}, callback)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ag *Agent) MultiAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) {
|
func (lm *LLM) MultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
||||||
return ag.Ask(messages, &agent.ChatModelConfig{
|
return lm.Ask(messages, llm.ChatConfig{
|
||||||
Model: ModelGLM4VPlus,
|
Model: ModelGLM4VPlus,
|
||||||
}, callback)
|
}, callback)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ag *Agent) BestMultiAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) {
|
func (lm *LLM) BestMultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
||||||
return ag.Ask(messages, &agent.ChatModelConfig{
|
return lm.Ask(messages, llm.ChatConfig{
|
||||||
Model: ModelGLM4V,
|
Model: ModelGLM4V,
|
||||||
}, callback)
|
}, callback)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ag *Agent) CodeInterpreterAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) {
|
func (lm *LLM) CodeInterpreterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
||||||
return ag.Ask(messages, &agent.ChatModelConfig{
|
return lm.Ask(messages, llm.ChatConfig{
|
||||||
Model: ModelGLM4AllTools,
|
Model: ModelGLM4AllTools,
|
||||||
Tools: map[string]any{agent.ToolCodeInterpreter: nil},
|
Tools: map[string]any{llm.ToolCodeInterpreter: nil},
|
||||||
}, callback)
|
}, callback)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ag *Agent) WebSearchAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) {
|
func (lm *LLM) WebSearchAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
||||||
return ag.Ask(messages, &agent.ChatModelConfig{
|
return lm.Ask(messages, llm.ChatConfig{
|
||||||
Model: ModelGLM4AllTools,
|
Model: ModelGLM4AllTools,
|
||||||
Tools: map[string]any{agent.ToolWebSearch: nil},
|
Tools: map[string]any{llm.ToolWebSearch: nil},
|
||||||
}, callback)
|
}, callback)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ag *Agent) Ask(messages []agent.ChatMessage, config *agent.ChatModelConfig, callback func(answer string)) (string, agent.TokenUsage, error) {
|
func (lm *LLM) Ask(messages []llm.ChatMessage, config llm.ChatConfig, callback func(answer string)) (string, llm.Usage, error) {
|
||||||
if config == nil {
|
config.SetDefault(&lm.config.ChatConfig)
|
||||||
config = &agent.ChatModelConfig{}
|
c, err := zhipu.NewClient(zhipu.WithAPIKey(lm.config.ApiKey), zhipu.WithBaseURL(lm.config.Endpoint))
|
||||||
}
|
|
||||||
config.SetDefault(&ag.config.DefaultChatModelConfig)
|
|
||||||
c, err := zhipu.NewClient(zhipu.WithAPIKey(ag.config.ApiKey), zhipu.WithBaseURL(ag.config.Endpoint))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", agent.TokenUsage{}, err
|
return "", llm.Usage{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
cc := c.ChatCompletion(config.GetModel())
|
cc := c.ChatCompletion(config.GetModel())
|
||||||
@ -76,27 +78,34 @@ func (ag *Agent) Ask(messages []agent.ChatMessage, config *agent.ChatModelConfig
|
|||||||
part := zhipu.ChatCompletionMultiContent{}
|
part := zhipu.ChatCompletionMultiContent{}
|
||||||
part.Type = NameMap[inPart.Type]
|
part.Type = NameMap[inPart.Type]
|
||||||
switch inPart.Type {
|
switch inPart.Type {
|
||||||
case agent.TypeText:
|
case llm.TypeText:
|
||||||
part.Text = inPart.Content
|
part.Text = inPart.Content
|
||||||
case agent.TypeImage:
|
case llm.TypeImage:
|
||||||
part.ImageURL = &zhipu.URLItem{URL: inPart.Content}
|
part.ImageURL = &zhipu.URLItem{URL: inPart.Content}
|
||||||
case agent.TypeVideo:
|
//case llm.TypeVideo:
|
||||||
part.VideoURL = &zhipu.URLItem{URL: inPart.Content}
|
// part.VideoURL = &zhipu.URLItem{URL: inPart.Content}
|
||||||
}
|
}
|
||||||
contents[j] = part
|
contents[j] = part
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cc.AddMessage(zhipu.ChatCompletionMultiMessage{
|
if len(contents) == 1 && contents[0].Type == llm.TypeText {
|
||||||
Role: NameMap[msg.Role],
|
cc.AddMessage(zhipu.ChatCompletionMessage{
|
||||||
Content: contents,
|
Role: NameMap[msg.Role],
|
||||||
})
|
Content: contents[0].Text,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
cc.AddMessage(zhipu.ChatCompletionMultiMessage{
|
||||||
|
Role: NameMap[msg.Role],
|
||||||
|
Content: contents,
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for name := range config.GetTools() {
|
for name := range config.GetTools() {
|
||||||
switch name {
|
switch name {
|
||||||
case agent.ToolCodeInterpreter:
|
case llm.ToolCodeInterpreter:
|
||||||
cc.AddTool(zhipu.ChatCompletionToolCodeInterpreter{})
|
cc.AddTool(zhipu.ChatCompletionToolCodeInterpreter{})
|
||||||
case agent.ToolWebSearch:
|
case llm.ToolWebSearch:
|
||||||
cc.AddTool(zhipu.ChatCompletionToolWebBrowser{})
|
cc.AddTool(zhipu.ChatCompletionToolWebBrowser{})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -122,19 +131,65 @@ func (ag *Agent) Ask(messages []agent.ChatMessage, config *agent.ChatModelConfig
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if lm.config.Debug {
|
||||||
|
fmt.Println(cc.BatchMethod(), cc.BatchURL())
|
||||||
|
fmt.Println(u.JsonP(cc.BatchBody()))
|
||||||
|
}
|
||||||
|
|
||||||
|
t1 := time.Now().UnixMilli()
|
||||||
if r, err := cc.Do(context.Background()); err == nil {
|
if r, err := cc.Do(context.Background()); err == nil {
|
||||||
|
t2 := time.Now().UnixMilli() - t1
|
||||||
results := make([]string, 0)
|
results := make([]string, 0)
|
||||||
if r.Choices != nil {
|
if r.Choices != nil {
|
||||||
for _, ch := range r.Choices {
|
for _, ch := range r.Choices {
|
||||||
results = append(results, ch.Message.Content)
|
results = append(results, ch.Message.Content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return strings.Join(results, ""), agent.TokenUsage{
|
return strings.Join(results, ""), llm.Usage{
|
||||||
AskTokens: r.Usage.PromptTokens,
|
AskTokens: r.Usage.PromptTokens,
|
||||||
AnswerTokens: r.Usage.CompletionTokens,
|
AnswerTokens: r.Usage.CompletionTokens,
|
||||||
TotalTokens: r.Usage.TotalTokens,
|
TotalTokens: r.Usage.TotalTokens,
|
||||||
|
UsedTime: t2,
|
||||||
}, nil
|
}, nil
|
||||||
} else {
|
} else {
|
||||||
return "", agent.TokenUsage{}, err
|
return "", llm.Usage{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lm *LLM) FastEmbedding(text string) ([]byte, llm.Usage, error) {
|
||||||
|
return lm.Embedding(text, ModelEmbedding3)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lm *LLM) BestEmbedding(text string) ([]byte, llm.Usage, error) {
|
||||||
|
return lm.Embedding(text, ModelEmbedding3)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lm *LLM) Embedding(text, model string) ([]byte, llm.Usage, error) {
|
||||||
|
c, err := zhipu.NewClient(zhipu.WithAPIKey(lm.config.ApiKey), zhipu.WithBaseURL(lm.config.Endpoint))
|
||||||
|
if err != nil {
|
||||||
|
return nil, llm.Usage{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cc := c.Embedding(model)
|
||||||
|
cc.SetInput(text)
|
||||||
|
t1 := time.Now().UnixMilli()
|
||||||
|
if r, err := cc.Do(context.Background()); err == nil {
|
||||||
|
t2 := time.Now().UnixMilli() - t1
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
if r.Data != nil {
|
||||||
|
for _, ch := range r.Data {
|
||||||
|
for _, v := range ch.Embedding {
|
||||||
|
_ = binary.Write(buf, binary.LittleEndian, float32(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return buf.Bytes(), llm.Usage{
|
||||||
|
AskTokens: r.Usage.PromptTokens,
|
||||||
|
AnswerTokens: r.Usage.CompletionTokens,
|
||||||
|
TotalTokens: r.Usage.TotalTokens,
|
||||||
|
UsedTime: t2,
|
||||||
|
}, nil
|
||||||
|
} else {
|
||||||
|
return nil, llm.Usage{}, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
30
config.go
30
config.go
@ -1,22 +1,22 @@
|
|||||||
package zhipu
|
package zhipu
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"apigo.cc/ai/agent"
|
"apigo.cc/ai/llm/llm"
|
||||||
"apigo.cc/ai/zhipu/zhipu"
|
"github.com/yankeguo/zhipu"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Agent struct {
|
type LLM struct {
|
||||||
config agent.APIConfig
|
config llm.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
var NameMap = map[string]string{
|
var NameMap = map[string]string{
|
||||||
agent.TypeText: zhipu.MultiContentTypeText,
|
llm.TypeText: zhipu.MultiContentTypeText,
|
||||||
agent.TypeImage: zhipu.MultiContentTypeImageURL,
|
llm.TypeImage: zhipu.MultiContentTypeImageURL,
|
||||||
agent.TypeVideo: zhipu.MultiContentTypeVideoURL,
|
//llm.TypeVideo: zhipu.MultiContentTypeVideoURL,
|
||||||
agent.RoleSystem: zhipu.RoleSystem,
|
llm.RoleSystem: zhipu.RoleSystem,
|
||||||
agent.RoleUser: zhipu.RoleUser,
|
llm.RoleUser: zhipu.RoleUser,
|
||||||
agent.RoleAssistant: zhipu.RoleAssistant,
|
llm.RoleAssistant: zhipu.RoleAssistant,
|
||||||
agent.RoleTool: zhipu.RoleTool,
|
llm.RoleTool: zhipu.RoleTool,
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -40,8 +40,8 @@ const (
|
|||||||
ModelCodeGeeX4 = "CodeGeeX-4"
|
ModelCodeGeeX4 = "CodeGeeX-4"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (ag *Agent) Support() agent.Support {
|
func (lm *LLM) Support() llm.Support {
|
||||||
return agent.Support{
|
return llm.Support{
|
||||||
Ask: true,
|
Ask: true,
|
||||||
AskWithImage: true,
|
AskWithImage: true,
|
||||||
AskWithVideo: true,
|
AskWithVideo: true,
|
||||||
@ -54,7 +54,7 @@ func (ag *Agent) Support() agent.Support {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
agent.RegisterAgentMaker("zhipu", func(config agent.APIConfig) agent.Agent {
|
llm.Register("zhipu", func(config llm.Config) llm.LLM {
|
||||||
return &Agent{config: config}
|
return &LLM{config: config}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
94
gc.go
Normal file
94
gc.go
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
package zhipu
|
||||||
|
|
||||||
|
import (
|
||||||
|
"apigo.cc/ai/llm/llm"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"github.com/yankeguo/zhipu"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (lm *LLM) FastMakeImage(prompt string, config llm.GCConfig) ([]string, llm.Usage, error) {
|
||||||
|
config.Model = ModelCogView3Plus
|
||||||
|
return lm.MakeImage(prompt, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lm *LLM) BestMakeImage(prompt string, config llm.GCConfig) ([]string, llm.Usage, error) {
|
||||||
|
config.Model = ModelCogView3
|
||||||
|
return lm.MakeImage(prompt, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lm *LLM) MakeImage(prompt string, config llm.GCConfig) ([]string, llm.Usage, error) {
|
||||||
|
c, err := zhipu.NewClient(zhipu.WithAPIKey(lm.config.ApiKey), zhipu.WithBaseURL(lm.config.Endpoint))
|
||||||
|
if err != nil {
|
||||||
|
return nil, llm.Usage{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
config.SetDefault(&lm.config.GCConfig)
|
||||||
|
cc := c.ImageGeneration(config.Model).SetPrompt(prompt)
|
||||||
|
//cc.SetSize(config.GetSize())
|
||||||
|
|
||||||
|
t1 := time.Now().UnixMilli()
|
||||||
|
if r, err := cc.Do(context.Background()); err == nil {
|
||||||
|
t2 := time.Now().UnixMilli() - t1
|
||||||
|
results := make([]string, 0)
|
||||||
|
for _, item := range r.Data {
|
||||||
|
results = append(results, item.URL)
|
||||||
|
}
|
||||||
|
return results, llm.Usage{
|
||||||
|
UsedTime: t2,
|
||||||
|
}, nil
|
||||||
|
} else {
|
||||||
|
return nil, llm.Usage{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lm *LLM) FastMakeVideo(prompt string, config llm.GCConfig) ([]string, []string, llm.Usage, error) {
|
||||||
|
config.Model = ModelCogVideoX
|
||||||
|
return lm.MakeVideo(prompt, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lm *LLM) BestMakeVideo(prompt string, config llm.GCConfig) ([]string, []string, llm.Usage, error) {
|
||||||
|
config.Model = ModelCogVideoX
|
||||||
|
return lm.MakeVideo(prompt, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lm *LLM) MakeVideo(prompt string, config llm.GCConfig) ([]string, []string, llm.Usage, error) {
|
||||||
|
c, err := zhipu.NewClient(zhipu.WithAPIKey(lm.config.ApiKey), zhipu.WithBaseURL(lm.config.Endpoint))
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, llm.Usage{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
config.SetDefault(&lm.config.GCConfig)
|
||||||
|
cc := c.VideoGeneration(config.Model).SetPrompt(prompt)
|
||||||
|
cc.SetImageURL(config.GetRef())
|
||||||
|
|
||||||
|
t1 := time.Now().UnixMilli()
|
||||||
|
if resp, err := cc.Do(context.Background()); err == nil {
|
||||||
|
t2 := time.Now().UnixMilli() - t1
|
||||||
|
for i := 0; i < 1200; i++ {
|
||||||
|
r, err := c.AsyncResult(resp.ID).Do(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, llm.Usage{}, err
|
||||||
|
}
|
||||||
|
if r.TaskStatus == zhipu.VideoGenerationTaskStatusSuccess {
|
||||||
|
covers := make([]string, 0)
|
||||||
|
results := make([]string, 0)
|
||||||
|
for _, item := range r.VideoResult {
|
||||||
|
results = append(results, item.URL)
|
||||||
|
covers = append(covers, item.CoverImageURL)
|
||||||
|
}
|
||||||
|
return results, covers, llm.Usage{
|
||||||
|
UsedTime: t2,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
if r.TaskStatus == zhipu.VideoGenerationTaskStatusFail {
|
||||||
|
return nil, nil, llm.Usage{}, errors.New("fail on task " + resp.ID)
|
||||||
|
}
|
||||||
|
time.Sleep(3 * time.Second)
|
||||||
|
}
|
||||||
|
return nil, nil, llm.Usage{}, errors.New("timeout on task " + resp.ID)
|
||||||
|
} else {
|
||||||
|
return nil, nil, llm.Usage{}, err
|
||||||
|
}
|
||||||
|
}
|
13
go.mod
13
go.mod
@ -3,15 +3,14 @@ module apigo.cc/ai/zhipu
|
|||||||
go 1.22
|
go 1.22
|
||||||
|
|
||||||
require (
|
require (
|
||||||
apigo.cc/ai/agent v0.0.1
|
apigo.cc/ai/llm v0.0.4
|
||||||
github.com/go-resty/resty/v2 v2.14.0
|
github.com/ssgo/u v1.7.9
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.1
|
github.com/yankeguo/zhipu v0.1.2
|
||||||
github.com/stretchr/testify v1.9.0
|
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/go-resty/resty/v2 v2.14.0 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/golang-jwt/jwt/v5 v5.2.1 // indirect
|
||||||
golang.org/x/net v0.29.0 // indirect
|
golang.org/x/net v0.30.0 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
@ -1,108 +0,0 @@
|
|||||||
# Changelog
|
|
||||||
All notable changes to this project will be documented in this file. See [conventional commits](https://www.conventionalcommits.org/) for commit guidelines.
|
|
||||||
|
|
||||||
- - -
|
|
||||||
## v0.1.2 - 2024-08-15
|
|
||||||
#### Bug Fixes
|
|
||||||
- add FinishReasonStopSequence - (01b4201) - GUO YANKE
|
|
||||||
#### Documentation
|
|
||||||
- update README.md [skip ci] - (e48a88b) - GUO YANKE
|
|
||||||
#### Features
|
|
||||||
- add videos/generations - (7261999) - GUO YANKE
|
|
||||||
#### Miscellaneous Chores
|
|
||||||
- relaxing go version to 1.18 - (6acc17c) - GUO YANKE
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
## v0.1.1 - 2024-07-17
|
|
||||||
#### Documentation
|
|
||||||
- update README.md [skip ci] - (695432a) - GUO YANKE
|
|
||||||
#### Features
|
|
||||||
- add support for GLM-4-AllTools - (9627a36) - GUO YANKE
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
## v0.1.0 - 2024-06-28
|
|
||||||
#### Bug Fixes
|
|
||||||
- rename client function for batch list - (40ac05f) - GUO YANKE
|
|
||||||
#### Documentation
|
|
||||||
- update README.md [skip ci] - (6ce5754) - GUO YANKE
|
|
||||||
#### Features
|
|
||||||
- add knowledge capacity service - (4ce62b3) - GUO YANKE
|
|
||||||
#### Refactoring
|
|
||||||
- update batch service - (b92d438) - GUO YANKE
|
|
||||||
- update chat completion service - (19dd77f) - GUO YANKE
|
|
||||||
- update embedding service - (c1bbc2d) - GUO YANKE
|
|
||||||
- update file services - (7ef4d87) - GUO YANKE
|
|
||||||
- update fine tune services, using APIError - (15aed88) - GUO YANKE
|
|
||||||
- update fine tune services - (664523b) - GUO YANKE
|
|
||||||
- update image generation service - (a18e028) - GUO YANKE
|
|
||||||
- update knowledge services - (c7bfb73) - GUO YANKE
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
## v0.0.6 - 2024-06-28
|
|
||||||
#### Features
|
|
||||||
- add batch support for result reader - (c062095) - GUO YANKE
|
|
||||||
- add fine tune services - (f172f51) - GUO YANKE
|
|
||||||
- add knowledge service - (09792b5) - GUO YANKE
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
## v0.0.5 - 2024-06-28
|
|
||||||
#### Bug Fixes
|
|
||||||
- api error parsing - (60a17f4) - GUO YANKE
|
|
||||||
#### Features
|
|
||||||
- add batch service - (389aec3) - GUO YANKE
|
|
||||||
- add batch support for chat completions, image generations and embeddings - (c017ffd) - GUO YANKE
|
|
||||||
- add file edit/get/delete service - (8a4d309) - GUO YANKE
|
|
||||||
- add file create serivce - (6d2140b) - GUO YANKE
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
## v0.0.4 - 2024-06-26
|
|
||||||
#### Bug Fixes
|
|
||||||
- remove Client.R(), hide resty for future removal - (dc2a4ca) - GUO YANKE
|
|
||||||
#### Features
|
|
||||||
- add meta support for charglm - (fdd20e7) - GUO YANKE
|
|
||||||
- add client option to custom http client - (c62d6a9) - GUO YANKE
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
## v0.0.3 - 2024-06-26
|
|
||||||
#### Features
|
|
||||||
- add image generation service - (9f3f54f) - GUO YANKE
|
|
||||||
- add support for vision models - (2dcd82a) - GUO YANKE
|
|
||||||
- add embedding service - (f57806a) - GUO YANKE
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
## v0.0.2 - 2024-06-26
|
|
||||||
#### Bug Fixes
|
|
||||||
- **(deps)** update golang-jwt/jwt to v5 - (2f76a57) - GUO YANKE
|
|
||||||
#### Features
|
|
||||||
- add constants for roles - (3d08a72) - GUO YANKE
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
## v0.0.1 - 2024-06-26
|
|
||||||
#### Bug Fixes
|
|
||||||
- add json tag "omitempty" to various types - (bf81097) - GUO YANKE
|
|
||||||
#### Continuous Integration
|
|
||||||
- add github action workflows for testing - (5a64987) - GUO YANKE
|
|
||||||
#### Documentation
|
|
||||||
- update README.md [skip ci] - (d504f57) - GUO YANKE
|
|
||||||
#### Features
|
|
||||||
- add chat completion in stream mode - (130fe1d) - GUO YANKE
|
|
||||||
- add chat completion in non-stream mode - (2326e37) - GUO YANKE
|
|
||||||
- support debug option while creating client - (0f104d8) - GUO YANKE
|
|
||||||
- add APIError and APIErrorResponse - (1886d85) - GUO YANKE
|
|
||||||
- add client struct - (710d8e8) - GUO YANKE
|
|
||||||
#### Refactoring
|
|
||||||
- change signature of Client#createJWT since there is no reason to fail - (f0d7887) - GUO YANKE
|
|
||||||
#### Tests
|
|
||||||
- add client_test.go - (a3fc217) - GUO YANKE
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
Changelog generated by [cocogitto](https://github.com/cocogitto/cocogitto).
|
|
@ -1,21 +0,0 @@
|
|||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2024 Yanke G.
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
280
zhipu/README.md
280
zhipu/README.md
@ -1,280 +0,0 @@
|
|||||||
# zhipu
|
|
||||||
|
|
||||||
[![Go Reference](https://pkg.go.dev/badge/github.com/yankeguo/zhipu.svg)](https://pkg.go.dev/github.com/yankeguo/zhipu)
|
|
||||||
[![go](https://github.com/yankeguo/zhipu/actions/workflows/go.yml/badge.svg)](https://github.com/yankeguo/zhipu/actions/workflows/go.yml)
|
|
||||||
|
|
||||||
[中文文档](README.zh.md)
|
|
||||||
|
|
||||||
A 3rd-Party Golang Client Library for Zhipu AI Platform
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
### Install the package
|
|
||||||
|
|
||||||
```bash
|
|
||||||
go get -u github.com/yankeguo/zhipu
|
|
||||||
```
|
|
||||||
|
|
||||||
### Create a client
|
|
||||||
|
|
||||||
```go
|
|
||||||
// this will use environment variables ZHIPUAI_API_KEY
|
|
||||||
client, err := zhipu.NewClient()
|
|
||||||
// or you can specify the API key
|
|
||||||
client, err = zhipu.NewClient(zhipu.WithAPIKey("your api key"))
|
|
||||||
```
|
|
||||||
|
|
||||||
### Use the client
|
|
||||||
|
|
||||||
**ChatCompletion**
|
|
||||||
|
|
||||||
```go
|
|
||||||
service := client.ChatCompletion("glm-4-flash").
|
|
||||||
AddMessage(zhipu.ChatCompletionMessage{
|
|
||||||
Role: "user",
|
|
||||||
Content: "你好",
|
|
||||||
})
|
|
||||||
|
|
||||||
res, err := service.Do(context.Background())
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
zhipu.GetAPIErrorCode(err) // get the API error code
|
|
||||||
} else {
|
|
||||||
println(res.Choices[0].Message.Content)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**ChatCompletion (Stream)**
|
|
||||||
|
|
||||||
```go
|
|
||||||
service := client.ChatCompletion("glm-4-flash").
|
|
||||||
AddMessage(zhipu.ChatCompletionMessage{
|
|
||||||
Role: "user",
|
|
||||||
Content: "你好",
|
|
||||||
}).SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
|
|
||||||
println(chunk.Choices[0].Delta.Content)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
res, err := service.Do(context.Background())
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
zhipu.GetAPIErrorCode(err) // get the API error code
|
|
||||||
} else {
|
|
||||||
// this package will combine the stream chunks and build a final result mimicking the non-streaming API
|
|
||||||
println(res.Choices[0].Message.Content)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**ChatCompletion (Stream with GLM-4-AllTools)**
|
|
||||||
|
|
||||||
```go
|
|
||||||
// CodeInterpreter
|
|
||||||
s := client.ChatCompletion("GLM-4-AllTools")
|
|
||||||
s.AddMessage(zhipu.ChatCompletionMultiMessage{
|
|
||||||
Role: "user",
|
|
||||||
Content: []zhipu.ChatCompletionMultiContent{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: "计算[5,10,20,700,99,310,978,100]的平均值和方差。",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
s.AddTool(zhipu.ChatCompletionToolCodeInterpreter{
|
|
||||||
Sandbox: zhipu.Ptr(CodeInterpreterSandboxAuto),
|
|
||||||
})
|
|
||||||
s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
|
|
||||||
for _, c := range chunk.Choices {
|
|
||||||
for _, tc := range c.Delta.ToolCalls {
|
|
||||||
if tc.Type == ToolTypeCodeInterpreter && tc.CodeInterpreter != nil {
|
|
||||||
if tc.CodeInterpreter.Input != "" {
|
|
||||||
// DO SOMETHING
|
|
||||||
}
|
|
||||||
if len(tc.CodeInterpreter.Outputs) > 0 {
|
|
||||||
// DO SOMETHING
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
// WebBrowser
|
|
||||||
// CAUTION: NOT 'WebSearch'
|
|
||||||
s := client.ChatCompletion("GLM-4-AllTools")
|
|
||||||
s.AddMessage(zhipu.ChatCompletionMultiMessage{
|
|
||||||
Role: "user",
|
|
||||||
Content: []zhipu.ChatCompletionMultiContent{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: "搜索下本周深圳天气如何",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
s.AddTool(zhipu.ChatCompletionToolWebBrowser{})
|
|
||||||
s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
|
|
||||||
for _, c := range chunk.Choices {
|
|
||||||
for _, tc := range c.Delta.ToolCalls {
|
|
||||||
if tc.Type == ToolTypeWebBrowser && tc.WebBrowser != nil {
|
|
||||||
if tc.WebBrowser.Input != "" {
|
|
||||||
// DO SOMETHING
|
|
||||||
}
|
|
||||||
if len(tc.WebBrowser.Outputs) > 0 {
|
|
||||||
// DO SOMETHING
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
s.Do(context.Background())
|
|
||||||
|
|
||||||
// DrawingTool
|
|
||||||
s := client.ChatCompletion("GLM-4-AllTools")
|
|
||||||
s.AddMessage(zhipu.ChatCompletionMultiMessage{
|
|
||||||
Role: "user",
|
|
||||||
Content: []zhipu.ChatCompletionMultiContent{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: "画一个正弦函数图像",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
s.AddTool(zhipu.ChatCompletionToolDrawingTool{})
|
|
||||||
s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
|
|
||||||
for _, c := range chunk.Choices {
|
|
||||||
for _, tc := range c.Delta.ToolCalls {
|
|
||||||
if tc.Type == ToolTypeDrawingTool && tc.DrawingTool != nil {
|
|
||||||
if tc.DrawingTool.Input != "" {
|
|
||||||
// DO SOMETHING
|
|
||||||
}
|
|
||||||
if len(tc.DrawingTool.Outputs) > 0 {
|
|
||||||
// DO SOMETHING
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
s.Do(context.Background())
|
|
||||||
```
|
|
||||||
|
|
||||||
**Embedding**
|
|
||||||
|
|
||||||
```go
|
|
||||||
service := client.Embedding("embedding-v2").SetInput("你好呀")
|
|
||||||
service.Do(context.Background())
|
|
||||||
```
|
|
||||||
|
|
||||||
**Image Generation**
|
|
||||||
|
|
||||||
```go
|
|
||||||
service := client.ImageGeneration("cogview-3").SetPrompt("一只可爱的小猫咪")
|
|
||||||
service.Do(context.Background())
|
|
||||||
```
|
|
||||||
|
|
||||||
**Video Generation**
|
|
||||||
|
|
||||||
```go
|
|
||||||
service := client.VideoGeneration("cogvideox").SetPrompt("一只可爱的小猫咪")
|
|
||||||
resp, err := service.Do(context.Background())
|
|
||||||
|
|
||||||
for {
|
|
||||||
result, err := client.AsyncResult(resp.ID).Do(context.Background())
|
|
||||||
|
|
||||||
if result.TaskStatus == zhipu.VideoGenerationTaskStatusSuccess {
|
|
||||||
_ = result.VideoResult[0].URL
|
|
||||||
_ = result.VideoResult[0].CoverImageURL
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.TaskStatus != zhipu.VideoGenerationTaskStatusProcessing {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**Upload File (Retrieval)**
|
|
||||||
|
|
||||||
```go
|
|
||||||
service := client.FileCreate(zhipu.FilePurposeRetrieval)
|
|
||||||
service.SetLocalFile(filepath.Join("testdata", "test-file.txt"))
|
|
||||||
service.SetKnowledgeID("your-knowledge-id")
|
|
||||||
|
|
||||||
service.Do(context.Background())
|
|
||||||
```
|
|
||||||
|
|
||||||
**Upload File (Fine-Tune)**
|
|
||||||
|
|
||||||
```go
|
|
||||||
service := client.FileCreate(zhipu.FilePurposeFineTune)
|
|
||||||
service.SetLocalFile(filepath.Join("testdata", "test-file.jsonl"))
|
|
||||||
service.Do(context.Background())
|
|
||||||
```
|
|
||||||
|
|
||||||
**Batch Create**
|
|
||||||
|
|
||||||
```go
|
|
||||||
service := client.BatchCreate().
|
|
||||||
SetInputFileID("fileid").
|
|
||||||
SetCompletionWindow(zhipu.BatchCompletionWindow24h).
|
|
||||||
SetEndpoint(BatchEndpointV4ChatCompletions)
|
|
||||||
service.Do(context.Background())
|
|
||||||
```
|
|
||||||
|
|
||||||
**Knowledge Base**
|
|
||||||
|
|
||||||
```go
|
|
||||||
client.KnowledgeCreate("")
|
|
||||||
client.KnowledgeEdit("")
|
|
||||||
```
|
|
||||||
|
|
||||||
**Fine Tune**
|
|
||||||
|
|
||||||
```go
|
|
||||||
client.FineTuneCreate("")
|
|
||||||
```
|
|
||||||
|
|
||||||
### Batch Support
|
|
||||||
|
|
||||||
**Batch File Writer**
|
|
||||||
|
|
||||||
```go
|
|
||||||
f, err := os.OpenFile("batch.jsonl", os.O_CREATE|os.O_WRONLY, 0644)
|
|
||||||
|
|
||||||
bw := zhipu.NewBatchFileWriter(f)
|
|
||||||
|
|
||||||
bw.Add("action_1", client.ChatCompletion("glm-4-flash").
|
|
||||||
AddMessage(zhipu.ChatCompletionMessage{
|
|
||||||
Role: "user",
|
|
||||||
Content: "你好",
|
|
||||||
}))
|
|
||||||
bw.Add("action_2", client.Embedding("embedding-v2").SetInput("你好呀"))
|
|
||||||
bw.Add("action_3", client.ImageGeneration("cogview-3").SetPrompt("一只可爱的小猫咪"))
|
|
||||||
```
|
|
||||||
|
|
||||||
**Batch Result Reader**
|
|
||||||
|
|
||||||
```go
|
|
||||||
br := zhipu.NewBatchResultReader[zhipu.ChatCompletionResponse](r)
|
|
||||||
|
|
||||||
for {
|
|
||||||
var res zhipu.BatchResult[zhipu.ChatCompletionResponse]
|
|
||||||
err := br.Read(&res)
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Donation
|
|
||||||
|
|
||||||
Executing unit tests will actually call the ChatGLM API and consume my quota. Please donate and thank you for your support!
|
|
||||||
|
|
||||||
<img src="./wechat-donation.png" width="180"/>
|
|
||||||
|
|
||||||
## Credits
|
|
||||||
|
|
||||||
GUO YANKE, MIT License
|
|
@ -1,278 +0,0 @@
|
|||||||
# zhipu
|
|
||||||
|
|
||||||
[![Go Reference](https://pkg.go.dev/badge/github.com/yankeguo/zhipu.svg)](https://pkg.go.dev/github.com/yankeguo/zhipu)
|
|
||||||
[![go](https://github.com/yankeguo/zhipu/actions/workflows/go.yml/badge.svg)](https://github.com/yankeguo/zhipu/actions/workflows/go.yml)
|
|
||||||
|
|
||||||
Zhipu AI 平台第三方 Golang 客户端库
|
|
||||||
|
|
||||||
## 用法
|
|
||||||
|
|
||||||
### 安装库
|
|
||||||
|
|
||||||
```bash
|
|
||||||
go get -u github.com/yankeguo/zhipu
|
|
||||||
```
|
|
||||||
|
|
||||||
### 创建客户端
|
|
||||||
|
|
||||||
```go
|
|
||||||
// 默认使用环境变量 ZHIPUAI_API_KEY
|
|
||||||
client, err := zhipu.NewClient()
|
|
||||||
// 或者手动指定密钥
|
|
||||||
client, err = zhipu.NewClient(zhipu.WithAPIKey("your api key"))
|
|
||||||
```
|
|
||||||
|
|
||||||
### 使用客户端
|
|
||||||
|
|
||||||
**ChatCompletion(大语言模型)**
|
|
||||||
|
|
||||||
```go
|
|
||||||
service := client.ChatCompletion("glm-4-flash").
|
|
||||||
AddMessage(zhipu.ChatCompletionMessage{
|
|
||||||
Role: "user",
|
|
||||||
Content: "你好",
|
|
||||||
})
|
|
||||||
|
|
||||||
res, err := service.Do(context.Background())
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
zhipu.GetAPIErrorCode(err) // get the API error code
|
|
||||||
} else {
|
|
||||||
println(res.Choices[0].Message.Content)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**ChatCompletion(流式调用大语言模型)**
|
|
||||||
|
|
||||||
```go
|
|
||||||
service := client.ChatCompletion("glm-4-flash").
|
|
||||||
AddMessage(zhipu.ChatCompletionMessage{
|
|
||||||
Role: "user",
|
|
||||||
Content: "你好",
|
|
||||||
}).SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
|
|
||||||
println(chunk.Choices[0].Delta.Content)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
res, err := service.Do(context.Background())
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
zhipu.GetAPIErrorCode(err) // get the API error code
|
|
||||||
} else {
|
|
||||||
// this package will combine the stream chunks and build a final result mimicking the non-streaming API
|
|
||||||
println(res.Choices[0].Message.Content)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**ChatCompletion(流式调用大语言工具模型GLM-4-AllTools)**
|
|
||||||
|
|
||||||
```go
|
|
||||||
// CodeInterpreter
|
|
||||||
s := client.ChatCompletion("GLM-4-AllTools")
|
|
||||||
s.AddMessage(zhipu.ChatCompletionMultiMessage{
|
|
||||||
Role: "user",
|
|
||||||
Content: []zhipu.ChatCompletionMultiContent{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: "计算[5,10,20,700,99,310,978,100]的平均值和方差。",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
s.AddTool(zhipu.ChatCompletionToolCodeInterpreter{
|
|
||||||
Sandbox: zhipu.Ptr(CodeInterpreterSandboxAuto),
|
|
||||||
})
|
|
||||||
s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
|
|
||||||
for _, c := range chunk.Choices {
|
|
||||||
for _, tc := range c.Delta.ToolCalls {
|
|
||||||
if tc.Type == ToolTypeCodeInterpreter && tc.CodeInterpreter != nil {
|
|
||||||
if tc.CodeInterpreter.Input != "" {
|
|
||||||
// DO SOMETHING
|
|
||||||
}
|
|
||||||
if len(tc.CodeInterpreter.Outputs) > 0 {
|
|
||||||
// DO SOMETHING
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
// WebBrowser
|
|
||||||
// CAUTION: NOT 'WebSearch'
|
|
||||||
s := client.ChatCompletion("GLM-4-AllTools")
|
|
||||||
s.AddMessage(zhipu.ChatCompletionMultiMessage{
|
|
||||||
Role: "user",
|
|
||||||
Content: []zhipu.ChatCompletionMultiContent{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: "搜索下本周深圳天气如何",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
s.AddTool(zhipu.ChatCompletionToolWebBrowser{})
|
|
||||||
s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
|
|
||||||
for _, c := range chunk.Choices {
|
|
||||||
for _, tc := range c.Delta.ToolCalls {
|
|
||||||
if tc.Type == ToolTypeWebBrowser && tc.WebBrowser != nil {
|
|
||||||
if tc.WebBrowser.Input != "" {
|
|
||||||
// DO SOMETHING
|
|
||||||
}
|
|
||||||
if len(tc.WebBrowser.Outputs) > 0 {
|
|
||||||
// DO SOMETHING
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
s.Do(context.Background())
|
|
||||||
|
|
||||||
// DrawingTool
|
|
||||||
s := client.ChatCompletion("GLM-4-AllTools")
|
|
||||||
s.AddMessage(zhipu.ChatCompletionMultiMessage{
|
|
||||||
Role: "user",
|
|
||||||
Content: []zhipu.ChatCompletionMultiContent{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: "画一个正弦函数图像",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
s.AddTool(zhipu.ChatCompletionToolDrawingTool{})
|
|
||||||
s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
|
|
||||||
for _, c := range chunk.Choices {
|
|
||||||
for _, tc := range c.Delta.ToolCalls {
|
|
||||||
if tc.Type == ToolTypeDrawingTool && tc.DrawingTool != nil {
|
|
||||||
if tc.DrawingTool.Input != "" {
|
|
||||||
// DO SOMETHING
|
|
||||||
}
|
|
||||||
if len(tc.DrawingTool.Outputs) > 0 {
|
|
||||||
// DO SOMETHING
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
s.Do(context.Background())
|
|
||||||
```
|
|
||||||
|
|
||||||
**Embedding**
|
|
||||||
|
|
||||||
```go
|
|
||||||
service := client.Embedding("embedding-v2").SetInput("你好呀")
|
|
||||||
service.Do(context.Background())
|
|
||||||
```
|
|
||||||
|
|
||||||
**ImageGeneration(图像生成)**
|
|
||||||
|
|
||||||
```go
|
|
||||||
service := client.ImageGeneration("cogview-3").SetPrompt("一只可爱的小猫咪")
|
|
||||||
service.Do(context.Background())
|
|
||||||
```
|
|
||||||
|
|
||||||
**VideoGeneration(视频生成)**
|
|
||||||
|
|
||||||
```go
|
|
||||||
service := client.VideoGeneration("cogvideox").SetPrompt("一只可爱的小猫咪")
|
|
||||||
resp, err := service.Do(context.Background())
|
|
||||||
|
|
||||||
for {
|
|
||||||
result, err := client.AsyncResult(resp.ID).Do(context.Background())
|
|
||||||
|
|
||||||
if result.TaskStatus == zhipu.VideoGenerationTaskStatusSuccess {
|
|
||||||
_ = result.VideoResult[0].URL
|
|
||||||
_ = result.VideoResult[0].CoverImageURL
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.TaskStatus != zhipu.VideoGenerationTaskStatusProcessing {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**UploadFile(上传文件用于取回)**
|
|
||||||
|
|
||||||
```go
|
|
||||||
service := client.FileCreate(zhipu.FilePurposeRetrieval)
|
|
||||||
service.SetLocalFile(filepath.Join("testdata", "test-file.txt"))
|
|
||||||
service.SetKnowledgeID("your-knowledge-id")
|
|
||||||
|
|
||||||
service.Do(context.Background())
|
|
||||||
```
|
|
||||||
|
|
||||||
**UploadFile(上传文件用于微调)**
|
|
||||||
|
|
||||||
```go
|
|
||||||
service := client.FileCreate(zhipu.FilePurposeFineTune)
|
|
||||||
service.SetLocalFile(filepath.Join("testdata", "test-file.jsonl"))
|
|
||||||
service.Do(context.Background())
|
|
||||||
```
|
|
||||||
|
|
||||||
**BatchCreate(创建批量任务)**
|
|
||||||
|
|
||||||
```go
|
|
||||||
service := client.BatchCreate().
|
|
||||||
SetInputFileID("fileid").
|
|
||||||
SetCompletionWindow(zhipu.BatchCompletionWindow24h).
|
|
||||||
SetEndpoint(BatchEndpointV4ChatCompletions)
|
|
||||||
service.Do(context.Background())
|
|
||||||
```
|
|
||||||
|
|
||||||
**KnowledgeBase(知识库)**
|
|
||||||
|
|
||||||
```go
|
|
||||||
client.KnowledgeCreate("")
|
|
||||||
client.KnowledgeEdit("")
|
|
||||||
```
|
|
||||||
|
|
||||||
**FineTune(微调)**
|
|
||||||
|
|
||||||
```go
|
|
||||||
client.FineTuneCreate("")
|
|
||||||
```
|
|
||||||
|
|
||||||
### 批量任务辅助工具
|
|
||||||
|
|
||||||
**批量任务文件创建**
|
|
||||||
|
|
||||||
```go
|
|
||||||
f, err := os.OpenFile("batch.jsonl", os.O_CREATE|os.O_WRONLY, 0644)
|
|
||||||
|
|
||||||
bw := zhipu.NewBatchFileWriter(f)
|
|
||||||
|
|
||||||
bw.Add("action_1", client.ChatCompletion("glm-4-flash").
|
|
||||||
AddMessage(zhipu.ChatCompletionMessage{
|
|
||||||
Role: "user",
|
|
||||||
Content: "你好",
|
|
||||||
}))
|
|
||||||
bw.Add("action_2", client.Embedding("embedding-v2").SetInput("你好呀"))
|
|
||||||
bw.Add("action_3", client.ImageGeneration("cogview-3").SetPrompt("一只可爱的小猫咪"))
|
|
||||||
```
|
|
||||||
|
|
||||||
**批量任务结果解析**
|
|
||||||
|
|
||||||
```go
|
|
||||||
br := zhipu.NewBatchResultReader[zhipu.ChatCompletionResponse](r)
|
|
||||||
|
|
||||||
for {
|
|
||||||
var res zhipu.BatchResult[zhipu.ChatCompletionResponse]
|
|
||||||
err := br.Read(&res)
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## 赞助
|
|
||||||
|
|
||||||
执行单元测试会真实调用GLM接口,消耗我充值的额度,开发不易,请微信扫码捐赠,感谢您的支持!
|
|
||||||
|
|
||||||
<img src="./wechat-donation.png" width="180"/>
|
|
||||||
|
|
||||||
## 许可证
|
|
||||||
|
|
||||||
GUO YANKE, MIT License
|
|
@ -1,63 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/go-resty/resty/v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
// AsyncResultService creates a new async result get service
|
|
||||||
type AsyncResultService struct {
|
|
||||||
client *Client
|
|
||||||
|
|
||||||
id string
|
|
||||||
}
|
|
||||||
|
|
||||||
// AsyncResultVideo is the video result of the AsyncResultService
|
|
||||||
type AsyncResultVideo struct {
|
|
||||||
URL string `json:"url"`
|
|
||||||
CoverImageURL string `json:"cover_image_url"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// AsyncResultResponse is the response of the AsyncResultService
|
|
||||||
type AsyncResultResponse struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
TaskStatus string `json:"task_status"`
|
|
||||||
RequestID string `json:"request_id"`
|
|
||||||
ID string `json:"id"`
|
|
||||||
VideoResult []AsyncResultVideo `json:"video_result"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewAsyncResultService creates a new async result get service
|
|
||||||
func NewAsyncResultService(client *Client) *AsyncResultService {
|
|
||||||
return &AsyncResultService{
|
|
||||||
client: client,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetID sets the id parameter
|
|
||||||
func (s *AsyncResultService) SetID(id string) *AsyncResultService {
|
|
||||||
s.id = id
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *AsyncResultService) Do(ctx context.Context) (res AsyncResultResponse, err error) {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
|
|
||||||
if resp, err = s.client.request(ctx).
|
|
||||||
SetResult(&res).
|
|
||||||
SetError(&apiError).
|
|
||||||
Get("async-result/" + s.id); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
@ -1 +0,0 @@
|
|||||||
package zhipu
|
|
258
zhipu/batch.go
258
zhipu/batch.go
@ -1,258 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/go-resty/resty/v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
BatchEndpointV4ChatCompletions = "/v4/chat/completions"
|
|
||||||
BatchEndpointV4ImagesGenerations = "/v4/images/generations"
|
|
||||||
BatchEndpointV4Embeddings = "/v4/embeddings"
|
|
||||||
BatchEndpointV4VideosGenerations = "/v4/videos/generations"
|
|
||||||
|
|
||||||
BatchCompletionWindow24h = "24h"
|
|
||||||
)
|
|
||||||
|
|
||||||
// BatchRequestCounts represents the counts of the batch requests.
|
|
||||||
type BatchRequestCounts struct {
|
|
||||||
Total int64 `json:"total"`
|
|
||||||
Completed int64 `json:"completed"`
|
|
||||||
Failed int64 `json:"failed"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchItem represents a batch item.
|
|
||||||
type BatchItem struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Object any `json:"object"`
|
|
||||||
Endpoint string `json:"endpoint"`
|
|
||||||
InputFileID string `json:"input_file_id"`
|
|
||||||
CompletionWindow string `json:"completion_window"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
OutputFileID string `json:"output_file_id"`
|
|
||||||
ErrorFileID string `json:"error_file_id"`
|
|
||||||
CreatedAt int64 `json:"created_at"`
|
|
||||||
InProgressAt int64 `json:"in_progress_at"`
|
|
||||||
ExpiresAt int64 `json:"expires_at"`
|
|
||||||
FinalizingAt int64 `json:"finalizing_at"`
|
|
||||||
CompletedAt int64 `json:"completed_at"`
|
|
||||||
FailedAt int64 `json:"failed_at"`
|
|
||||||
ExpiredAt int64 `json:"expired_at"`
|
|
||||||
CancellingAt int64 `json:"cancelling_at"`
|
|
||||||
CancelledAt int64 `json:"cancelled_at"`
|
|
||||||
RequestCounts BatchRequestCounts `json:"request_counts"`
|
|
||||||
Metadata json.RawMessage `json:"metadata"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchCreateService is a service to create a batch.
|
|
||||||
type BatchCreateService struct {
|
|
||||||
client *Client
|
|
||||||
|
|
||||||
inputFileID string
|
|
||||||
endpoint string
|
|
||||||
completionWindow string
|
|
||||||
metadata any
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewBatchCreateService creates a new BatchCreateService.
|
|
||||||
func NewBatchCreateService(client *Client) *BatchCreateService {
|
|
||||||
return &BatchCreateService{client: client}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetInputFileID sets the input file id for the batch.
|
|
||||||
func (s *BatchCreateService) SetInputFileID(inputFileID string) *BatchCreateService {
|
|
||||||
s.inputFileID = inputFileID
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetEndpoint sets the endpoint for the batch.
|
|
||||||
func (s *BatchCreateService) SetEndpoint(endpoint string) *BatchCreateService {
|
|
||||||
s.endpoint = endpoint
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetCompletionWindow sets the completion window for the batch.
|
|
||||||
func (s *BatchCreateService) SetCompletionWindow(window string) *BatchCreateService {
|
|
||||||
s.completionWindow = window
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetMetadata sets the metadata for the batch.
|
|
||||||
func (s *BatchCreateService) SetMetadata(metadata any) *BatchCreateService {
|
|
||||||
s.metadata = metadata
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do executes the batch create service.
|
|
||||||
func (s *BatchCreateService) Do(ctx context.Context) (res BatchItem, err error) {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
|
|
||||||
if resp, err = s.client.request(ctx).
|
|
||||||
SetBody(M{
|
|
||||||
"input_file_id": s.inputFileID,
|
|
||||||
"endpoint": s.endpoint,
|
|
||||||
"completion_window": s.completionWindow,
|
|
||||||
"metadata": s.metadata,
|
|
||||||
}).
|
|
||||||
SetResult(&res).
|
|
||||||
SetError(&apiError).
|
|
||||||
Post("batches"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchGetService is a service to get a batch.
|
|
||||||
type BatchGetService struct {
|
|
||||||
client *Client
|
|
||||||
batchID string
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchGetResponse represents the response of the batch get service.
|
|
||||||
type BatchGetResponse = BatchItem
|
|
||||||
|
|
||||||
// NewBatchGetService creates a new BatchGetService.
|
|
||||||
func NewBatchGetService(client *Client) *BatchGetService {
|
|
||||||
return &BatchGetService{client: client}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetBatchID sets the batch id for the batch get service.
|
|
||||||
func (s *BatchGetService) SetBatchID(batchID string) *BatchGetService {
|
|
||||||
s.batchID = batchID
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do executes the batch get service.
|
|
||||||
func (s *BatchGetService) Do(ctx context.Context) (res BatchGetResponse, err error) {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
|
|
||||||
if resp, err = s.client.request(ctx).
|
|
||||||
SetPathParam("batch_id", s.batchID).
|
|
||||||
SetResult(&res).
|
|
||||||
SetError(&apiError).
|
|
||||||
Get("batches/{batch_id}"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchCancelService is a service to cancel a batch.
|
|
||||||
type BatchCancelService struct {
|
|
||||||
client *Client
|
|
||||||
batchID string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewBatchCancelService creates a new BatchCancelService.
|
|
||||||
func NewBatchCancelService(client *Client) *BatchCancelService {
|
|
||||||
return &BatchCancelService{client: client}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetBatchID sets the batch id for the batch cancel service.
|
|
||||||
func (s *BatchCancelService) SetBatchID(batchID string) *BatchCancelService {
|
|
||||||
s.batchID = batchID
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do executes the batch cancel service.
|
|
||||||
func (s *BatchCancelService) Do(ctx context.Context) (err error) {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
|
|
||||||
if resp, err = s.client.request(ctx).
|
|
||||||
SetPathParam("batch_id", s.batchID).
|
|
||||||
SetBody(M{}).
|
|
||||||
SetError(&apiError).
|
|
||||||
Post("batches/{batch_id}/cancel"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchListService is a service to list batches.
|
|
||||||
type BatchListService struct {
|
|
||||||
client *Client
|
|
||||||
|
|
||||||
after *string
|
|
||||||
limit *int
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchListResponse represents the response of the batch list service.
|
|
||||||
type BatchListResponse struct {
|
|
||||||
Object string `json:"object"`
|
|
||||||
Data []BatchItem `json:"data"`
|
|
||||||
FirstID string `json:"first_id"`
|
|
||||||
LastID string `json:"last_id"`
|
|
||||||
HasMore bool `json:"has_more"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewBatchListService creates a new BatchListService.
|
|
||||||
func NewBatchListService(client *Client) *BatchListService {
|
|
||||||
return &BatchListService{client: client}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetAfter sets the after cursor for the batch list service.
|
|
||||||
func (s *BatchListService) SetAfter(after string) *BatchListService {
|
|
||||||
s.after = &after
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetLimit sets the limit for the batch list service.
|
|
||||||
func (s *BatchListService) SetLimit(limit int) *BatchListService {
|
|
||||||
s.limit = &limit
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do executes the batch list service.
|
|
||||||
func (s *BatchListService) Do(ctx context.Context) (res BatchListResponse, err error) {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
|
|
||||||
req := s.client.request(ctx)
|
|
||||||
if s.after != nil {
|
|
||||||
req.SetQueryParam("after", *s.after)
|
|
||||||
}
|
|
||||||
if s.limit != nil {
|
|
||||||
req.SetQueryParam("limit", strconv.Itoa(*s.limit))
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp, err = req.
|
|
||||||
SetResult(&res).
|
|
||||||
SetError(&apiError).
|
|
||||||
Get("batches"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
@ -1,63 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"io"
|
|
||||||
)
|
|
||||||
|
|
||||||
// BatchSupport is the interface for services with batch support.
|
|
||||||
type BatchSupport interface {
|
|
||||||
BatchMethod() string
|
|
||||||
BatchURL() string
|
|
||||||
BatchBody() any
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchFileWriter is a writer for batch files.
|
|
||||||
type BatchFileWriter struct {
|
|
||||||
w io.Writer
|
|
||||||
je *json.Encoder
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewBatchFileWriter creates a new BatchFileWriter.
|
|
||||||
func NewBatchFileWriter(w io.Writer) *BatchFileWriter {
|
|
||||||
return &BatchFileWriter{w: w, je: json.NewEncoder(w)}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write writes a batch file.
|
|
||||||
func (b *BatchFileWriter) Write(customID string, s BatchSupport) error {
|
|
||||||
return b.je.Encode(M{
|
|
||||||
"custom_id": customID,
|
|
||||||
"method": s.BatchMethod(),
|
|
||||||
"url": s.BatchURL(),
|
|
||||||
"body": s.BatchBody(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchResultResponse is the response of a batch result.
|
|
||||||
type BatchResultResponse[T any] struct {
|
|
||||||
StatusCode int `json:"status_code"`
|
|
||||||
Body T `json:"body"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchResult is the result of a batch.
|
|
||||||
type BatchResult[T any] struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
CustomID string `json:"custom_id"`
|
|
||||||
Response BatchResultResponse[T] `json:"response"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchResultReader reads batch results.
|
|
||||||
type BatchResultReader[T any] struct {
|
|
||||||
r io.Reader
|
|
||||||
jd *json.Decoder
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewBatchResultReader creates a new BatchResultReader.
|
|
||||||
func NewBatchResultReader[T any](r io.Reader) *BatchResultReader[T] {
|
|
||||||
return &BatchResultReader[T]{r: r, jd: json.NewDecoder(r)}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read reads a batch result.
|
|
||||||
func (r *BatchResultReader[T]) Read(out *BatchResult[T]) error {
|
|
||||||
return r.jd.Decode(out)
|
|
||||||
}
|
|
@ -1,73 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"io"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestBatchFileWriter(t *testing.T) {
|
|
||||||
client, err := NewClient()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
buf := &bytes.Buffer{}
|
|
||||||
|
|
||||||
w := NewBatchFileWriter(buf)
|
|
||||||
err = w.Write("batch-1", client.ChatCompletion("a").AddMessage(ChatCompletionMessage{
|
|
||||||
Role: "user", Content: "hello",
|
|
||||||
}))
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = w.Write("batch-2", client.Embedding("c").SetInput("whoa"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = w.Write("batch-3", client.ImageGeneration("d").SetPrompt("whoa"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.Equal(t, `{"body":{"messages":[{"role":"user","content":"hello"}],"model":"a"},"custom_id":"batch-1","method":"POST","url":"/v4/chat/completions"}
|
|
||||||
{"body":{"input":"whoa","model":"c"},"custom_id":"batch-2","method":"POST","url":"/v4/embeddings"}
|
|
||||||
{"body":{"model":"d","prompt":"whoa"},"custom_id":"batch-3","method":"POST","url":"/v4/images/generations"}
|
|
||||||
`, buf.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBatchResultReader(t *testing.T) {
|
|
||||||
result := `
|
|
||||||
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":26,"prompt_tokens":89,"total_tokens":115},"model":"glm-4","id":"8668357533850320547","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"订单处理慢\"\n}\n'''"}}],"request_id":"615-request-1"}},"custom_id":"request-1","id":"batch_1791490810192076800"}
|
|
||||||
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":22,"prompt_tokens":94,"total_tokens":116},"model":"glm-4","id":"8668368425887509080","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"产品缺陷\"\n}\n'''"}}],"request_id":"616-request-2"}},"custom_id":"request-2","id":"batch_1791490810192076800"}
|
|
||||||
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":25,"prompt_tokens":86,"total_tokens":111},"model":"glm-4","id":"8668355815863214980","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"正面\",\n \"特定问题标注\": \"性价比\"\n}\n'''"}}],"request_id":"617-request-3"}},"custom_id":"request-3","id":"batch_1791490810192076800"}
|
|
||||||
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":28,"prompt_tokens":89,"total_tokens":117},"model":"glm-4","id":"8668355815863214981","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"说明文档不清晰\"\n}\n'''"}}],"request_id":"618-request-4"}},"custom_id":"request-4","id":"batch_1791490810192076800"}
|
|
||||||
|
|
||||||
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":26,"prompt_tokens":88,"total_tokens":114},"model":"glm-4","id":"8668357533850320546","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"中性\",\n \"特定问题标注\": \"价格问题\"\n}\n'''"}}],"request_id":"619-request-5"}},"custom_id":"request-5","id":"batch_1791490810192076800"}
|
|
||||||
|
|
||||||
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":26,"prompt_tokens":90,"total_tokens":116},"model":"glm-4","id":"8668356159460662846","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"配送延迟\"\n}\n'''"}}],"request_id":"620-request-6"}},"custom_id":"request-6","id":"batch_1791490810192076800"}
|
|
||||||
|
|
||||||
|
|
||||||
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":27,"prompt_tokens":88,"total_tokens":115},"model":"glm-4","id":"8668357671289274638","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"产品描述不符\"\n}\n'''"}}],"request_id":"621-request-7"}},"custom_id":"request-7","id":"batch_1791490810192076800"}
|
|
||||||
{"response":{"status_code":200,"body":{"created":1715959702,"usage":{"completion_tokens":26,"prompt_tokens":87,"total_tokens":113},"model":"glm-4","id":"8668355644064514872","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"正面\",\n \"特定问题标注\": \"客服态度\"\n}\n'''"}}],"request_id":"622-request-8"}},"custom_id":"request-8","id":"batch_1791490810192076800"}
|
|
||||||
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":29,"prompt_tokens":90,"total_tokens":119},"model":"glm-4","id":"8668357671289274639","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"包装问题, 产品损坏\"\n}\n'''"}}],"request_id":"623-request-9"}},"custom_id":"request-9","id":"batch_1791490810192076800"}
|
|
||||||
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":27,"prompt_tokens":87,"total_tokens":114},"model":"glm-4","id":"8668355644064514871","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"正面\",\n \"特定问题标注\": \"产品描述不符\"\n}\n'''"}}],"request_id":"624-request-10"}},"custom_id":"request-10","id":"batch_1791490810192076800"}
|
|
||||||
`
|
|
||||||
|
|
||||||
brr := NewBatchResultReader[ChatCompletionResponse](bytes.NewReader([]byte(result)))
|
|
||||||
|
|
||||||
var count int
|
|
||||||
|
|
||||||
for {
|
|
||||||
var res BatchResult[ChatCompletionResponse]
|
|
||||||
|
|
||||||
err := brr.Read(&res)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
if err == io.EOF {
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
require.Equal(t, 10, count)
|
|
||||||
require.NoError(t, err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
require.Equal(t, 200, res.Response.StatusCode)
|
|
||||||
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,59 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestBatchServiceAll(t *testing.T) {
|
|
||||||
client, err := NewClient()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
buf := &bytes.Buffer{}
|
|
||||||
|
|
||||||
bfw := NewBatchFileWriter(buf)
|
|
||||||
err = bfw.Write("batch_1", client.ChatCompletion("glm-4-flash").AddMessage(ChatCompletionMessage{
|
|
||||||
Role: RoleUser, Content: "你好呀",
|
|
||||||
}))
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = bfw.Write("batch_2", client.ChatCompletion("glm-4-flash").AddMessage(ChatCompletionMessage{
|
|
||||||
Role: RoleUser, Content: "你叫什么名字",
|
|
||||||
}))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
res, err := client.FileCreate(FilePurposeBatch).SetFile(bytes.NewReader(buf.Bytes()), "batch.jsonl").Do(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
fileID := res.FileCreateFineTuneResponse.ID
|
|
||||||
require.NotEmpty(t, fileID)
|
|
||||||
|
|
||||||
res1, err := client.BatchCreate().
|
|
||||||
SetInputFileID(fileID).
|
|
||||||
SetCompletionWindow(BatchCompletionWindow24h).
|
|
||||||
SetEndpoint(BatchEndpointV4ChatCompletions).Do(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotEmpty(t, res1.ID)
|
|
||||||
|
|
||||||
res2, err := client.BatchGet(res1.ID).Do(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, res2.ID, res1.ID)
|
|
||||||
|
|
||||||
res3, err := client.BatchList().Do(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotEmpty(t, res3.Data)
|
|
||||||
|
|
||||||
err = client.BatchCancel(res1.ID).Do(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBatchListService(t *testing.T) {
|
|
||||||
client, err := NewClient()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
res, err := client.BatchList().Do(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Log(res)
|
|
||||||
}
|
|
@ -1,577 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/go-resty/resty/v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
RoleSystem = "system"
|
|
||||||
RoleUser = "user"
|
|
||||||
RoleAssistant = "assistant"
|
|
||||||
RoleTool = "tool"
|
|
||||||
|
|
||||||
ToolChoiceAuto = "auto"
|
|
||||||
|
|
||||||
FinishReasonStop = "stop"
|
|
||||||
FinishReasonStopSequence = "stop_sequence"
|
|
||||||
FinishReasonToolCalls = "tool_calls"
|
|
||||||
FinishReasonLength = "length"
|
|
||||||
FinishReasonSensitive = "sensitive"
|
|
||||||
FinishReasonNetworkError = "network_error"
|
|
||||||
|
|
||||||
ToolTypeFunction = "function"
|
|
||||||
ToolTypeWebSearch = "web_search"
|
|
||||||
ToolTypeRetrieval = "retrieval"
|
|
||||||
|
|
||||||
MultiContentTypeText = "text"
|
|
||||||
MultiContentTypeImageURL = "image_url"
|
|
||||||
MultiContentTypeVideoURL = "video_url"
|
|
||||||
|
|
||||||
// New in GLM-4-AllTools
|
|
||||||
ToolTypeCodeInterpreter = "code_interpreter"
|
|
||||||
ToolTypeDrawingTool = "drawing_tool"
|
|
||||||
ToolTypeWebBrowser = "web_browser"
|
|
||||||
|
|
||||||
CodeInterpreterSandboxNone = "none"
|
|
||||||
CodeInterpreterSandboxAuto = "auto"
|
|
||||||
|
|
||||||
ChatCompletionStatusFailed = "failed"
|
|
||||||
ChatCompletionStatusCompleted = "completed"
|
|
||||||
ChatCompletionStatusRequiresAction = "requires_action"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ChatCompletionTool is the interface for chat completion tool
|
|
||||||
type ChatCompletionTool interface {
|
|
||||||
isChatCompletionTool()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatCompletionToolFunction is the function for chat completion
|
|
||||||
type ChatCompletionToolFunction struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Parameters any `json:"parameters"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ChatCompletionToolFunction) isChatCompletionTool() {}
|
|
||||||
|
|
||||||
// ChatCompletionToolRetrieval is the retrieval for chat completion
|
|
||||||
type ChatCompletionToolRetrieval struct {
|
|
||||||
KnowledgeID string `json:"knowledge_id"`
|
|
||||||
PromptTemplate string `json:"prompt_template,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ChatCompletionToolRetrieval) isChatCompletionTool() {}
|
|
||||||
|
|
||||||
// ChatCompletionToolWebSearch is the web search for chat completion
|
|
||||||
type ChatCompletionToolWebSearch struct {
|
|
||||||
Enable *bool `json:"enable,omitempty"`
|
|
||||||
SearchQuery string `json:"search_query,omitempty"`
|
|
||||||
SearchResult bool `json:"search_result,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ChatCompletionToolWebSearch) isChatCompletionTool() {}
|
|
||||||
|
|
||||||
// ChatCompletionToolCodeInterpreter is the code interpreter for chat completion
|
|
||||||
// only in GLM-4-AllTools
|
|
||||||
type ChatCompletionToolCodeInterpreter struct {
|
|
||||||
Sandbox *string `json:"sandbox,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ChatCompletionToolCodeInterpreter) isChatCompletionTool() {}
|
|
||||||
|
|
||||||
// ChatCompletionToolDrawingTool is the drawing tool for chat completion
|
|
||||||
// only in GLM-4-AllTools
|
|
||||||
type ChatCompletionToolDrawingTool struct {
|
|
||||||
// no fields
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ChatCompletionToolDrawingTool) isChatCompletionTool() {}
|
|
||||||
|
|
||||||
// ChatCompletionToolWebBrowser is the web browser for chat completion
|
|
||||||
type ChatCompletionToolWebBrowser struct {
|
|
||||||
// no fields
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ChatCompletionToolWebBrowser) isChatCompletionTool() {}
|
|
||||||
|
|
||||||
// ChatCompletionUsage is the usage for chat completion
|
|
||||||
type ChatCompletionUsage struct {
|
|
||||||
PromptTokens int64 `json:"prompt_tokens"`
|
|
||||||
CompletionTokens int64 `json:"completion_tokens"`
|
|
||||||
TotalTokens int64 `json:"total_tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatCompletionWebSearch is the web search result for chat completion
|
|
||||||
type ChatCompletionWebSearch struct {
|
|
||||||
Icon string `json:"icon"`
|
|
||||||
Title string `json:"title"`
|
|
||||||
Link string `json:"link"`
|
|
||||||
Media string `json:"media"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatCompletionToolCallFunction is the function for chat completion tool call
|
|
||||||
type ChatCompletionToolCallFunction struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Arguments json.RawMessage `json:"arguments"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatCompletionToolCallCodeInterpreterOutput is the output for chat completion tool call code interpreter
|
|
||||||
type ChatCompletionToolCallCodeInterpreterOutput struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Logs string `json:"logs"`
|
|
||||||
File string `json:"file"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatCompletionToolCallCodeInterpreter is the code interpreter for chat completion tool call
|
|
||||||
type ChatCompletionToolCallCodeInterpreter struct {
|
|
||||||
Input string `json:"input"`
|
|
||||||
Outputs []ChatCompletionToolCallCodeInterpreterOutput `json:"outputs"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatCompletionToolCallDrawingToolOutput is the output for chat completion tool call drawing tool
|
|
||||||
type ChatCompletionToolCallDrawingToolOutput struct {
|
|
||||||
Image string `json:"image"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatCompletionToolCallDrawingTool is the drawing tool for chat completion tool call
|
|
||||||
type ChatCompletionToolCallDrawingTool struct {
|
|
||||||
Input string `json:"input"`
|
|
||||||
Outputs []ChatCompletionToolCallDrawingToolOutput `json:"outputs"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatCompletionToolCallWebBrowserOutput is the output for chat completion tool call web browser
|
|
||||||
type ChatCompletionToolCallWebBrowserOutput struct {
|
|
||||||
Title string `json:"title"`
|
|
||||||
Link string `json:"link"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatCompletionToolCallWebBrowser is the web browser for chat completion tool call
|
|
||||||
type ChatCompletionToolCallWebBrowser struct {
|
|
||||||
Input string `json:"input"`
|
|
||||||
Outputs []ChatCompletionToolCallWebBrowserOutput `json:"outputs"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatCompletionToolCall is the tool call for chat completion
|
|
||||||
type ChatCompletionToolCall struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
Function *ChatCompletionToolCallFunction `json:"function,omitempty"`
|
|
||||||
CodeInterpreter *ChatCompletionToolCallCodeInterpreter `json:"code_interpreter,omitempty"`
|
|
||||||
DrawingTool *ChatCompletionToolCallDrawingTool `json:"drawing_tool,omitempty"`
|
|
||||||
WebBrowser *ChatCompletionToolCallWebBrowser `json:"web_browser,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionMessageType interface {
|
|
||||||
isChatCompletionMessageType()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatCompletionMessage is the message for chat completion
|
|
||||||
type ChatCompletionMessage struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content,omitempty"`
|
|
||||||
ToolCalls []ChatCompletionToolCall `json:"tool_calls,omitempty"`
|
|
||||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ChatCompletionMessage) isChatCompletionMessageType() {}
|
|
||||||
|
|
||||||
type ChatCompletionMultiContent struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Text string `json:"text"`
|
|
||||||
ImageURL *URLItem `json:"image_url,omitempty"`
|
|
||||||
VideoURL *URLItem `json:"video_url,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatCompletionMultiMessage is the multi message for chat completion
|
|
||||||
type ChatCompletionMultiMessage struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content []ChatCompletionMultiContent `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ChatCompletionMultiMessage) isChatCompletionMessageType() {}
|
|
||||||
|
|
||||||
// ChatCompletionMeta is the meta for chat completion
|
|
||||||
type ChatCompletionMeta struct {
|
|
||||||
UserInfo string `json:"user_info"`
|
|
||||||
BotInfo string `json:"bot_info"`
|
|
||||||
UserName string `json:"user_name"`
|
|
||||||
BotName string `json:"bot_name"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatCompletionChoice is the choice for chat completion
|
|
||||||
type ChatCompletionChoice struct {
|
|
||||||
Index int `json:"index"`
|
|
||||||
FinishReason string `json:"finish_reason"`
|
|
||||||
Delta ChatCompletionMessage `json:"delta"` // stream mode
|
|
||||||
Message ChatCompletionMessage `json:"message"` // non-stream mode
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatCompletionResponse is the response for chat completion
|
|
||||||
type ChatCompletionResponse struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Choices []ChatCompletionChoice `json:"choices"`
|
|
||||||
Usage ChatCompletionUsage `json:"usage"`
|
|
||||||
WebSearch []ChatCompletionWebSearch `json:"web_search"`
|
|
||||||
// Status is the status of the chat completion, only in GLM-4-AllTools
|
|
||||||
Status string `json:"status"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatCompletionStreamHandler is the handler for chat completion stream
|
|
||||||
type ChatCompletionStreamHandler func(chunk ChatCompletionResponse) error
|
|
||||||
|
|
||||||
var (
|
|
||||||
chatCompletionStreamPrefix = []byte("data:")
|
|
||||||
chatCompletionStreamDone = []byte("[DONE]")
|
|
||||||
)
|
|
||||||
|
|
||||||
// chatCompletionReduceResponse reduce the chunk to the response
|
|
||||||
func chatCompletionReduceResponse(out *ChatCompletionResponse, chunk ChatCompletionResponse) {
|
|
||||||
if len(out.Choices) == 0 {
|
|
||||||
out.Choices = append(out.Choices, ChatCompletionChoice{})
|
|
||||||
}
|
|
||||||
|
|
||||||
// basic
|
|
||||||
out.ID = chunk.ID
|
|
||||||
out.Created = chunk.Created
|
|
||||||
out.Model = chunk.Model
|
|
||||||
|
|
||||||
// choices
|
|
||||||
if len(chunk.Choices) != 0 {
|
|
||||||
oc := &out.Choices[0]
|
|
||||||
cc := chunk.Choices[0]
|
|
||||||
|
|
||||||
oc.Index = cc.Index
|
|
||||||
if cc.Delta.Role != "" {
|
|
||||||
oc.Message.Role = cc.Delta.Role
|
|
||||||
}
|
|
||||||
oc.Message.Content += cc.Delta.Content
|
|
||||||
oc.Message.ToolCalls = append(oc.Message.ToolCalls, cc.Delta.ToolCalls...)
|
|
||||||
if cc.FinishReason != "" {
|
|
||||||
oc.FinishReason = cc.FinishReason
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// usage
|
|
||||||
if chunk.Usage.CompletionTokens != 0 {
|
|
||||||
out.Usage.CompletionTokens = chunk.Usage.CompletionTokens
|
|
||||||
}
|
|
||||||
if chunk.Usage.PromptTokens != 0 {
|
|
||||||
out.Usage.PromptTokens = chunk.Usage.PromptTokens
|
|
||||||
}
|
|
||||||
if chunk.Usage.TotalTokens != 0 {
|
|
||||||
out.Usage.TotalTokens = chunk.Usage.TotalTokens
|
|
||||||
}
|
|
||||||
|
|
||||||
// web search
|
|
||||||
out.WebSearch = append(out.WebSearch, chunk.WebSearch...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// chatCompletionDecodeStream decode the sse stream of chat completion
|
|
||||||
func chatCompletionDecodeStream(r io.Reader, fn func(chunk ChatCompletionResponse) error) (err error) {
|
|
||||||
br := bufio.NewReader(r)
|
|
||||||
|
|
||||||
for {
|
|
||||||
var line []byte
|
|
||||||
|
|
||||||
if line, err = br.ReadBytes('\n'); err != nil {
|
|
||||||
if errors.Is(err, io.EOF) {
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
line = bytes.TrimSpace(line)
|
|
||||||
|
|
||||||
if len(line) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if !bytes.HasPrefix(line, chatCompletionStreamPrefix) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
data := bytes.TrimSpace(line[len(chatCompletionStreamPrefix):])
|
|
||||||
|
|
||||||
if bytes.Equal(data, chatCompletionStreamDone) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(data) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var chunk ChatCompletionResponse
|
|
||||||
if err = json.Unmarshal(data, &chunk); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err = fn(chunk); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatCompletionStreamService is the service for chat completion stream
|
|
||||||
type ChatCompletionService struct {
|
|
||||||
client *Client
|
|
||||||
|
|
||||||
model string
|
|
||||||
requestID *string
|
|
||||||
doSample *bool
|
|
||||||
temperature *float64
|
|
||||||
topP *float64
|
|
||||||
maxTokens *int
|
|
||||||
stop []string
|
|
||||||
toolChoice *string
|
|
||||||
userID *string
|
|
||||||
meta *ChatCompletionMeta
|
|
||||||
|
|
||||||
messages []any
|
|
||||||
tools []any
|
|
||||||
|
|
||||||
streamHandler ChatCompletionStreamHandler
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
_ BatchSupport = &ChatCompletionService{}
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewChatCompletionService creates a new ChatCompletionService.
|
|
||||||
func NewChatCompletionService(client *Client) *ChatCompletionService {
|
|
||||||
return &ChatCompletionService{
|
|
||||||
client: client,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ChatCompletionService) BatchMethod() string {
|
|
||||||
return "POST"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ChatCompletionService) BatchURL() string {
|
|
||||||
return BatchEndpointV4ChatCompletions
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ChatCompletionService) BatchBody() any {
|
|
||||||
return s.buildBody()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetModel set the model of the chat completion
|
|
||||||
func (s *ChatCompletionService) SetModel(model string) *ChatCompletionService {
|
|
||||||
s.model = model
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetMeta set the meta of the chat completion, optional
|
|
||||||
func (s *ChatCompletionService) SetMeta(meta ChatCompletionMeta) *ChatCompletionService {
|
|
||||||
s.meta = &meta
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetRequestID set the request id of the chat completion, optional
|
|
||||||
func (s *ChatCompletionService) SetRequestID(requestID string) *ChatCompletionService {
|
|
||||||
s.requestID = &requestID
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetTemperature set the temperature of the chat completion, optional
|
|
||||||
func (s *ChatCompletionService) SetDoSample(doSample bool) *ChatCompletionService {
|
|
||||||
s.doSample = &doSample
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetTemperature set the temperature of the chat completion, optional
|
|
||||||
func (s *ChatCompletionService) SetTemperature(temperature float64) *ChatCompletionService {
|
|
||||||
s.temperature = &temperature
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetTopP set the top p of the chat completion, optional
|
|
||||||
func (s *ChatCompletionService) SetTopP(topP float64) *ChatCompletionService {
|
|
||||||
s.topP = &topP
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetMaxTokens set the max tokens of the chat completion, optional
|
|
||||||
func (s *ChatCompletionService) SetMaxTokens(maxTokens int) *ChatCompletionService {
|
|
||||||
s.maxTokens = &maxTokens
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetStop set the stop of the chat completion, optional
|
|
||||||
func (s *ChatCompletionService) SetStop(stop ...string) *ChatCompletionService {
|
|
||||||
s.stop = stop
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetToolChoice set the tool choice of the chat completion, optional
|
|
||||||
func (s *ChatCompletionService) SetToolChoice(toolChoice string) *ChatCompletionService {
|
|
||||||
s.toolChoice = &toolChoice
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetUserID set the user id of the chat completion, optional
|
|
||||||
func (s *ChatCompletionService) SetUserID(userID string) *ChatCompletionService {
|
|
||||||
s.userID = &userID
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetStreamHandler set the stream handler of the chat completion, optional
|
|
||||||
// this will enable the stream mode
|
|
||||||
func (s *ChatCompletionService) SetStreamHandler(handler ChatCompletionStreamHandler) *ChatCompletionService {
|
|
||||||
s.streamHandler = handler
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddMessage add the message to the chat completion
|
|
||||||
func (s *ChatCompletionService) AddMessage(messages ...ChatCompletionMessageType) *ChatCompletionService {
|
|
||||||
for _, message := range messages {
|
|
||||||
s.messages = append(s.messages, message)
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddFunction add the function to the chat completion
|
|
||||||
func (s *ChatCompletionService) AddTool(tools ...ChatCompletionTool) *ChatCompletionService {
|
|
||||||
for _, tool := range tools {
|
|
||||||
switch tool := tool.(type) {
|
|
||||||
case ChatCompletionToolFunction:
|
|
||||||
s.tools = append(s.tools, map[string]any{
|
|
||||||
"type": ToolTypeFunction,
|
|
||||||
ToolTypeFunction: tool,
|
|
||||||
})
|
|
||||||
case ChatCompletionToolRetrieval:
|
|
||||||
s.tools = append(s.tools, map[string]any{
|
|
||||||
"type": ToolTypeRetrieval,
|
|
||||||
ToolTypeRetrieval: tool,
|
|
||||||
})
|
|
||||||
case ChatCompletionToolWebSearch:
|
|
||||||
s.tools = append(s.tools, map[string]any{
|
|
||||||
"type": ToolTypeWebSearch,
|
|
||||||
ToolTypeWebSearch: tool,
|
|
||||||
})
|
|
||||||
case ChatCompletionToolCodeInterpreter:
|
|
||||||
s.tools = append(s.tools, map[string]any{
|
|
||||||
"type": ToolTypeCodeInterpreter,
|
|
||||||
ToolTypeCodeInterpreter: tool,
|
|
||||||
})
|
|
||||||
case ChatCompletionToolDrawingTool:
|
|
||||||
s.tools = append(s.tools, map[string]any{
|
|
||||||
"type": ToolTypeDrawingTool,
|
|
||||||
ToolTypeDrawingTool: tool,
|
|
||||||
})
|
|
||||||
case ChatCompletionToolWebBrowser:
|
|
||||||
s.tools = append(s.tools, map[string]any{
|
|
||||||
"type": ToolTypeWebBrowser,
|
|
||||||
ToolTypeWebBrowser: tool,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ChatCompletionService) buildBody() M {
|
|
||||||
body := map[string]any{
|
|
||||||
"model": s.model,
|
|
||||||
"messages": s.messages,
|
|
||||||
}
|
|
||||||
if s.requestID != nil {
|
|
||||||
body["request_id"] = *s.requestID
|
|
||||||
}
|
|
||||||
if s.doSample != nil {
|
|
||||||
body["do_sample"] = *s.doSample
|
|
||||||
}
|
|
||||||
if s.temperature != nil {
|
|
||||||
body["temperature"] = *s.temperature
|
|
||||||
}
|
|
||||||
if s.topP != nil {
|
|
||||||
body["top_p"] = *s.topP
|
|
||||||
}
|
|
||||||
if s.maxTokens != nil {
|
|
||||||
body["max_tokens"] = *s.maxTokens
|
|
||||||
}
|
|
||||||
if len(s.stop) != 0 {
|
|
||||||
body["stop"] = s.stop
|
|
||||||
}
|
|
||||||
if len(s.tools) != 0 {
|
|
||||||
body["tools"] = s.tools
|
|
||||||
}
|
|
||||||
if s.toolChoice != nil {
|
|
||||||
body["tool_choice"] = *s.toolChoice
|
|
||||||
}
|
|
||||||
if s.userID != nil {
|
|
||||||
body["user_id"] = *s.userID
|
|
||||||
}
|
|
||||||
if s.meta != nil {
|
|
||||||
body["meta"] = s.meta
|
|
||||||
}
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do send the request of the chat completion and return the response
|
|
||||||
func (s *ChatCompletionService) Do(ctx context.Context) (res ChatCompletionResponse, err error) {
|
|
||||||
body := s.buildBody()
|
|
||||||
|
|
||||||
streamHandler := s.streamHandler
|
|
||||||
|
|
||||||
if streamHandler == nil {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
//fmt.Println(u.BMagenta(u.JsonP(body)), 111)
|
|
||||||
if resp, err = s.client.request(ctx).SetBody(body).SetResult(&res).SetError(&apiError).Post("chat/completions"); err != nil {
|
|
||||||
//fmt.Println(u.BRed(err.Error()), 2221)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
//fmt.Println(u.BRed(err.Error()), 2222)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
//fmt.Println(u.BGreen(u.JsonP(resp.Result())), resp.Status(), resp.Status(), 333)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// stream mode
|
|
||||||
|
|
||||||
body["stream"] = true
|
|
||||||
|
|
||||||
var resp *resty.Response
|
|
||||||
|
|
||||||
if resp, err = s.client.request(ctx).SetBody(body).SetDoNotParseResponse(true).Post("chat/completions"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer resp.RawBody().Close()
|
|
||||||
|
|
||||||
if resp.IsError() {
|
|
||||||
err = errors.New(resp.Status())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var choice ChatCompletionChoice
|
|
||||||
|
|
||||||
if err = chatCompletionDecodeStream(resp.RawBody(), func(chunk ChatCompletionResponse) error {
|
|
||||||
// reduce the chunk to the response
|
|
||||||
chatCompletionReduceResponse(&res, chunk)
|
|
||||||
// invoke the stream handler
|
|
||||||
return streamHandler(chunk)
|
|
||||||
}); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
res.Choices = append(res.Choices, choice)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
@ -1,251 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestChatCompletionService(t *testing.T) {
|
|
||||||
client, err := NewClient()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
s := client.ChatCompletion("glm-4-flash")
|
|
||||||
s.AddMessage(ChatCompletionMessage{
|
|
||||||
Role: RoleUser,
|
|
||||||
Content: "你好呀",
|
|
||||||
})
|
|
||||||
res, err := s.Do(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotEmpty(t, res.Choices)
|
|
||||||
choice := res.Choices[0]
|
|
||||||
require.Equal(t, FinishReasonStop, choice.FinishReason)
|
|
||||||
require.NotEmpty(t, choice.Message.Content)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestChatCompletionServiceCharGLM(t *testing.T) {
|
|
||||||
client, err := NewClient()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
s := client.ChatCompletion("charglm-3")
|
|
||||||
s.SetMeta(
|
|
||||||
ChatCompletionMeta{
|
|
||||||
UserName: "啵酱",
|
|
||||||
UserInfo: "啵酱是小少爷",
|
|
||||||
BotName: "塞巴斯酱",
|
|
||||||
BotInfo: "塞巴斯酱是一个冷酷的恶魔管家",
|
|
||||||
},
|
|
||||||
).AddMessage(ChatCompletionMessage{
|
|
||||||
Role: RoleUser,
|
|
||||||
Content: "早上好",
|
|
||||||
})
|
|
||||||
res, err := s.Do(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotEmpty(t, res.Choices)
|
|
||||||
choice := res.Choices[0]
|
|
||||||
require.Contains(t, []string{FinishReasonStop, FinishReasonStopSequence}, choice.FinishReason)
|
|
||||||
require.NotEmpty(t, choice.Message.Content)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestChatCompletionServiceAllToolsCodeInterpreter(t *testing.T) {
|
|
||||||
client, err := NewClient()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
s := client.ChatCompletion("GLM-4-AllTools")
|
|
||||||
s.AddMessage(ChatCompletionMultiMessage{
|
|
||||||
Role: "user",
|
|
||||||
Content: []ChatCompletionMultiContent{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: "计算[5,10,20,700,99,310,978,100]的平均值和方差。",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
s.AddTool(ChatCompletionToolCodeInterpreter{
|
|
||||||
Sandbox: Ptr(CodeInterpreterSandboxAuto),
|
|
||||||
})
|
|
||||||
|
|
||||||
foundInterpreterInput := false
|
|
||||||
foundInterpreterOutput := false
|
|
||||||
|
|
||||||
s.SetStreamHandler(func(chunk ChatCompletionResponse) error {
|
|
||||||
for _, c := range chunk.Choices {
|
|
||||||
for _, tc := range c.Delta.ToolCalls {
|
|
||||||
if tc.Type == ToolTypeCodeInterpreter && tc.CodeInterpreter != nil {
|
|
||||||
if tc.CodeInterpreter.Input != "" {
|
|
||||||
foundInterpreterInput = true
|
|
||||||
}
|
|
||||||
if len(tc.CodeInterpreter.Outputs) > 0 {
|
|
||||||
foundInterpreterOutput = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
buf, _ := json.MarshalIndent(chunk, "", " ")
|
|
||||||
t.Log(string(buf))
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
res, err := s.Do(context.Background())
|
|
||||||
require.True(t, foundInterpreterInput)
|
|
||||||
require.True(t, foundInterpreterOutput)
|
|
||||||
require.NotNil(t, res)
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestChatCompletionServiceAllToolsDrawingTool(t *testing.T) {
|
|
||||||
client, err := NewClient()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
s := client.ChatCompletion("GLM-4-AllTools")
|
|
||||||
s.AddMessage(ChatCompletionMultiMessage{
|
|
||||||
Role: "user",
|
|
||||||
Content: []ChatCompletionMultiContent{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: "画一个正弦函数图像",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
s.AddTool(ChatCompletionToolDrawingTool{})
|
|
||||||
|
|
||||||
foundInput := false
|
|
||||||
foundOutput := false
|
|
||||||
outputImage := ""
|
|
||||||
|
|
||||||
s.SetStreamHandler(func(chunk ChatCompletionResponse) error {
|
|
||||||
for _, c := range chunk.Choices {
|
|
||||||
for _, tc := range c.Delta.ToolCalls {
|
|
||||||
if tc.Type == ToolTypeDrawingTool && tc.DrawingTool != nil {
|
|
||||||
if tc.DrawingTool.Input != "" {
|
|
||||||
foundInput = true
|
|
||||||
}
|
|
||||||
if len(tc.DrawingTool.Outputs) > 0 {
|
|
||||||
foundOutput = true
|
|
||||||
}
|
|
||||||
for _, output := range tc.DrawingTool.Outputs {
|
|
||||||
if output.Image != "" {
|
|
||||||
outputImage = output.Image
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
buf, _ := json.MarshalIndent(chunk, "", " ")
|
|
||||||
t.Log(string(buf))
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
res, err := s.Do(context.Background())
|
|
||||||
require.True(t, foundInput)
|
|
||||||
require.True(t, foundOutput)
|
|
||||||
require.NotEmpty(t, outputImage)
|
|
||||||
t.Log(outputImage)
|
|
||||||
require.NotNil(t, res)
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestChatCompletionServiceAllToolsWebBrowser(t *testing.T) {
|
|
||||||
client, err := NewClient()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
s := client.ChatCompletion("GLM-4-AllTools")
|
|
||||||
s.AddMessage(ChatCompletionMultiMessage{
|
|
||||||
Role: "user",
|
|
||||||
Content: []ChatCompletionMultiContent{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: "搜索下本周深圳天气如何",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
s.AddTool(ChatCompletionToolWebBrowser{})
|
|
||||||
|
|
||||||
foundInput := false
|
|
||||||
foundOutput := false
|
|
||||||
outputContent := ""
|
|
||||||
|
|
||||||
s.SetStreamHandler(func(chunk ChatCompletionResponse) error {
|
|
||||||
for _, c := range chunk.Choices {
|
|
||||||
for _, tc := range c.Delta.ToolCalls {
|
|
||||||
if tc.Type == ToolTypeWebBrowser && tc.WebBrowser != nil {
|
|
||||||
if tc.WebBrowser.Input != "" {
|
|
||||||
foundInput = true
|
|
||||||
}
|
|
||||||
if len(tc.WebBrowser.Outputs) > 0 {
|
|
||||||
foundOutput = true
|
|
||||||
}
|
|
||||||
for _, output := range tc.WebBrowser.Outputs {
|
|
||||||
if output.Content != "" {
|
|
||||||
outputContent = output.Content
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
buf, _ := json.MarshalIndent(chunk, "", " ")
|
|
||||||
t.Log(string(buf))
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
res, err := s.Do(context.Background())
|
|
||||||
require.True(t, foundInput)
|
|
||||||
require.True(t, foundOutput)
|
|
||||||
require.NotEmpty(t, outputContent)
|
|
||||||
t.Log(outputContent)
|
|
||||||
require.NotNil(t, res)
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestChatCompletionServiceStream(t *testing.T) {
|
|
||||||
client, err := NewClient()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
var content string
|
|
||||||
|
|
||||||
s := client.ChatCompletion("glm-4-flash").AddMessage(ChatCompletionMessage{
|
|
||||||
Role: RoleUser,
|
|
||||||
Content: "你好呀",
|
|
||||||
}).SetStreamHandler(func(chunk ChatCompletionResponse) error {
|
|
||||||
content += chunk.Choices[0].Delta.Content
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
res, err := s.Do(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotEmpty(t, res.Choices)
|
|
||||||
choice := res.Choices[0]
|
|
||||||
require.Equal(t, FinishReasonStop, choice.FinishReason)
|
|
||||||
require.NotEmpty(t, choice.Message.Content)
|
|
||||||
require.Equal(t, content, choice.Message.Content)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestChatCompletionServiceVision(t *testing.T) {
|
|
||||||
client, err := NewClient()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
s := client.ChatCompletion("glm-4v")
|
|
||||||
s.AddMessage(ChatCompletionMultiMessage{
|
|
||||||
Role: RoleUser,
|
|
||||||
Content: []ChatCompletionMultiContent{
|
|
||||||
{
|
|
||||||
Type: MultiContentTypeText,
|
|
||||||
Text: "图里有什么",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Type: MultiContentTypeImageURL,
|
|
||||||
ImageURL: &URLItem{
|
|
||||||
URL: "https://img1.baidu.com/it/u=1369931113,3388870256&fm=253&app=138&size=w931&n=0&f=JPEG&fmt=auto?sec=1703696400&t=f3028c7a1dca43a080aeb8239f09cc2f",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
res, err := s.Do(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotEmpty(t, res.Choices)
|
|
||||||
require.NotZero(t, res.Usage.CompletionTokens)
|
|
||||||
choice := res.Choices[0]
|
|
||||||
require.Equal(t, FinishReasonStop, choice.FinishReason)
|
|
||||||
require.NotEmpty(t, choice.Message.Content)
|
|
||||||
}
|
|
291
zhipu/client.go
291
zhipu/client.go
@ -1,291 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-resty/resty/v2"
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
envAPIKey = "ZHIPUAI_API_KEY"
|
|
||||||
envBaseURL = "ZHIPUAI_BASE_URL"
|
|
||||||
envDebug = "ZHIPUAI_DEBUG"
|
|
||||||
|
|
||||||
defaultBaseURL = "https://open.bigmodel.cn/api/paas/v4"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// ErrAPIKeyMissing is the error when the api key is missing
|
|
||||||
ErrAPIKeyMissing = errors.New("zhipu: api key is missing")
|
|
||||||
// ErrAPIKeyMalformed is the error when the api key is malformed
|
|
||||||
ErrAPIKeyMalformed = errors.New("zhipu: api key is malformed")
|
|
||||||
)
|
|
||||||
|
|
||||||
type clientOptions struct {
|
|
||||||
baseURL string
|
|
||||||
apiKey string
|
|
||||||
client *http.Client
|
|
||||||
debug *bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClientOption is a function that configures the client
|
|
||||||
type ClientOption func(opts *clientOptions)
|
|
||||||
|
|
||||||
// WithAPIKey set the api key of the client
|
|
||||||
func WithAPIKey(apiKey string) ClientOption {
|
|
||||||
return func(opts *clientOptions) {
|
|
||||||
opts.apiKey = apiKey
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithBaseURL set the base url of the client
|
|
||||||
func WithBaseURL(baseURL string) ClientOption {
|
|
||||||
return func(opts *clientOptions) {
|
|
||||||
opts.baseURL = baseURL
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithHTTPClient set the http client of the client
|
|
||||||
func WithHTTPClient(client *http.Client) ClientOption {
|
|
||||||
return func(opts *clientOptions) {
|
|
||||||
opts.client = client
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithDebug set the debug mode of the client
|
|
||||||
func WithDebug(debug bool) ClientOption {
|
|
||||||
return func(opts *clientOptions) {
|
|
||||||
opts.debug = new(bool)
|
|
||||||
*opts.debug = debug
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Client is the client for zhipu ai platform
|
|
||||||
type Client struct {
|
|
||||||
client *resty.Client
|
|
||||||
debug bool
|
|
||||||
keyID string
|
|
||||||
keySecret []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) createJWT() string {
|
|
||||||
timestamp := time.Now().UnixMilli()
|
|
||||||
exp := timestamp + time.Hour.Milliseconds()*24*7
|
|
||||||
|
|
||||||
t := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
|
||||||
"api_key": c.keyID,
|
|
||||||
"timestamp": timestamp,
|
|
||||||
"exp": exp,
|
|
||||||
})
|
|
||||||
t.Header = map[string]interface{}{
|
|
||||||
"alg": "HS256",
|
|
||||||
"sign_type": "SIGN",
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := t.SignedString(c.keySecret)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return token
|
|
||||||
}
|
|
||||||
|
|
||||||
// request creates a new resty request with the jwt token and context
|
|
||||||
func (c *Client) request(ctx context.Context) *resty.Request {
|
|
||||||
return c.client.R().SetContext(ctx).SetHeader("Authorization", c.createJWT())
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewClient creates a new client
|
|
||||||
// It will read the api key from the environment variable ZHIPUAI_API_KEY
|
|
||||||
// It will read the base url from the environment variable ZHIPUAI_BASE_URL
|
|
||||||
func NewClient(optFns ...ClientOption) (client *Client, err error) {
|
|
||||||
var opts clientOptions
|
|
||||||
for _, optFn := range optFns {
|
|
||||||
optFn(&opts)
|
|
||||||
}
|
|
||||||
// base url
|
|
||||||
if opts.baseURL == "" {
|
|
||||||
opts.baseURL = strings.TrimSpace(os.Getenv(envBaseURL))
|
|
||||||
}
|
|
||||||
if opts.baseURL == "" {
|
|
||||||
opts.baseURL = defaultBaseURL
|
|
||||||
}
|
|
||||||
// api key
|
|
||||||
if opts.apiKey == "" {
|
|
||||||
opts.apiKey = strings.TrimSpace(os.Getenv(envAPIKey))
|
|
||||||
}
|
|
||||||
if opts.apiKey == "" {
|
|
||||||
err = ErrAPIKeyMissing
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// debug
|
|
||||||
if opts.debug == nil {
|
|
||||||
if debugStr := strings.TrimSpace(os.Getenv(envDebug)); debugStr != "" {
|
|
||||||
if debug, err1 := strconv.ParseBool(debugStr); err1 == nil {
|
|
||||||
opts.debug = &debug
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
keyComponents := strings.SplitN(opts.apiKey, ".", 2)
|
|
||||||
|
|
||||||
if len(keyComponents) != 2 {
|
|
||||||
err = ErrAPIKeyMalformed
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
client = &Client{
|
|
||||||
keyID: keyComponents[0],
|
|
||||||
keySecret: []byte(keyComponents[1]),
|
|
||||||
}
|
|
||||||
|
|
||||||
if opts.client == nil {
|
|
||||||
client.client = resty.New()
|
|
||||||
} else {
|
|
||||||
client.client = resty.NewWithClient(opts.client)
|
|
||||||
}
|
|
||||||
|
|
||||||
client.client = client.client.SetBaseURL(opts.baseURL)
|
|
||||||
|
|
||||||
if opts.debug != nil {
|
|
||||||
client.client.SetDebug(*opts.debug)
|
|
||||||
client.debug = *opts.debug
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchCreate creates a new BatchCreateService.
|
|
||||||
func (c *Client) BatchCreate() *BatchCreateService {
|
|
||||||
return NewBatchCreateService(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchGet creates a new BatchGetService.
|
|
||||||
func (c *Client) BatchGet(batchID string) *BatchGetService {
|
|
||||||
return NewBatchGetService(c).SetBatchID(batchID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchCancel creates a new BatchCancelService.
|
|
||||||
func (c *Client) BatchCancel(batchID string) *BatchCancelService {
|
|
||||||
return NewBatchCancelService(c).SetBatchID(batchID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchList creates a new BatchListService.
|
|
||||||
func (c *Client) BatchList() *BatchListService {
|
|
||||||
return NewBatchListService(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatCompletion creates a new ChatCompletionService.
|
|
||||||
func (c *Client) ChatCompletion(model string) *ChatCompletionService {
|
|
||||||
return NewChatCompletionService(c).SetModel(model)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Embedding embeds a list of text into a vector space.
|
|
||||||
func (c *Client) Embedding(model string) *EmbeddingService {
|
|
||||||
return NewEmbeddingService(c).SetModel(model)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileCreate creates a new FileCreateService.
|
|
||||||
func (c *Client) FileCreate(purpose string) *FileCreateService {
|
|
||||||
return NewFileCreateService(c).SetPurpose(purpose)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileEditService creates a new FileEditService.
|
|
||||||
func (c *Client) FileEdit(documentID string) *FileEditService {
|
|
||||||
return NewFileEditService(c).SetDocumentID(documentID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileList creates a new FileListService.
|
|
||||||
func (c *Client) FileList(purpose string) *FileListService {
|
|
||||||
return NewFileListService(c).SetPurpose(purpose)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileDeleteService creates a new FileDeleteService.
|
|
||||||
func (c *Client) FileDelete(documentID string) *FileDeleteService {
|
|
||||||
return NewFileDeleteService(c).SetDocumentID(documentID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileGetService creates a new FileGetService.
|
|
||||||
func (c *Client) FileGet(documentID string) *FileGetService {
|
|
||||||
return NewFileGetService(c).SetDocumentID(documentID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileDownload creates a new FileDownloadService.
|
|
||||||
func (c *Client) FileDownload(fileID string) *FileDownloadService {
|
|
||||||
return NewFileDownloadService(c).SetFileID(fileID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FineTuneCreate creates a new fine tune create service
|
|
||||||
func (c *Client) FineTuneCreate(model string) *FineTuneCreateService {
|
|
||||||
return NewFineTuneCreateService(c).SetModel(model)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FineTuneEventList creates a new fine tune event list service
|
|
||||||
func (c *Client) FineTuneEventList(jobID string) *FineTuneEventListService {
|
|
||||||
return NewFineTuneEventListService(c).SetJobID(jobID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FineTuneGet creates a new fine tune get service
|
|
||||||
func (c *Client) FineTuneGet(jobID string) *FineTuneGetService {
|
|
||||||
return NewFineTuneGetService(c).SetJobID(jobID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FineTuneList creates a new fine tune list service
|
|
||||||
func (c *Client) FineTuneList() *FineTuneListService {
|
|
||||||
return NewFineTuneListService(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FineTuneDelete creates a new fine tune delete service
|
|
||||||
func (c *Client) FineTuneDelete(jobID string) *FineTuneDeleteService {
|
|
||||||
return NewFineTuneDeleteService(c).SetJobID(jobID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FineTuneCancel creates a new fine tune cancel service
|
|
||||||
func (c *Client) FineTuneCancel(jobID string) *FineTuneCancelService {
|
|
||||||
return NewFineTuneCancelService(c).SetJobID(jobID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ImageGeneration creates a new image generation service
|
|
||||||
func (c *Client) ImageGeneration(model string) *ImageGenerationService {
|
|
||||||
return NewImageGenerationService(c).SetModel(model)
|
|
||||||
}
|
|
||||||
|
|
||||||
// KnowledgeCreate creates a new knowledge create service
|
|
||||||
func (c *Client) KnowledgeCreate() *KnowledgeCreateService {
|
|
||||||
return NewKnowledgeCreateService(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// KnowledgeEdit creates a new knowledge edit service
|
|
||||||
func (c *Client) KnowledgeEdit(knowledgeID string) *KnowledgeEditService {
|
|
||||||
return NewKnowledgeEditService(c).SetKnowledgeID(knowledgeID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// KnowledgeList list all the knowledge
|
|
||||||
func (c *Client) KnowledgeList() *KnowledgeListService {
|
|
||||||
return NewKnowledgeListService(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// KnowledgeDelete creates a new knowledge delete service
|
|
||||||
func (c *Client) KnowledgeDelete(knowledgeID string) *KnowledgeDeleteService {
|
|
||||||
return NewKnowledgeDeleteService(c).SetKnowledgeID(knowledgeID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// KnowledgeGet creates a new knowledge get service
|
|
||||||
func (c *Client) KnowledgeCapacity() *KnowledgeCapacityService {
|
|
||||||
return NewKnowledgeCapacityService(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// VideoGeneration creates a new video generation service
|
|
||||||
func (c *Client) VideoGeneration(model string) *VideoGenerationService {
|
|
||||||
return NewVideoGenerationService(c).SetModel(model)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AsyncResult creates a new async result get service
|
|
||||||
func (c *Client) AsyncResult(id string) *AsyncResultService {
|
|
||||||
return NewAsyncResultService(c).SetID(id)
|
|
||||||
}
|
|
@ -1,17 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestClientR(t *testing.T) {
|
|
||||||
c, err := NewClient()
|
|
||||||
require.NoError(t, err)
|
|
||||||
// the only free api is to list fine-tuning jobs
|
|
||||||
res, err := c.request(context.Background()).Get("fine_tuning/jobs")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.True(t, res.IsSuccess())
|
|
||||||
}
|
|
@ -1,25 +0,0 @@
|
|||||||
from_latest_tag = false
|
|
||||||
ignore_merge_commits = false
|
|
||||||
disable_changelog = false
|
|
||||||
disable_bump_commit = false
|
|
||||||
generate_mono_repository_global_tag = true
|
|
||||||
branch_whitelist = []
|
|
||||||
skip_ci = "[skip ci]"
|
|
||||||
skip_untracked = false
|
|
||||||
pre_bump_hooks = []
|
|
||||||
post_bump_hooks = []
|
|
||||||
pre_package_bump_hooks = []
|
|
||||||
post_package_bump_hooks = []
|
|
||||||
tag_prefix = "v"
|
|
||||||
|
|
||||||
[git_hooks]
|
|
||||||
|
|
||||||
[commit_types]
|
|
||||||
|
|
||||||
[changelog]
|
|
||||||
path = "CHANGELOG.md"
|
|
||||||
authors = []
|
|
||||||
|
|
||||||
[bump_profiles]
|
|
||||||
|
|
||||||
[packages]
|
|
@ -1,87 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/go-resty/resty/v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
// EmbeddingData is the data for each embedding.
|
|
||||||
type EmbeddingData struct {
|
|
||||||
Embedding []float64 `json:"embedding"`
|
|
||||||
Index int `json:"index"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// EmbeddingResponse is the response from the embedding service.
|
|
||||||
type EmbeddingResponse struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Data []EmbeddingData `json:"data"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Usage ChatCompletionUsage `json:"usage"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// EmbeddingService embeds a list of text into a vector space.
|
|
||||||
type EmbeddingService struct {
|
|
||||||
client *Client
|
|
||||||
|
|
||||||
model string
|
|
||||||
input string
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
_ BatchSupport = &EmbeddingService{}
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewEmbeddingService creates a new EmbeddingService.
|
|
||||||
func NewEmbeddingService(client *Client) *EmbeddingService {
|
|
||||||
return &EmbeddingService{client: client}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *EmbeddingService) BatchMethod() string {
|
|
||||||
return "POST"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *EmbeddingService) BatchURL() string {
|
|
||||||
return BatchEndpointV4Embeddings
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *EmbeddingService) BatchBody() any {
|
|
||||||
return s.buildBody()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetModel sets the model to use for the embedding.
|
|
||||||
func (s *EmbeddingService) SetModel(model string) *EmbeddingService {
|
|
||||||
s.model = model
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetInput sets the input text to embed.
|
|
||||||
func (s *EmbeddingService) SetInput(input string) *EmbeddingService {
|
|
||||||
s.input = input
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *EmbeddingService) buildBody() M {
|
|
||||||
return M{"model": s.model, "input": s.input}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *EmbeddingService) Do(ctx context.Context) (res EmbeddingResponse, err error) {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
|
|
||||||
if resp, err = s.client.request(ctx).
|
|
||||||
SetBody(s.buildBody()).
|
|
||||||
SetResult(&res).
|
|
||||||
SetError(&apiError).
|
|
||||||
Post("embeddings"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
@ -1,21 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestEmbeddingService(t *testing.T) {
|
|
||||||
client, err := NewClient()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
service := client.Embedding("embedding-2")
|
|
||||||
|
|
||||||
resp, err := service.SetInput("你好").Do(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotZero(t, resp.Usage.TotalTokens)
|
|
||||||
require.NotEmpty(t, resp.Data)
|
|
||||||
require.NotEmpty(t, resp.Data[0].Embedding)
|
|
||||||
}
|
|
@ -1,58 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
type APIError struct {
|
|
||||||
Code string `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e APIError) Error() string {
|
|
||||||
return e.Message
|
|
||||||
}
|
|
||||||
|
|
||||||
type APIErrorResponse struct {
|
|
||||||
APIError `json:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e APIErrorResponse) Error() string {
|
|
||||||
return e.APIError.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAPIErrorCode returns the error code of an API error.
|
|
||||||
func GetAPIErrorCode(err error) string {
|
|
||||||
if err == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
if e, ok := err.(APIError); ok {
|
|
||||||
return e.Code
|
|
||||||
}
|
|
||||||
if e, ok := err.(APIErrorResponse); ok {
|
|
||||||
return e.Code
|
|
||||||
}
|
|
||||||
if e, ok := err.(*APIError); ok && e != nil {
|
|
||||||
return e.Code
|
|
||||||
}
|
|
||||||
if e, ok := err.(*APIErrorResponse); ok && e != nil {
|
|
||||||
return e.Code
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAPIErrorMessage returns the error message of an API error.
|
|
||||||
func GetAPIErrorMessage(err error) string {
|
|
||||||
if err == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
if e, ok := err.(APIError); ok {
|
|
||||||
return e.Message
|
|
||||||
}
|
|
||||||
if e, ok := err.(APIErrorResponse); ok {
|
|
||||||
return e.Message
|
|
||||||
}
|
|
||||||
if e, ok := err.(*APIError); ok && e != nil {
|
|
||||||
return e.Message
|
|
||||||
}
|
|
||||||
if e, ok := err.(*APIErrorResponse); ok && e != nil {
|
|
||||||
return e.Message
|
|
||||||
}
|
|
||||||
return err.Error()
|
|
||||||
}
|
|
@ -1,38 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestAPIError(t *testing.T) {
|
|
||||||
err := APIError{
|
|
||||||
Code: "code",
|
|
||||||
Message: "message",
|
|
||||||
}
|
|
||||||
require.Equal(t, "message", err.Error())
|
|
||||||
require.Equal(t, "code", GetAPIErrorCode(err))
|
|
||||||
require.Equal(t, "message", GetAPIErrorMessage(err))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAPIErrorResponse(t *testing.T) {
|
|
||||||
err := APIErrorResponse{
|
|
||||||
APIError: APIError{
|
|
||||||
Code: "code",
|
|
||||||
Message: "message",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
require.Equal(t, "message", err.Error())
|
|
||||||
require.Equal(t, "code", GetAPIErrorCode(err))
|
|
||||||
require.Equal(t, "message", GetAPIErrorMessage(err))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAPIErrorResponseFromDoc(t *testing.T) {
|
|
||||||
var res APIErrorResponse
|
|
||||||
err := json.Unmarshal([]byte(`{"error":{"code":"1002","message":"Authorization Token非法,请确认Authorization Token正确传递。"}}`), &res)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, "1002", res.Code)
|
|
||||||
require.Equal(t, "1002", GetAPIErrorCode(res))
|
|
||||||
}
|
|
541
zhipu/file.go
541
zhipu/file.go
@ -1,541 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/go-resty/resty/v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
FilePurposeFineTune = "fine-tune"
|
|
||||||
FilePurposeRetrieval = "retrieval"
|
|
||||||
FilePurposeBatch = "batch"
|
|
||||||
|
|
||||||
KnowledgeTypeArticle = 1
|
|
||||||
KnowledgeTypeQADocument = 2
|
|
||||||
KnowledgeTypeQASpreadsheet = 3
|
|
||||||
KnowledgeTypeProductDatabaseSpreadsheet = 4
|
|
||||||
KnowledgeTypeCustom = 5
|
|
||||||
)
|
|
||||||
|
|
||||||
// FileCreateService is a service to create a file.
|
|
||||||
type FileCreateService struct {
|
|
||||||
client *Client
|
|
||||||
|
|
||||||
purpose string
|
|
||||||
|
|
||||||
localFile string
|
|
||||||
file io.Reader
|
|
||||||
filename string
|
|
||||||
|
|
||||||
customSeparator *string
|
|
||||||
sentenceSize *int
|
|
||||||
knowledgeID *string
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileCreateKnowledgeSuccessInfo is the success info of the FileCreateKnowledgeResponse.
|
|
||||||
type FileCreateKnowledgeSuccessInfo struct {
|
|
||||||
Filename string `json:"fileName"`
|
|
||||||
DocumentID string `json:"documentId"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileCreateKnowledgeFailedInfo is the failed info of the FileCreateKnowledgeResponse.
|
|
||||||
type FileCreateKnowledgeFailedInfo struct {
|
|
||||||
Filename string `json:"fileName"`
|
|
||||||
FailReason string `json:"failReason"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileCreateKnowledgeResponse is the response of the FileCreateService.
|
|
||||||
type FileCreateKnowledgeResponse struct {
|
|
||||||
SuccessInfos []FileCreateKnowledgeSuccessInfo `json:"successInfos"`
|
|
||||||
FailedInfos []FileCreateKnowledgeFailedInfo `json:"failedInfos"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileCreateFineTuneResponse is the response of the FileCreateService.
|
|
||||||
type FileCreateFineTuneResponse struct {
|
|
||||||
Bytes int64 `json:"bytes"`
|
|
||||||
CreatedAt int64 `json:"created_at"`
|
|
||||||
Filename string `json:"filename"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Purpose string `json:"purpose"`
|
|
||||||
ID string `json:"id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileCreateResponse is the response of the FileCreateService.
|
|
||||||
type FileCreateResponse struct {
|
|
||||||
FileCreateFineTuneResponse
|
|
||||||
FileCreateKnowledgeResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewFileCreateService creates a new FileCreateService.
|
|
||||||
func NewFileCreateService(client *Client) *FileCreateService {
|
|
||||||
return &FileCreateService{client: client}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetLocalFile sets the local_file parameter of the FileCreateService.
|
|
||||||
func (s *FileCreateService) SetLocalFile(localFile string) *FileCreateService {
|
|
||||||
s.localFile = localFile
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetFile sets the file parameter of the FileCreateService.
|
|
||||||
func (s *FileCreateService) SetFile(file io.Reader, filename string) *FileCreateService {
|
|
||||||
s.file = file
|
|
||||||
s.filename = filename
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetPurpose sets the purpose parameter of the FileCreateService.
|
|
||||||
func (s *FileCreateService) SetPurpose(purpose string) *FileCreateService {
|
|
||||||
s.purpose = purpose
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetCustomSeparator sets the custom_separator parameter of the FileCreateService.
|
|
||||||
func (s *FileCreateService) SetCustomSeparator(customSeparator string) *FileCreateService {
|
|
||||||
s.customSeparator = &customSeparator
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSentenceSize sets the sentence_size parameter of the FileCreateService.
|
|
||||||
func (s *FileCreateService) SetSentenceSize(sentenceSize int) *FileCreateService {
|
|
||||||
s.sentenceSize = &sentenceSize
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetKnowledgeID sets the knowledge_id parameter of the FileCreateService.
|
|
||||||
func (s *FileCreateService) SetKnowledgeID(knowledgeID string) *FileCreateService {
|
|
||||||
s.knowledgeID = &knowledgeID
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do makes the request.
|
|
||||||
func (s *FileCreateService) Do(ctx context.Context) (res FileCreateResponse, err error) {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
|
|
||||||
body := map[string]string{"purpose": s.purpose}
|
|
||||||
|
|
||||||
if s.customSeparator != nil {
|
|
||||||
body["custom_separator"] = *s.customSeparator
|
|
||||||
}
|
|
||||||
if s.sentenceSize != nil {
|
|
||||||
body["sentence_size"] = strconv.Itoa(*s.sentenceSize)
|
|
||||||
}
|
|
||||||
if s.knowledgeID != nil {
|
|
||||||
body["knowledge_id"] = *s.knowledgeID
|
|
||||||
}
|
|
||||||
|
|
||||||
file, filename := s.file, s.filename
|
|
||||||
|
|
||||||
if file == nil && s.localFile != "" {
|
|
||||||
var f *os.File
|
|
||||||
if f, err = os.Open(s.localFile); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
file = f
|
|
||||||
filename = filepath.Base(s.localFile)
|
|
||||||
}
|
|
||||||
|
|
||||||
if file == nil {
|
|
||||||
err = errors.New("no file specified")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp, err = s.client.request(ctx).
|
|
||||||
SetFileReader("file", filename, file).
|
|
||||||
SetMultipartFormData(body).
|
|
||||||
SetResult(&res).
|
|
||||||
SetError(&apiError).
|
|
||||||
Post("files"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileEditService is a service to edit a file.
|
|
||||||
type FileEditService struct {
|
|
||||||
client *Client
|
|
||||||
|
|
||||||
documentID string
|
|
||||||
|
|
||||||
knowledgeType *int
|
|
||||||
customSeparator []string
|
|
||||||
sentenceSize *int
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewFileEditService creates a new FileEditService.
|
|
||||||
func NewFileEditService(client *Client) *FileEditService {
|
|
||||||
return &FileEditService{client: client}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDocumentID sets the document_id parameter of the FileEditService.
|
|
||||||
func (s *FileEditService) SetDocumentID(documentID string) *FileEditService {
|
|
||||||
s.documentID = documentID
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetKnowledgeType sets the knowledge_type parameter of the FileEditService.
|
|
||||||
func (s *FileEditService) SetKnowledgeType(knowledgeType int) *FileEditService {
|
|
||||||
s.knowledgeType = &knowledgeType
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSentenceSize sets the sentence_size parameter of the FileEditService.
|
|
||||||
func (s *FileEditService) SetCustomSeparator(customSeparator ...string) *FileEditService {
|
|
||||||
s.customSeparator = customSeparator
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSentenceSize sets the sentence_size parameter of the FileEditService.
|
|
||||||
func (s *FileEditService) SetSentenceSize(sentenceSize int) *FileEditService {
|
|
||||||
s.sentenceSize = &sentenceSize
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do makes the request.
|
|
||||||
func (s *FileEditService) Do(ctx context.Context) (err error) {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
|
|
||||||
body := M{}
|
|
||||||
|
|
||||||
if s.knowledgeType != nil {
|
|
||||||
body["knowledge_type"] = strconv.Itoa(*s.knowledgeType)
|
|
||||||
}
|
|
||||||
if len(s.customSeparator) > 0 {
|
|
||||||
body["custom_separator"] = s.customSeparator
|
|
||||||
}
|
|
||||||
if s.sentenceSize != nil {
|
|
||||||
body["sentence_size"] = strconv.Itoa(*s.sentenceSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp, err = s.client.request(ctx).
|
|
||||||
SetPathParam("document_id", s.documentID).
|
|
||||||
SetBody(body).
|
|
||||||
SetError(&apiError).
|
|
||||||
Put("document/{document_id}"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileListService is a service to list files.
|
|
||||||
type FileListService struct {
|
|
||||||
client *Client
|
|
||||||
|
|
||||||
purpose string
|
|
||||||
|
|
||||||
knowledgeID *string
|
|
||||||
page *int
|
|
||||||
limit *int
|
|
||||||
after *string
|
|
||||||
orderAsc *bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileFailInfo is the failed info of the FileListKnowledgeItem.
|
|
||||||
type FileFailInfo struct {
|
|
||||||
EmbeddingCode int `json:"embedding_code"`
|
|
||||||
EmbeddingMsg string `json:"embedding_msg"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileListKnowledgeItem is the item of the FileListKnowledgeResponse.
|
|
||||||
type FileListKnowledgeItem struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
URL string `json:"url"`
|
|
||||||
Length int64 `json:"length"`
|
|
||||||
SentenceSize int64 `json:"sentence_size"`
|
|
||||||
CustomSeparator []string `json:"custom_separator"`
|
|
||||||
EmbeddingStat int `json:"embedding_stat"`
|
|
||||||
FailInfo *FileFailInfo `json:"failInfo"`
|
|
||||||
WordNum int64 `json:"word_num"`
|
|
||||||
ParseImage int `json:"parse_image"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileListKnowledgeResponse is the response of the FileListService.
|
|
||||||
type FileListKnowledgeResponse struct {
|
|
||||||
Total int `json:"total"`
|
|
||||||
List []FileListKnowledgeItem `json:"list"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileListFineTuneItem is the item of the FileListFineTuneResponse.
|
|
||||||
type FileListFineTuneItem struct {
|
|
||||||
Bytes int64 `json:"bytes"`
|
|
||||||
CreatedAt int64 `json:"created_at"`
|
|
||||||
Filename string `json:"filename"`
|
|
||||||
ID string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Purpose string `json:"purpose"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileListFineTuneResponse is the response of the FileListService.
|
|
||||||
type FileListFineTuneResponse struct {
|
|
||||||
Object string `json:"object"`
|
|
||||||
Data []FileListFineTuneItem `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileListResponse is the response of the FileListService.
|
|
||||||
type FileListResponse struct {
|
|
||||||
FileListKnowledgeResponse
|
|
||||||
FileListFineTuneResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewFileListService creates a new FileListService.
|
|
||||||
func NewFileListService(client *Client) *FileListService {
|
|
||||||
return &FileListService{client: client}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetPurpose sets the purpose parameter of the FileListService.
|
|
||||||
func (s *FileListService) SetPurpose(purpose string) *FileListService {
|
|
||||||
s.purpose = purpose
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetKnowledgeID sets the knowledge_id parameter of the FileListService.
|
|
||||||
func (s *FileListService) SetKnowledgeID(knowledgeID string) *FileListService {
|
|
||||||
s.knowledgeID = &knowledgeID
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetPage sets the page parameter of the FileListService.
|
|
||||||
func (s *FileListService) SetPage(page int) *FileListService {
|
|
||||||
s.page = &page
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetLimit sets the limit parameter of the FileListService.
|
|
||||||
func (s *FileListService) SetLimit(limit int) *FileListService {
|
|
||||||
s.limit = &limit
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetAfter sets the after parameter of the FileListService.
|
|
||||||
func (s *FileListService) SetAfter(after string) *FileListService {
|
|
||||||
s.after = &after
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetOrder sets the order parameter of the FileListService.
|
|
||||||
func (s *FileListService) SetOrder(asc bool) *FileListService {
|
|
||||||
s.orderAsc = &asc
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do makes the request.
|
|
||||||
func (s *FileListService) Do(ctx context.Context) (res FileListResponse, err error) {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
|
|
||||||
m := map[string]string{
|
|
||||||
"purpose": s.purpose,
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.knowledgeID != nil {
|
|
||||||
m["knowledge_id"] = *s.knowledgeID
|
|
||||||
}
|
|
||||||
if s.page != nil {
|
|
||||||
m["page"] = strconv.Itoa(*s.page)
|
|
||||||
}
|
|
||||||
if s.limit != nil {
|
|
||||||
m["limit"] = strconv.Itoa(*s.limit)
|
|
||||||
}
|
|
||||||
if s.after != nil {
|
|
||||||
m["after"] = *s.after
|
|
||||||
}
|
|
||||||
if s.orderAsc != nil {
|
|
||||||
if *s.orderAsc {
|
|
||||||
m["order"] = "asc"
|
|
||||||
} else {
|
|
||||||
m["order"] = "desc"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp, err = s.client.request(ctx).
|
|
||||||
SetQueryParams(m).
|
|
||||||
SetResult(&res).
|
|
||||||
SetError(&apiError).
|
|
||||||
Get("files"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileDeleteService is a service to delete a file.
|
|
||||||
type FileDeleteService struct {
|
|
||||||
client *Client
|
|
||||||
documentID string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewFileDeleteService creates a new FileDeleteService.
|
|
||||||
func NewFileDeleteService(client *Client) *FileDeleteService {
|
|
||||||
return &FileDeleteService{client: client}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDocumentID sets the document_id parameter of the FileDeleteService.
|
|
||||||
func (s *FileDeleteService) SetDocumentID(documentID string) *FileDeleteService {
|
|
||||||
s.documentID = documentID
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do makes the request.
|
|
||||||
func (s *FileDeleteService) Do(ctx context.Context) (err error) {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
|
|
||||||
if resp, err = s.client.request(ctx).
|
|
||||||
SetPathParam("document_id", s.documentID).
|
|
||||||
SetError(&apiError).
|
|
||||||
Delete("document/{document_id}"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileGetService is a service to get a file.
|
|
||||||
type FileGetService struct {
|
|
||||||
client *Client
|
|
||||||
documentID string
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileGetResponse is the response of the FileGetService.
|
|
||||||
type FileGetResponse = FileListKnowledgeItem
|
|
||||||
|
|
||||||
// NewFileGetService creates a new FileGetService.
|
|
||||||
func NewFileGetService(client *Client) *FileGetService {
|
|
||||||
return &FileGetService{client: client}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDocumentID sets the document_id parameter of the FileGetService.
|
|
||||||
func (s *FileGetService) SetDocumentID(documentID string) *FileGetService {
|
|
||||||
s.documentID = documentID
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do makes the request.
|
|
||||||
func (s *FileGetService) Do(ctx context.Context) (res FileGetResponse, err error) {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
|
|
||||||
if resp, err = s.client.request(ctx).
|
|
||||||
SetPathParam("document_id", s.documentID).
|
|
||||||
SetResult(&res).
|
|
||||||
SetError(&apiError).
|
|
||||||
Get("document/{document_id}"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileDownloadService is a service to download a file.
|
|
||||||
type FileDownloadService struct {
|
|
||||||
client *Client
|
|
||||||
|
|
||||||
fileID string
|
|
||||||
|
|
||||||
writer io.Writer
|
|
||||||
filename string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewFileDownloadService creates a new FileDownloadService.
|
|
||||||
func NewFileDownloadService(client *Client) *FileDownloadService {
|
|
||||||
return &FileDownloadService{client: client}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetFileID sets the file_id parameter of the FileDownloadService.
|
|
||||||
func (s *FileDownloadService) SetFileID(fileID string) *FileDownloadService {
|
|
||||||
s.fileID = fileID
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetOutput sets the output parameter of the FileDownloadService.
|
|
||||||
func (s *FileDownloadService) SetOutput(w io.Writer) *FileDownloadService {
|
|
||||||
s.writer = w
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetOutputFile sets the output_file parameter of the FileDownloadService.
|
|
||||||
func (s *FileDownloadService) SetOutputFile(filename string) *FileDownloadService {
|
|
||||||
s.filename = filename
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do makes the request.
|
|
||||||
func (s *FileDownloadService) Do(ctx context.Context) (err error) {
|
|
||||||
var resp *resty.Response
|
|
||||||
|
|
||||||
writer := s.writer
|
|
||||||
|
|
||||||
if writer == nil && s.filename != "" {
|
|
||||||
var f *os.File
|
|
||||||
if f, err = os.Create(s.filename); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
writer = f
|
|
||||||
}
|
|
||||||
|
|
||||||
if writer == nil {
|
|
||||||
return errors.New("no output specified")
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp, err = s.client.request(ctx).
|
|
||||||
SetDoNotParseResponse(true).
|
|
||||||
SetPathParam("file_id", s.fileID).
|
|
||||||
Get("files/{file_id}/content"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer resp.RawBody().Close()
|
|
||||||
|
|
||||||
_, err = io.Copy(writer, resp.RawBody())
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
@ -1,71 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestFileServiceFineTune(t *testing.T) {
|
|
||||||
client, err := NewClient()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
s := client.FileCreate(FilePurposeFineTune)
|
|
||||||
s.SetLocalFile(filepath.Join("testdata", "test-file.jsonl"))
|
|
||||||
|
|
||||||
res, err := s.Do(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotZero(t, res.Bytes)
|
|
||||||
require.NotZero(t, res.CreatedAt)
|
|
||||||
require.NotEmpty(t, res.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileServiceKnowledge(t *testing.T) {
|
|
||||||
client, err := NewClient()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
s := client.FileCreate(FilePurposeRetrieval)
|
|
||||||
s.SetKnowledgeID(os.Getenv("TEST_KNOWLEDGE_ID"))
|
|
||||||
s.SetLocalFile(filepath.Join("testdata", "test-file.txt"))
|
|
||||||
|
|
||||||
res, err := s.Do(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotEmpty(t, res.SuccessInfos)
|
|
||||||
require.NotEmpty(t, res.SuccessInfos[0].DocumentID)
|
|
||||||
require.NotEmpty(t, res.SuccessInfos[0].Filename)
|
|
||||||
|
|
||||||
documentID := res.SuccessInfos[0].DocumentID
|
|
||||||
|
|
||||||
res2, err := client.FileGet(documentID).Do(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotEmpty(t, res2.ID)
|
|
||||||
|
|
||||||
err = client.FileEdit(documentID).SetKnowledgeType(KnowledgeTypeCustom).Do(context.Background())
|
|
||||||
require.True(t, err == nil || GetAPIErrorCode(err) == "10019")
|
|
||||||
|
|
||||||
err = client.FileDelete(res.SuccessInfos[0].DocumentID).Do(context.Background())
|
|
||||||
require.True(t, err == nil || GetAPIErrorCode(err) == "10019")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileListServiceKnowledge(t *testing.T) {
|
|
||||||
client, err := NewClient()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
s := client.FileList(FilePurposeRetrieval).SetKnowledgeID(os.Getenv("TEST_KNOWLEDGE_ID"))
|
|
||||||
res, err := s.Do(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotEmpty(t, res.List)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileListServiceFineTune(t *testing.T) {
|
|
||||||
client, err := NewClient()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
s := client.FileList(FilePurposeFineTune)
|
|
||||||
res, err := s.Do(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotEmpty(t, res.Data)
|
|
||||||
}
|
|
@ -1,456 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/go-resty/resty/v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
HyperParameterAuto = "auto"
|
|
||||||
|
|
||||||
FineTuneStatusCreate = "create"
|
|
||||||
FineTuneStatusValidatingFiles = "validating_files"
|
|
||||||
FineTuneStatusQueued = "queued"
|
|
||||||
FineTuneStatusRunning = "running"
|
|
||||||
FineTuneStatusSucceeded = "succeeded"
|
|
||||||
FineTuneStatusFailed = "failed"
|
|
||||||
FineTuneStatusCancelled = "cancelled"
|
|
||||||
)
|
|
||||||
|
|
||||||
// FineTuneItem is the item of the FineTune
|
|
||||||
type FineTuneItem struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
RequestID string `json:"request_id"`
|
|
||||||
FineTunedModel string `json:"fine_tuned_model"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
TrainingFile string `json:"training_file"`
|
|
||||||
ValidationFile string `json:"validation_file"`
|
|
||||||
Error APIError `json:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// FineTuneCreateService creates a new fine tune
|
|
||||||
type FineTuneCreateService struct {
|
|
||||||
client *Client
|
|
||||||
|
|
||||||
model string
|
|
||||||
trainingFile string
|
|
||||||
validationFile *string
|
|
||||||
|
|
||||||
learningRateMultiplier *StringOr[float64]
|
|
||||||
batchSize *StringOr[int]
|
|
||||||
nEpochs *StringOr[int]
|
|
||||||
|
|
||||||
suffix *string
|
|
||||||
requestID *string
|
|
||||||
}
|
|
||||||
|
|
||||||
// FineTuneCreateResponse is the response of the FineTuneCreateService
|
|
||||||
type FineTuneCreateResponse = FineTuneItem
|
|
||||||
|
|
||||||
// NewFineTuneCreateService creates a new FineTuneCreateService
|
|
||||||
func NewFineTuneCreateService(client *Client) *FineTuneCreateService {
|
|
||||||
return &FineTuneCreateService{
|
|
||||||
client: client,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetModel sets the model parameter
|
|
||||||
func (s *FineTuneCreateService) SetModel(model string) *FineTuneCreateService {
|
|
||||||
s.model = model
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetTrainingFile sets the trainingFile parameter
|
|
||||||
func (s *FineTuneCreateService) SetTrainingFile(trainingFile string) *FineTuneCreateService {
|
|
||||||
s.trainingFile = trainingFile
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetValidationFile sets the validationFile parameter
|
|
||||||
func (s *FineTuneCreateService) SetValidationFile(validationFile string) *FineTuneCreateService {
|
|
||||||
s.validationFile = &validationFile
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetLearningRateMultiplier sets the learningRateMultiplier parameter
|
|
||||||
func (s *FineTuneCreateService) SetLearningRateMultiplier(learningRateMultiplier float64) *FineTuneCreateService {
|
|
||||||
s.learningRateMultiplier = &StringOr[float64]{}
|
|
||||||
s.learningRateMultiplier.SetValue(learningRateMultiplier)
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetLearningRateMultiplierAuto sets the learningRateMultiplier parameter to auto
|
|
||||||
func (s *FineTuneCreateService) SetLearningRateMultiplierAuto() *FineTuneCreateService {
|
|
||||||
s.learningRateMultiplier = &StringOr[float64]{}
|
|
||||||
s.learningRateMultiplier.SetString(HyperParameterAuto)
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetBatchSize sets the batchSize parameter
|
|
||||||
func (s *FineTuneCreateService) SetBatchSize(batchSize int) *FineTuneCreateService {
|
|
||||||
s.batchSize = &StringOr[int]{}
|
|
||||||
s.batchSize.SetValue(batchSize)
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetBatchSizeAuto sets the batchSize parameter to auto
|
|
||||||
func (s *FineTuneCreateService) SetBatchSizeAuto() *FineTuneCreateService {
|
|
||||||
s.batchSize = &StringOr[int]{}
|
|
||||||
s.batchSize.SetString(HyperParameterAuto)
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNEpochs sets the nEpochs parameter
|
|
||||||
func (s *FineTuneCreateService) SetNEpochs(nEpochs int) *FineTuneCreateService {
|
|
||||||
s.nEpochs = &StringOr[int]{}
|
|
||||||
s.nEpochs.SetValue(nEpochs)
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNEpochsAuto sets the nEpochs parameter to auto
|
|
||||||
func (s *FineTuneCreateService) SetNEpochsAuto() *FineTuneCreateService {
|
|
||||||
s.nEpochs = &StringOr[int]{}
|
|
||||||
s.nEpochs.SetString(HyperParameterAuto)
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSuffix sets the suffix parameter
|
|
||||||
func (s *FineTuneCreateService) SetSuffix(suffix string) *FineTuneCreateService {
|
|
||||||
s.suffix = &suffix
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetRequestID sets the requestID parameter
|
|
||||||
func (s *FineTuneCreateService) SetRequestID(requestID string) *FineTuneCreateService {
|
|
||||||
s.requestID = &requestID
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do makes the request
|
|
||||||
func (s *FineTuneCreateService) Do(ctx context.Context) (res FineTuneCreateResponse, err error) {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
|
|
||||||
body := M{
|
|
||||||
"model": s.model,
|
|
||||||
"training_file": s.trainingFile,
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.validationFile != nil {
|
|
||||||
body["validation_file"] = *s.validationFile
|
|
||||||
}
|
|
||||||
if s.suffix != nil {
|
|
||||||
body["suffix"] = *s.suffix
|
|
||||||
}
|
|
||||||
if s.requestID != nil {
|
|
||||||
body["request_id"] = *s.requestID
|
|
||||||
}
|
|
||||||
if s.learningRateMultiplier != nil || s.batchSize != nil || s.nEpochs != nil {
|
|
||||||
hp := M{}
|
|
||||||
if s.learningRateMultiplier != nil {
|
|
||||||
hp["learning_rate_multiplier"] = s.learningRateMultiplier
|
|
||||||
}
|
|
||||||
if s.batchSize != nil {
|
|
||||||
hp["batch_size"] = s.batchSize
|
|
||||||
}
|
|
||||||
if s.nEpochs != nil {
|
|
||||||
hp["n_epochs"] = s.nEpochs
|
|
||||||
}
|
|
||||||
body["hyperparameters"] = hp
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp, err = s.client.request(ctx).
|
|
||||||
SetBody(body).
|
|
||||||
SetResult(&res).
|
|
||||||
SetError(&apiError).
|
|
||||||
Post("fine_tuning/jobs"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// FineTuneEventListService creates a new fine tune event list
|
|
||||||
type FineTuneEventListService struct {
|
|
||||||
client *Client
|
|
||||||
|
|
||||||
jobID string
|
|
||||||
|
|
||||||
limit *int
|
|
||||||
after *string
|
|
||||||
}
|
|
||||||
|
|
||||||
// FineTuneEventData is the data of the FineTuneEventItem
|
|
||||||
type FineTuneEventData struct {
|
|
||||||
Acc float64 `json:"acc"`
|
|
||||||
Loss float64 `json:"loss"`
|
|
||||||
CurrentSteps int64 `json:"current_steps"`
|
|
||||||
RemainingTime string `json:"remaining_time"`
|
|
||||||
ElapsedTime string `json:"elapsed_time"`
|
|
||||||
TotalSteps int64 `json:"total_steps"`
|
|
||||||
Epoch int64 `json:"epoch"`
|
|
||||||
TrainedTokens int64 `json:"trained_tokens"`
|
|
||||||
LearningRate float64 `json:"learning_rate"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// FineTuneEventItem is the item of the FineTuneEventListResponse
|
|
||||||
type FineTuneEventItem struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
Level string `json:"level"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
CreatedAt int64 `json:"created_at"`
|
|
||||||
Data FineTuneEventData `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// FineTuneEventListResponse is the response of the FineTuneEventListService
|
|
||||||
type FineTuneEventListResponse struct {
|
|
||||||
Data []FineTuneEventItem `json:"data"`
|
|
||||||
HasMore bool `json:"has_more"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewFineTuneEventListService creates a new FineTuneEventListService
|
|
||||||
func NewFineTuneEventListService(client *Client) *FineTuneEventListService {
|
|
||||||
return &FineTuneEventListService{
|
|
||||||
client: client,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetJobID sets the jobID parameter
|
|
||||||
func (s *FineTuneEventListService) SetJobID(jobID string) *FineTuneEventListService {
|
|
||||||
s.jobID = jobID
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetLimit sets the limit parameter
|
|
||||||
func (s *FineTuneEventListService) SetLimit(limit int) *FineTuneEventListService {
|
|
||||||
s.limit = &limit
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetAfter sets the after parameter
|
|
||||||
func (s *FineTuneEventListService) SetAfter(after string) *FineTuneEventListService {
|
|
||||||
s.after = &after
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do makes the request
|
|
||||||
func (s *FineTuneEventListService) Do(ctx context.Context) (res FineTuneEventListResponse, err error) {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
|
|
||||||
req := s.client.request(ctx)
|
|
||||||
|
|
||||||
if s.limit != nil {
|
|
||||||
req.SetQueryParam("limit", strconv.Itoa(*s.limit))
|
|
||||||
}
|
|
||||||
if s.after != nil {
|
|
||||||
req.SetQueryParam("after", *s.after)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp, err = req.
|
|
||||||
SetPathParam("job_id", s.jobID).
|
|
||||||
SetResult(&res).
|
|
||||||
SetError(&apiError).
|
|
||||||
Get("fine_tuning/jobs/{job_id}/events"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// FineTuneGetService creates a new fine tune get
|
|
||||||
type FineTuneGetService struct {
|
|
||||||
client *Client
|
|
||||||
jobID string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewFineTuneGetService creates a new FineTuneGetService
|
|
||||||
func NewFineTuneGetService(client *Client) *FineTuneGetService {
|
|
||||||
return &FineTuneGetService{
|
|
||||||
client: client,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetJobID sets the jobID parameter
|
|
||||||
func (s *FineTuneGetService) SetJobID(jobID string) *FineTuneGetService {
|
|
||||||
s.jobID = jobID
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do makes the request
|
|
||||||
func (s *FineTuneGetService) Do(ctx context.Context) (res FineTuneItem, err error) {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
|
|
||||||
if resp, err = s.client.request(ctx).
|
|
||||||
SetPathParam("job_id", s.jobID).
|
|
||||||
SetResult(&res).
|
|
||||||
SetError(&apiError).
|
|
||||||
Get("fine_tuning/jobs/{job_id}"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// FineTuneListService creates a new fine tune list
|
|
||||||
type FineTuneListService struct {
|
|
||||||
client *Client
|
|
||||||
|
|
||||||
limit *int
|
|
||||||
after *string
|
|
||||||
}
|
|
||||||
|
|
||||||
// FineTuneListResponse is the response of the FineTuneListService
|
|
||||||
type FineTuneListResponse struct {
|
|
||||||
Data []FineTuneItem `json:"data"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewFineTuneListService creates a new FineTuneListService
|
|
||||||
func NewFineTuneListService(client *Client) *FineTuneListService {
|
|
||||||
return &FineTuneListService{
|
|
||||||
client: client,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetLimit sets the limit parameter
|
|
||||||
func (s *FineTuneListService) SetLimit(limit int) *FineTuneListService {
|
|
||||||
s.limit = &limit
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetAfter sets the after parameter
|
|
||||||
func (s *FineTuneListService) SetAfter(after string) *FineTuneListService {
|
|
||||||
s.after = &after
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do makes the request
|
|
||||||
func (s *FineTuneListService) Do(ctx context.Context) (res FineTuneListResponse, err error) {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
|
|
||||||
req := s.client.request(ctx)
|
|
||||||
if s.limit != nil {
|
|
||||||
req.SetQueryParam("limit", strconv.Itoa(*s.limit))
|
|
||||||
}
|
|
||||||
if s.after != nil {
|
|
||||||
req.SetQueryParam("after", *s.after)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp, err = req.
|
|
||||||
SetResult(&res).
|
|
||||||
SetError(&apiError).
|
|
||||||
Get("fine_tuning/jobs"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// FineTuneDeleteService creates a new fine tune delete
|
|
||||||
type FineTuneDeleteService struct {
|
|
||||||
client *Client
|
|
||||||
jobID string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewFineTuneDeleteService creates a new FineTuneDeleteService
|
|
||||||
func NewFineTuneDeleteService(client *Client) *FineTuneDeleteService {
|
|
||||||
return &FineTuneDeleteService{
|
|
||||||
client: client,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetJobID sets the jobID parameter
|
|
||||||
func (s *FineTuneDeleteService) SetJobID(jobID string) *FineTuneDeleteService {
|
|
||||||
s.jobID = jobID
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do makes the request
|
|
||||||
func (s *FineTuneDeleteService) Do(ctx context.Context) (res FineTuneItem, err error) {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
|
|
||||||
if resp, err = s.client.request(ctx).
|
|
||||||
SetPathParam("job_id", s.jobID).
|
|
||||||
SetResult(&res).
|
|
||||||
SetError(&apiError).
|
|
||||||
Delete("fine_tuning/jobs/{job_id}"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// FineTuneCancelService creates a new fine tune cancel
|
|
||||||
type FineTuneCancelService struct {
|
|
||||||
client *Client
|
|
||||||
jobID string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewFineTuneCancelService creates a new FineTuneCancelService
|
|
||||||
func NewFineTuneCancelService(client *Client) *FineTuneCancelService {
|
|
||||||
return &FineTuneCancelService{
|
|
||||||
client: client,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetJobID sets the jobID parameter
|
|
||||||
func (s *FineTuneCancelService) SetJobID(jobID string) *FineTuneCancelService {
|
|
||||||
s.jobID = jobID
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do makes the request
|
|
||||||
func (s *FineTuneCancelService) Do(ctx context.Context) (res FineTuneItem, err error) {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
|
|
||||||
if resp, err = s.client.request(ctx).
|
|
||||||
SetPathParam("job_id", s.jobID).
|
|
||||||
SetResult(&res).
|
|
||||||
SetError(&apiError).
|
|
||||||
Post("fine_tuning/jobs/{job_id}/cancel"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
@ -1,3 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
// tests not available since lack of budget to test it
|
|
@ -1,110 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/go-resty/resty/v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ImageGenerationService creates a new image generation
|
|
||||||
type ImageGenerationService struct {
|
|
||||||
client *Client
|
|
||||||
|
|
||||||
model string
|
|
||||||
prompt string
|
|
||||||
size string
|
|
||||||
userID string
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
_ BatchSupport = &ImageGenerationService{}
|
|
||||||
)
|
|
||||||
|
|
||||||
// ImageGenerationResponse is the response of the ImageGenerationService
|
|
||||||
type ImageGenerationResponse struct {
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
Data []URLItem `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewImageGenerationService creates a new ImageGenerationService
|
|
||||||
func NewImageGenerationService(client *Client) *ImageGenerationService {
|
|
||||||
return &ImageGenerationService{
|
|
||||||
client: client,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ImageGenerationService) BatchMethod() string {
|
|
||||||
return "POST"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ImageGenerationService) BatchURL() string {
|
|
||||||
return BatchEndpointV4ImagesGenerations
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ImageGenerationService) BatchBody() any {
|
|
||||||
return s.buildBody()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetModel sets the model parameter
|
|
||||||
func (s *ImageGenerationService) SetModel(model string) *ImageGenerationService {
|
|
||||||
s.model = model
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetPrompt sets the prompt parameter
|
|
||||||
func (s *ImageGenerationService) SetPrompt(prompt string) *ImageGenerationService {
|
|
||||||
s.prompt = prompt
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ImageGenerationService) SetSize(size string) *ImageGenerationService {
|
|
||||||
s.size = size
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetUserID sets the userID parameter
|
|
||||||
func (s *ImageGenerationService) SetUserID(userID string) *ImageGenerationService {
|
|
||||||
s.userID = userID
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ImageGenerationService) buildBody() M {
|
|
||||||
body := M{
|
|
||||||
"model": s.model,
|
|
||||||
"prompt": s.prompt,
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.userID != "" {
|
|
||||||
body["user_id"] = s.userID
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.size != "" {
|
|
||||||
body["size"] = s.size
|
|
||||||
}
|
|
||||||
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ImageGenerationService) Do(ctx context.Context) (res ImageGenerationResponse, err error) {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
|
|
||||||
body := s.buildBody()
|
|
||||||
|
|
||||||
if resp, err = s.client.request(ctx).
|
|
||||||
SetBody(body).
|
|
||||||
SetResult(&res).
|
|
||||||
SetError(&apiError).
|
|
||||||
Post("images/generations"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
@ -1,21 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestImageGenerationService(t *testing.T) {
|
|
||||||
client, err := NewClient()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
s := client.ImageGeneration("cogview-3")
|
|
||||||
s.SetPrompt("一只可爱的小猫")
|
|
||||||
|
|
||||||
res, err := s.Do(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotEmpty(t, res.Data)
|
|
||||||
t.Log(res.Data[0].URL)
|
|
||||||
}
|
|
@ -1,299 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/go-resty/resty/v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
KnowledgeEmbeddingIDEmbedding2 = 3
|
|
||||||
)
|
|
||||||
|
|
||||||
// KnowledgeCreateService creates a new knowledge
|
|
||||||
type KnowledgeCreateService struct {
|
|
||||||
client *Client
|
|
||||||
|
|
||||||
embeddingID int
|
|
||||||
name string
|
|
||||||
description *string
|
|
||||||
}
|
|
||||||
|
|
||||||
// KnowledgeCreateResponse is the response of the KnowledgeCreateService
|
|
||||||
type KnowledgeCreateResponse = IDItem
|
|
||||||
|
|
||||||
// NewKnowledgeCreateService creates a new KnowledgeCreateService
|
|
||||||
func NewKnowledgeCreateService(client *Client) *KnowledgeCreateService {
|
|
||||||
return &KnowledgeCreateService{
|
|
||||||
client: client,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetEmbeddingID sets the embedding id of the knowledge
|
|
||||||
func (s *KnowledgeCreateService) SetEmbeddingID(embeddingID int) *KnowledgeCreateService {
|
|
||||||
s.embeddingID = embeddingID
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetName sets the name of the knowledge
|
|
||||||
func (s *KnowledgeCreateService) SetName(name string) *KnowledgeCreateService {
|
|
||||||
s.name = name
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDescription sets the description of the knowledge
|
|
||||||
func (s *KnowledgeCreateService) SetDescription(description string) *KnowledgeCreateService {
|
|
||||||
s.description = &description
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do creates the knowledge
|
|
||||||
func (s *KnowledgeCreateService) Do(ctx context.Context) (res KnowledgeCreateResponse, err error) {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
body := M{
|
|
||||||
"name": s.name,
|
|
||||||
"embedding_id": s.embeddingID,
|
|
||||||
}
|
|
||||||
if s.description != nil {
|
|
||||||
body["description"] = *s.description
|
|
||||||
}
|
|
||||||
if resp, err = s.client.request(ctx).
|
|
||||||
SetBody(body).
|
|
||||||
SetResult(&res).
|
|
||||||
SetError(&apiError).
|
|
||||||
Post("knowledge"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// KnowledgeEditService edits a knowledge
|
|
||||||
type KnowledgeEditService struct {
|
|
||||||
client *Client
|
|
||||||
|
|
||||||
knowledgeID string
|
|
||||||
|
|
||||||
embeddingID *int
|
|
||||||
name *string
|
|
||||||
description *string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewKnowledgeEditService creates a new KnowledgeEditService
|
|
||||||
func NewKnowledgeEditService(client *Client) *KnowledgeEditService {
|
|
||||||
return &KnowledgeEditService{
|
|
||||||
client: client,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetKnowledgeID sets the knowledge id
|
|
||||||
func (s *KnowledgeEditService) SetKnowledgeID(knowledgeID string) *KnowledgeEditService {
|
|
||||||
s.knowledgeID = knowledgeID
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetName sets the name of the knowledge
|
|
||||||
func (s *KnowledgeEditService) SetName(name string) *KnowledgeEditService {
|
|
||||||
s.name = &name
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetEmbeddingID sets the embedding id of the knowledge
|
|
||||||
func (s *KnowledgeEditService) SetEmbeddingID(embeddingID int) *KnowledgeEditService {
|
|
||||||
s.embeddingID = &embeddingID
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDescription sets the description of the knowledge
|
|
||||||
func (s *KnowledgeEditService) SetDescription(description string) *KnowledgeEditService {
|
|
||||||
s.description = &description
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do edits the knowledge
|
|
||||||
func (s *KnowledgeEditService) Do(ctx context.Context) (err error) {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
body := M{}
|
|
||||||
if s.name != nil {
|
|
||||||
body["name"] = *s.name
|
|
||||||
}
|
|
||||||
if s.description != nil {
|
|
||||||
body["description"] = *s.description
|
|
||||||
}
|
|
||||||
if s.embeddingID != nil {
|
|
||||||
body["embedding_id"] = *s.embeddingID
|
|
||||||
}
|
|
||||||
if resp, err = s.client.request(ctx).
|
|
||||||
SetPathParam("knowledge_id", s.knowledgeID).
|
|
||||||
SetBody(body).
|
|
||||||
SetError(&apiError).
|
|
||||||
Put("knowledge/{knowledge_id}"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// KnowledgeListService lists the knowledge
|
|
||||||
type KnowledgeListService struct {
|
|
||||||
client *Client
|
|
||||||
|
|
||||||
page *int
|
|
||||||
size *int
|
|
||||||
}
|
|
||||||
|
|
||||||
// KnowledgeItem is an item in the knowledge list
|
|
||||||
type KnowledgeItem struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Icon string `json:"icon"`
|
|
||||||
Background string `json:"background"`
|
|
||||||
EmbeddingID int `json:"embedding_id"`
|
|
||||||
CustomIdentifier string `json:"custom_identifier"`
|
|
||||||
WordNum int64 `json:"word_num"`
|
|
||||||
Length int64 `json:"length"`
|
|
||||||
DocumentSize int64 `json:"document_size"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// KnowledgeListResponse is the response of the KnowledgeListService
|
|
||||||
type KnowledgeListResponse struct {
|
|
||||||
List []KnowledgeItem `json:"list"`
|
|
||||||
Total int `json:"total"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewKnowledgeListService creates a new KnowledgeListService
|
|
||||||
func NewKnowledgeListService(client *Client) *KnowledgeListService {
|
|
||||||
return &KnowledgeListService{client: client}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetPage sets the page of the knowledge list
|
|
||||||
func (s *KnowledgeListService) SetPage(page int) *KnowledgeListService {
|
|
||||||
s.page = &page
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSize sets the size of the knowledge list
|
|
||||||
func (s *KnowledgeListService) SetSize(size int) *KnowledgeListService {
|
|
||||||
s.size = &size
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do lists the knowledge
|
|
||||||
func (s *KnowledgeListService) Do(ctx context.Context) (res KnowledgeListResponse, err error) {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
req := s.client.request(ctx)
|
|
||||||
if s.page != nil {
|
|
||||||
req.SetQueryParam("page", strconv.Itoa(*s.page))
|
|
||||||
}
|
|
||||||
if s.size != nil {
|
|
||||||
req.SetQueryParam("size", strconv.Itoa(*s.size))
|
|
||||||
}
|
|
||||||
if resp, err = req.
|
|
||||||
SetResult(&res).
|
|
||||||
SetError(&apiError).
|
|
||||||
Get("knowledge"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// KnowledgeDeleteService deletes a knowledge
|
|
||||||
type KnowledgeDeleteService struct {
|
|
||||||
client *Client
|
|
||||||
|
|
||||||
knowledgeID string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewKnowledgeDeleteService creates a new KnowledgeDeleteService
|
|
||||||
func NewKnowledgeDeleteService(client *Client) *KnowledgeDeleteService {
|
|
||||||
return &KnowledgeDeleteService{
|
|
||||||
client: client,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetKnowledgeID sets the knowledge id
|
|
||||||
func (s *KnowledgeDeleteService) SetKnowledgeID(knowledgeID string) *KnowledgeDeleteService {
|
|
||||||
s.knowledgeID = knowledgeID
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do deletes the knowledge
|
|
||||||
func (s *KnowledgeDeleteService) Do(ctx context.Context) (err error) {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
if resp, err = s.client.request(ctx).
|
|
||||||
SetPathParam("knowledge_id", s.knowledgeID).
|
|
||||||
SetError(&apiError).
|
|
||||||
Delete("knowledge/{knowledge_id}"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// KnowledgeCapacityService query the capacity of the knowledge
|
|
||||||
type KnowledgeCapacityService struct {
|
|
||||||
client *Client
|
|
||||||
}
|
|
||||||
|
|
||||||
// KnowledgeCapacityItem is an item in the knowledge capacity
|
|
||||||
type KnowledgeCapacityItem struct {
|
|
||||||
WordNum int64 `json:"word_num"`
|
|
||||||
Length int64 `json:"length"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// KnowledgeCapacityResponse is the response of the KnowledgeCapacityService
|
|
||||||
type KnowledgeCapacityResponse struct {
|
|
||||||
Used KnowledgeCapacityItem `json:"used"`
|
|
||||||
Total KnowledgeCapacityItem `json:"total"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetKnowledgeID sets the knowledge id
|
|
||||||
func NewKnowledgeCapacityService(client *Client) *KnowledgeCapacityService {
|
|
||||||
return &KnowledgeCapacityService{client: client}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do query the capacity of the knowledge
|
|
||||||
func (s *KnowledgeCapacityService) Do(ctx context.Context) (res KnowledgeCapacityResponse, err error) {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
if resp, err = s.client.request(ctx).
|
|
||||||
SetResult(&res).
|
|
||||||
SetError(&apiError).
|
|
||||||
Get("knowledge/capacity"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
@ -1,50 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestKnowledgeCapacity(t *testing.T) {
|
|
||||||
client, err := NewClient()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
s := client.KnowledgeCapacity()
|
|
||||||
res, err := s.Do(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotEmpty(t, res.Total.Length)
|
|
||||||
require.NotEmpty(t, res.Total.WordNum)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestKnowledgeServiceAll(t *testing.T) {
|
|
||||||
client, err := NewClient()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
s := client.KnowledgeCreate()
|
|
||||||
s.SetName("test")
|
|
||||||
s.SetDescription("test description")
|
|
||||||
s.SetEmbeddingID(KnowledgeEmbeddingIDEmbedding2)
|
|
||||||
|
|
||||||
res, err := s.Do(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotEmpty(t, res.ID)
|
|
||||||
|
|
||||||
s2 := client.KnowledgeList()
|
|
||||||
res2, err := s2.Do(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotEmpty(t, res2.List)
|
|
||||||
require.Equal(t, res.ID, res2.List[0].ID)
|
|
||||||
|
|
||||||
s3 := client.KnowledgeEdit(res.ID)
|
|
||||||
s3.SetDescription("test description 2")
|
|
||||||
s3.SetName("test 2")
|
|
||||||
s3.SetEmbeddingID(KnowledgeEmbeddingIDEmbedding2)
|
|
||||||
err = s3.Do(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
s4 := client.KnowledgeDelete(res.ID)
|
|
||||||
err = s4.Do(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
@ -1,54 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
// StringOr is a struct that can be either a string or a value of type T.
|
|
||||||
type StringOr[T any] struct {
|
|
||||||
String *string
|
|
||||||
Value *T
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
_ json.Marshaler = StringOr[float64]{}
|
|
||||||
_ json.Unmarshaler = &StringOr[float64]{}
|
|
||||||
)
|
|
||||||
|
|
||||||
// SetString sets the string value of the struct.
|
|
||||||
func (f *StringOr[T]) SetString(v string) {
|
|
||||||
f.String = &v
|
|
||||||
f.Value = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetValue sets the value of the struct.
|
|
||||||
func (f *StringOr[T]) SetValue(v T) {
|
|
||||||
f.String = nil
|
|
||||||
f.Value = &v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f StringOr[T]) MarshalJSON() ([]byte, error) {
|
|
||||||
if f.Value != nil {
|
|
||||||
return json.Marshal(f.Value)
|
|
||||||
}
|
|
||||||
return json.Marshal(f.String)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *StringOr[T]) UnmarshalJSON(data []byte) error {
|
|
||||||
if len(data) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if bytes.Equal(data, []byte("null")) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if data[0] == '"' {
|
|
||||||
f.String = new(string)
|
|
||||||
f.Value = nil
|
|
||||||
return json.Unmarshal(data, f.String)
|
|
||||||
} else {
|
|
||||||
f.Value = new(T)
|
|
||||||
f.String = nil
|
|
||||||
return json.Unmarshal(data, f.Value)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,37 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestStringOr(t *testing.T) {
|
|
||||||
data := struct {
|
|
||||||
Item *StringOr[int] `json:"item,omitempty"`
|
|
||||||
}{}
|
|
||||||
data.Item = &StringOr[int]{}
|
|
||||||
data.Item.SetString("test")
|
|
||||||
|
|
||||||
b, err := json.Marshal(data)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, `{"item":"test"}`, string(b))
|
|
||||||
|
|
||||||
data.Item.SetValue(1)
|
|
||||||
b, err = json.Marshal(data)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, `{"item":1}`, string(b))
|
|
||||||
|
|
||||||
err = json.Unmarshal([]byte(`{"item":"test2"}`), &data)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, data.Item.String)
|
|
||||||
require.Nil(t, data.Item.Value)
|
|
||||||
require.Equal(t, "test2", *data.Item.String)
|
|
||||||
|
|
||||||
err = json.Unmarshal([]byte(`{"item":2}`), &data)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Nil(t, data.Item.String)
|
|
||||||
require.NotNil(t, data.Item.Value)
|
|
||||||
require.Equal(t, 2, *data.Item.Value)
|
|
||||||
}
|
|
3
zhipu/testdata/test-file.jsonl
vendored
3
zhipu/testdata/test-file.jsonl
vendored
@ -1,3 +0,0 @@
|
|||||||
{"messages": [{"role": "system", "content": "你是一位乐于助人,知识渊博的全能AI助手。"},{"role": "user", "content": "请介绍一下AlphaGo的基本原理。"}, {"role": "assistant", "content": "AlphaGo利用深度神经网络与树搜索算法相结合,通过对弈数据进行训练,逐步提升游戏水平,能够对围棋局面进行判断并选择最佳落子位置。"}]}
|
|
||||||
{"messages": [{"role": "system", "content": "你是一位乐于助人,知识渊博的全能AI助手。"},{"role": "user", "content": "请列举一些健康饮食的重要原则。"}, {"role": "assistant", "content": "健康饮食的重要原则包括:膳食应营养均衡,合理摄入各种营养素;多吃新鲜蔬果,适量饮水;少吃糖、油、盐,忌吃过量快餐;养成规律生活作息习惯,适量运动。"}]}
|
|
||||||
{"messages": [{"role": "system", "content": "你是一位乐于助人,知识渊博的全能AI助手。"},{"role": "user", "content": "请给出一些改善睡眠质量的建议。"}, {"role": "assistant", "content": "改善睡眠质量的建议包括:保证睡眠时间,避免睡眠不足;睡前放松身心,远离刺激;养成规律作息时间,不要频繁熬夜;适量运动,但不要运动过度;睡前可以喝一杯热牛奶等温和饮料。"}]}
|
|
1
zhipu/testdata/test-file.txt
vendored
1
zhipu/testdata/test-file.txt
vendored
@ -1 +0,0 @@
|
|||||||
The quick brown fox jumps over the lazy dog
|
|
@ -1,22 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
// URLItem is a struct that contains a URL.
|
|
||||||
type URLItem struct {
|
|
||||||
URL string `json:"url,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// IDItem is a struct that contains an ID.
|
|
||||||
type IDItem struct {
|
|
||||||
ID string `json:"id,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ptr returns a pointer to the value passed in.
|
|
||||||
// Example:
|
|
||||||
//
|
|
||||||
// web_search_enable = zhipu.Ptr(false)
|
|
||||||
func Ptr[T any](v T) *T {
|
|
||||||
return &v
|
|
||||||
}
|
|
||||||
|
|
||||||
// M is a shorthand for map[string]any.
|
|
||||||
type M = map[string]any
|
|
@ -1,3 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
// nothing to test
|
|
@ -1,125 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/go-resty/resty/v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
VideoGenerationTaskStatusProcessing = "PROCESSING"
|
|
||||||
VideoGenerationTaskStatusSuccess = "SUCCESS"
|
|
||||||
VideoGenerationTaskStatusFail = "FAIL"
|
|
||||||
)
|
|
||||||
|
|
||||||
// VideoGenerationService creates a new video generation
|
|
||||||
type VideoGenerationService struct {
|
|
||||||
client *Client
|
|
||||||
|
|
||||||
model string
|
|
||||||
prompt string
|
|
||||||
userID string
|
|
||||||
imageURL string
|
|
||||||
requestID string
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
_ BatchSupport = &VideoGenerationService{}
|
|
||||||
)
|
|
||||||
|
|
||||||
// VideoGenerationResponse is the response of the VideoGenerationService
|
|
||||||
type VideoGenerationResponse struct {
|
|
||||||
RequestID string `json:"request_id"`
|
|
||||||
ID string `json:"id"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
TaskStatus string `json:"task_status"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewVideoGenerationService(client *Client) *VideoGenerationService {
|
|
||||||
return &VideoGenerationService{
|
|
||||||
client: client,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *VideoGenerationService) BatchMethod() string {
|
|
||||||
return "POST"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *VideoGenerationService) BatchURL() string {
|
|
||||||
return BatchEndpointV4VideosGenerations
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *VideoGenerationService) BatchBody() any {
|
|
||||||
return s.buildBody()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetModel sets the model parameter
|
|
||||||
func (s *VideoGenerationService) SetModel(model string) *VideoGenerationService {
|
|
||||||
s.model = model
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetPrompt sets the prompt parameter
|
|
||||||
func (s *VideoGenerationService) SetPrompt(prompt string) *VideoGenerationService {
|
|
||||||
s.prompt = prompt
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetUserID sets the userID parameter
|
|
||||||
func (s *VideoGenerationService) SetUserID(userID string) *VideoGenerationService {
|
|
||||||
s.userID = userID
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetImageURL sets the imageURL parameter
|
|
||||||
func (s *VideoGenerationService) SetImageURL(imageURL string) *VideoGenerationService {
|
|
||||||
s.imageURL = imageURL
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetRequestID sets the requestID parameter
|
|
||||||
func (s *VideoGenerationService) SetRequestID(requestID string) *VideoGenerationService {
|
|
||||||
s.requestID = requestID
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *VideoGenerationService) buildBody() M {
|
|
||||||
body := M{
|
|
||||||
"model": s.model,
|
|
||||||
"prompt": s.prompt,
|
|
||||||
}
|
|
||||||
if s.userID != "" {
|
|
||||||
body["user_id"] = s.userID
|
|
||||||
}
|
|
||||||
if s.imageURL != "" {
|
|
||||||
body["image_url"] = s.imageURL
|
|
||||||
}
|
|
||||||
if s.requestID != "" {
|
|
||||||
body["request_id"] = s.requestID
|
|
||||||
}
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *VideoGenerationService) Do(ctx context.Context) (res VideoGenerationResponse, err error) {
|
|
||||||
var (
|
|
||||||
resp *resty.Response
|
|
||||||
apiError APIErrorResponse
|
|
||||||
)
|
|
||||||
|
|
||||||
body := s.buildBody()
|
|
||||||
|
|
||||||
if resp, err = s.client.request(ctx).
|
|
||||||
SetBody(body).
|
|
||||||
SetResult(&res).
|
|
||||||
SetError(&apiError).
|
|
||||||
Post("videos/generations"); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.IsError() {
|
|
||||||
err = apiError
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
@ -1,38 +0,0 @@
|
|||||||
package zhipu
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestVideoGeneration(t *testing.T) {
|
|
||||||
client, err := NewClient()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
s := client.VideoGeneration("cogvideox")
|
|
||||||
s.SetPrompt("一只可爱的小猫")
|
|
||||||
|
|
||||||
res, err := s.Do(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotEmpty(t, res.TaskStatus)
|
|
||||||
require.NotEmpty(t, res.ID)
|
|
||||||
t.Log(res.ID)
|
|
||||||
|
|
||||||
for {
|
|
||||||
res, err := client.AsyncResult(res.ID).Do(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotEmpty(t, res.TaskStatus)
|
|
||||||
if res.TaskStatus == VideoGenerationTaskStatusSuccess {
|
|
||||||
require.NotEmpty(t, res.VideoResult)
|
|
||||||
t.Log(res.VideoResult[0].URL)
|
|
||||||
t.Log(res.VideoResult[0].CoverImageURL)
|
|
||||||
}
|
|
||||||
if res.TaskStatus != VideoGenerationTaskStatusProcessing {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second * 5)
|
|
||||||
}
|
|
||||||
}
|
|
Binary file not shown.
Before Width: | Height: | Size: 46 KiB |
Loading…
Reference in New Issue
Block a user