commit 29b0faf61bc8588a3c7e33980ffcdabd2497ad7d Author: STARAI\Star Date: Sat Sep 7 23:14:12 2024 +0800 1 diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f166f95 --- /dev/null +++ b/LICENSE @@ -0,0 +1,9 @@ +MIT License + +Copyright (c) 2024 apigo + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/aigc.go b/aigc.go new file mode 100644 index 0000000..f9540db --- /dev/null +++ b/aigc.go @@ -0,0 +1,83 @@ +package zhipu + +import ( + "apigo.cc/ai/zhipu/zhipu" + "context" + "errors" + "time" +) + +func (ag *Agent) FastMakeImage(prompt, size, refImage string) ([]string, error) { + return ag.MakeImage(ModelCogView3Plus, prompt, size, refImage) +} + +func (ag *Agent) BestMakeImage(prompt, size, refImage string) ([]string, error) { + return ag.MakeImage(ModelCogView3, prompt, size, refImage) +} + +func (ag *Agent) MakeImage(model, prompt, size, refImage string) ([]string, error) { + c, err := zhipu.NewClient(zhipu.WithAPIKey(ag.config.ApiKey), zhipu.WithBaseURL(ag.config.Endpoint)) + if err != nil { + return nil, err + } + + cc := c.ImageGeneration(model).SetPrompt(prompt) + if size != "" { + cc.SetSize(size) + } + + if r, err := cc.Do(context.Background()); err == nil { + results := make([]string, 0) + for _, item := range r.Data { + results = append(results, item.URL) + } + return results, nil + } else { + return nil, err + } +} + +func (ag *Agent) FastMakeVideo(prompt, size, refImage string) ([]string, []string, error) { + return ag.MakeVideo(ModelCogVideoX, prompt, size, refImage) +} + +func (ag *Agent) BestMakeVideo(prompt, size, refImage string) ([]string, []string, error) { + return ag.MakeVideo(ModelCogVideoX, prompt, size, refImage) +} + +func (ag *Agent) MakeVideo(model, prompt, size, refImage string) ([]string, []string, error) { + c, err := zhipu.NewClient(zhipu.WithAPIKey(ag.config.ApiKey), zhipu.WithBaseURL(ag.config.Endpoint)) + if err != nil { + return nil, nil, err + } + + cc := c.VideoGeneration(model).SetPrompt(prompt) + if refImage != "" { + cc.SetImageURL(refImage) + } + + if resp, err := cc.Do(context.Background()); err == nil { + for i := 0; i < 1200; i++ { + r, err := c.AsyncResult(resp.ID).Do(context.Background()) + if err != nil { + return nil, nil, err + } + if r.TaskStatus == zhipu.VideoGenerationTaskStatusSuccess { + covers := make([]string, 0) + results := make([]string, 0) + for _, item := range r.VideoResult { + results = append(results, item.URL) + covers = append(covers, item.CoverImageURL) + } + return results, covers, nil + } + if r.TaskStatus == zhipu.VideoGenerationTaskStatusFail { + return nil, nil, errors.New("fail on task " + resp.ID) + } + time.Sleep(3 * time.Second) + } + return nil, nil, errors.New("timeout on task " + resp.ID) + } else { + return nil, nil, err + } +} diff --git a/chat.go b/chat.go new file mode 100644 index 0000000..2151f07 --- /dev/null +++ b/chat.go @@ -0,0 +1,140 @@ +package zhipu + +import ( + "apigo.cc/ai/agent" + "apigo.cc/ai/zhipu/zhipu" + "context" + "strings" +) + +func (ag *Agent) FastAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) { + return ag.Ask(messages, &agent.ChatModelConfig{ + Model: ModelGLM4Flash, + }, callback) +} + +func (ag *Agent) LongAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) { + return ag.Ask(messages, &agent.ChatModelConfig{ + Model: ModelGLM4Long, + }, callback) +} + +func (ag *Agent) BatterAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) { + return ag.Ask(messages, &agent.ChatModelConfig{ + Model: ModelGLM4Plus, + }, callback) +} + +func (ag *Agent) BestAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) { + return ag.Ask(messages, &agent.ChatModelConfig{ + Model: ModelGLM40520, + }, callback) +} + +func (ag *Agent) MultiAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) { + return ag.Ask(messages, &agent.ChatModelConfig{ + Model: ModelGLM4VPlus, + }, callback) +} + +func (ag *Agent) BestMultiAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) { + return ag.Ask(messages, &agent.ChatModelConfig{ + Model: ModelGLM4V, + }, callback) +} + +func (ag *Agent) CodeInterpreterAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) { + return ag.Ask(messages, &agent.ChatModelConfig{ + Model: ModelGLM4AllTools, + Tools: map[string]any{agent.ToolCodeInterpreter: nil}, + }, callback) +} + +func (ag *Agent) WebSearchAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) { + return ag.Ask(messages, &agent.ChatModelConfig{ + Model: ModelGLM4AllTools, + Tools: map[string]any{agent.ToolWebSearch: nil}, + }, callback) +} + +func (ag *Agent) Ask(messages []agent.ChatMessage, config *agent.ChatModelConfig, callback func(answer string)) (string, agent.TokenUsage, error) { + if config == nil { + config = &agent.ChatModelConfig{} + } + config.SetDefault(&ag.config.DefaultChatModelConfig) + c, err := zhipu.NewClient(zhipu.WithAPIKey(ag.config.ApiKey), zhipu.WithBaseURL(ag.config.Endpoint)) + if err != nil { + return "", agent.TokenUsage{}, err + } + + cc := c.ChatCompletion(config.GetModel()) + for _, msg := range messages { + var contents []zhipu.ChatCompletionMultiContent + if msg.Contents != nil { + contents = make([]zhipu.ChatCompletionMultiContent, len(msg.Contents)) + for j, inPart := range msg.Contents { + part := zhipu.ChatCompletionMultiContent{} + part.Type = NameMap[inPart.Type] + switch inPart.Type { + case agent.TypeText: + part.Text = inPart.Content + case agent.TypeImage: + part.ImageURL = &zhipu.URLItem{URL: inPart.Content} + case agent.TypeVideo: + part.VideoURL = &zhipu.URLItem{URL: inPart.Content} + } + contents[j] = part + } + } + cc.AddMessage(zhipu.ChatCompletionMultiMessage{ + Role: NameMap[msg.Role], + Content: contents, + }) + } + + for name := range config.GetTools() { + switch name { + case agent.ToolCodeInterpreter: + cc.AddTool(zhipu.ChatCompletionToolCodeInterpreter{}) + case agent.ToolWebSearch: + cc.AddTool(zhipu.ChatCompletionToolWebBrowser{}) + } + } + + if config.GetMaxTokens() != 0 { + cc.SetMaxTokens(config.GetMaxTokens()) + } + if config.GetTemperature() != 0 { + cc.SetTemperature(config.GetTemperature()) + } + if config.GetTopP() != 0 { + cc.SetTopP(config.GetTopP()) + } + if callback != nil { + cc.SetStreamHandler(func(r2 zhipu.ChatCompletionResponse) error { + if r2.Choices != nil { + for _, ch := range r2.Choices { + text := ch.Delta.Content + callback(text) + } + } + return nil + }) + } + + if r, err := cc.Do(context.Background()); err == nil { + results := make([]string, 0) + if r.Choices != nil { + for _, ch := range r.Choices { + results = append(results, ch.Message.Content) + } + } + return strings.Join(results, ""), agent.TokenUsage{ + AskTokens: r.Usage.PromptTokens, + AnswerTokens: r.Usage.CompletionTokens, + TotalTokens: r.Usage.TotalTokens, + }, nil + } else { + return "", agent.TokenUsage{}, err + } +} diff --git a/config.go b/config.go new file mode 100644 index 0000000..a8e57b0 --- /dev/null +++ b/config.go @@ -0,0 +1,60 @@ +package zhipu + +import ( + "apigo.cc/ai/agent" + "apigo.cc/ai/zhipu/zhipu" +) + +type Agent struct { + config agent.APIConfig +} + +var NameMap = map[string]string{ + agent.TypeText: zhipu.MultiContentTypeText, + agent.TypeImage: zhipu.MultiContentTypeImageURL, + agent.TypeVideo: zhipu.MultiContentTypeVideoURL, + agent.RoleSystem: zhipu.RoleSystem, + agent.RoleUser: zhipu.RoleUser, + agent.RoleAssistant: zhipu.RoleAssistant, + agent.RoleTool: zhipu.RoleTool, +} + +const ( + ModelGLM4Plus = "GLM-4-Plus" + ModelGLM40520 = "GLM-4-0520" + ModelGLM4Long = "GLM-4-Long" + ModelGLM4AirX = "GLM-4-AirX" + ModelGLM4Air = "GLM-4-Air" + ModelGLM4Flash = "GLM-4-Flash" + ModelGLM4AllTools = "GLM-4-AllTools" + ModelGLM4 = "GLM-4" + ModelGLM4VPlus = "GLM-4V-Plus" + ModelGLM4V = "GLM-4V" + ModelCogVideoX = "CogVideoX" + ModelCogView3Plus = "CogView-3-Plus" + ModelCogView3 = "CogView-3" + ModelEmbedding3 = "Embedding-3" + ModelEmbedding2 = "Embedding-2" + ModelCharGLM3 = "CharGLM-3" + ModelEmohaa = "Emohaa" + ModelCodeGeeX4 = "CodeGeeX-4" +) + +func (ag *Agent) Support() agent.Support { + return agent.Support{ + Ask: true, + AskWithImage: true, + AskWithVideo: true, + AskWithCodeInterpreter: true, + AskWithWebSearch: true, + MakeImage: true, + MakeVideo: true, + Models: []string{ModelGLM4Plus, ModelGLM40520, ModelGLM4Long, ModelGLM4AirX, ModelGLM4Air, ModelGLM4Flash, ModelGLM4AllTools, ModelGLM4, ModelGLM4VPlus, ModelGLM4V, ModelCogVideoX, ModelCogView3Plus, ModelCogView3, ModelEmbedding3, ModelEmbedding2, ModelCharGLM3, ModelEmohaa, ModelCodeGeeX4}, + } +} + +func init() { + agent.RegisterAgentMaker("zhipu", func(config agent.APIConfig) agent.Agent { + return &Agent{config: config} + }) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..fd5f3e2 --- /dev/null +++ b/go.mod @@ -0,0 +1,17 @@ +module apigo.cc/ai/zhipu + +go 1.22 + +require ( + apigo.cc/ai/agent v0.0.1 + github.com/go-resty/resty/v2 v2.14.0 + github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/stretchr/testify v1.9.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/net v0.29.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/zhipu/CHANGELOG.md b/zhipu/CHANGELOG.md new file mode 100644 index 0000000..8e1033f --- /dev/null +++ b/zhipu/CHANGELOG.md @@ -0,0 +1,108 @@ +# Changelog +All notable changes to this project will be documented in this file. See [conventional commits](https://www.conventionalcommits.org/) for commit guidelines. + +- - - +## v0.1.2 - 2024-08-15 +#### Bug Fixes +- add FinishReasonStopSequence - (01b4201) - GUO YANKE +#### Documentation +- update README.md [skip ci] - (e48a88b) - GUO YANKE +#### Features +- add videos/generations - (7261999) - GUO YANKE +#### Miscellaneous Chores +- relaxing go version to 1.18 - (6acc17c) - GUO YANKE + +- - - + +## v0.1.1 - 2024-07-17 +#### Documentation +- update README.md [skip ci] - (695432a) - GUO YANKE +#### Features +- add support for GLM-4-AllTools - (9627a36) - GUO YANKE + +- - - + +## v0.1.0 - 2024-06-28 +#### Bug Fixes +- rename client function for batch list - (40ac05f) - GUO YANKE +#### Documentation +- update README.md [skip ci] - (6ce5754) - GUO YANKE +#### Features +- add knowledge capacity service - (4ce62b3) - GUO YANKE +#### Refactoring +- update batch service - (b92d438) - GUO YANKE +- update chat completion service - (19dd77f) - GUO YANKE +- update embedding service - (c1bbc2d) - GUO YANKE +- update file services - (7ef4d87) - GUO YANKE +- update fine tune services, using APIError - (15aed88) - GUO YANKE +- update fine tune services - (664523b) - GUO YANKE +- update image generation service - (a18e028) - GUO YANKE +- update knowledge services - (c7bfb73) - GUO YANKE + +- - - + +## v0.0.6 - 2024-06-28 +#### Features +- add batch support for result reader - (c062095) - GUO YANKE +- add fine tune services - (f172f51) - GUO YANKE +- add knowledge service - (09792b5) - GUO YANKE + +- - - + +## v0.0.5 - 2024-06-28 +#### Bug Fixes +- api error parsing - (60a17f4) - GUO YANKE +#### Features +- add batch service - (389aec3) - GUO YANKE +- add batch support for chat completions, image generations and embeddings - (c017ffd) - GUO YANKE +- add file edit/get/delete service - (8a4d309) - GUO YANKE +- add file create serivce - (6d2140b) - GUO YANKE + +- - - + +## v0.0.4 - 2024-06-26 +#### Bug Fixes +- remove Client.R(), hide resty for future removal - (dc2a4ca) - GUO YANKE +#### Features +- add meta support for charglm - (fdd20e7) - GUO YANKE +- add client option to custom http client - (c62d6a9) - GUO YANKE + +- - - + +## v0.0.3 - 2024-06-26 +#### Features +- add image generation service - (9f3f54f) - GUO YANKE +- add support for vision models - (2dcd82a) - GUO YANKE +- add embedding service - (f57806a) - GUO YANKE + +- - - + +## v0.0.2 - 2024-06-26 +#### Bug Fixes +- **(deps)** update golang-jwt/jwt to v5 - (2f76a57) - GUO YANKE +#### Features +- add constants for roles - (3d08a72) - GUO YANKE + +- - - + +## v0.0.1 - 2024-06-26 +#### Bug Fixes +- add json tag "omitempty" to various types - (bf81097) - GUO YANKE +#### Continuous Integration +- add github action workflows for testing - (5a64987) - GUO YANKE +#### Documentation +- update README.md [skip ci] - (d504f57) - GUO YANKE +#### Features +- add chat completion in stream mode - (130fe1d) - GUO YANKE +- add chat completion in non-stream mode - (2326e37) - GUO YANKE +- support debug option while creating client - (0f104d8) - GUO YANKE +- add APIError and APIErrorResponse - (1886d85) - GUO YANKE +- add client struct - (710d8e8) - GUO YANKE +#### Refactoring +- change signature of Client#createJWT since there is no reason to fail - (f0d7887) - GUO YANKE +#### Tests +- add client_test.go - (a3fc217) - GUO YANKE + +- - - + +Changelog generated by [cocogitto](https://github.com/cocogitto/cocogitto). \ No newline at end of file diff --git a/zhipu/LICENSE b/zhipu/LICENSE new file mode 100644 index 0000000..67dc60b --- /dev/null +++ b/zhipu/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Yanke G. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/zhipu/README.md b/zhipu/README.md new file mode 100644 index 0000000..555e351 --- /dev/null +++ b/zhipu/README.md @@ -0,0 +1,280 @@ +# zhipu + +[![Go Reference](https://pkg.go.dev/badge/github.com/yankeguo/zhipu.svg)](https://pkg.go.dev/github.com/yankeguo/zhipu) +[![go](https://github.com/yankeguo/zhipu/actions/workflows/go.yml/badge.svg)](https://github.com/yankeguo/zhipu/actions/workflows/go.yml) + +[中文文档](README.zh.md) + +A 3rd-Party Golang Client Library for Zhipu AI Platform + +## Usage + +### Install the package + +```bash +go get -u github.com/yankeguo/zhipu +``` + +### Create a client + +```go +// this will use environment variables ZHIPUAI_API_KEY +client, err := zhipu.NewClient() +// or you can specify the API key +client, err = zhipu.NewClient(zhipu.WithAPIKey("your api key")) +``` + +### Use the client + +**ChatCompletion** + +```go +service := client.ChatCompletion("glm-4-flash"). + AddMessage(zhipu.ChatCompletionMessage{ + Role: "user", + Content: "你好", + }) + +res, err := service.Do(context.Background()) + +if err != nil { + zhipu.GetAPIErrorCode(err) // get the API error code +} else { + println(res.Choices[0].Message.Content) +} +``` + +**ChatCompletion (Stream)** + +```go +service := client.ChatCompletion("glm-4-flash"). + AddMessage(zhipu.ChatCompletionMessage{ + Role: "user", + Content: "你好", + }).SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error { + println(chunk.Choices[0].Delta.Content) + return nil + }) + +res, err := service.Do(context.Background()) + +if err != nil { + zhipu.GetAPIErrorCode(err) // get the API error code +} else { + // this package will combine the stream chunks and build a final result mimicking the non-streaming API + println(res.Choices[0].Message.Content) +} +``` + +**ChatCompletion (Stream with GLM-4-AllTools)** + +```go +// CodeInterpreter +s := client.ChatCompletion("GLM-4-AllTools") +s.AddMessage(zhipu.ChatCompletionMultiMessage{ + Role: "user", + Content: []zhipu.ChatCompletionMultiContent{ + { + Type: "text", + Text: "计算[5,10,20,700,99,310,978,100]的平均值和方差。", + }, + }, +}) +s.AddTool(zhipu.ChatCompletionToolCodeInterpreter{ + Sandbox: zhipu.Ptr(CodeInterpreterSandboxAuto), +}) +s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error { + for _, c := range chunk.Choices { + for _, tc := range c.Delta.ToolCalls { + if tc.Type == ToolTypeCodeInterpreter && tc.CodeInterpreter != nil { + if tc.CodeInterpreter.Input != "" { + // DO SOMETHING + } + if len(tc.CodeInterpreter.Outputs) > 0 { + // DO SOMETHING + } + } + } + } + return nil +}) + +// WebBrowser +// CAUTION: NOT 'WebSearch' +s := client.ChatCompletion("GLM-4-AllTools") +s.AddMessage(zhipu.ChatCompletionMultiMessage{ + Role: "user", + Content: []zhipu.ChatCompletionMultiContent{ + { + Type: "text", + Text: "搜索下本周深圳天气如何", + }, + }, +}) +s.AddTool(zhipu.ChatCompletionToolWebBrowser{}) +s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error { + for _, c := range chunk.Choices { + for _, tc := range c.Delta.ToolCalls { + if tc.Type == ToolTypeWebBrowser && tc.WebBrowser != nil { + if tc.WebBrowser.Input != "" { + // DO SOMETHING + } + if len(tc.WebBrowser.Outputs) > 0 { + // DO SOMETHING + } + } + } + } + return nil +}) +s.Do(context.Background()) + +// DrawingTool +s := client.ChatCompletion("GLM-4-AllTools") +s.AddMessage(zhipu.ChatCompletionMultiMessage{ + Role: "user", + Content: []zhipu.ChatCompletionMultiContent{ + { + Type: "text", + Text: "画一个正弦函数图像", + }, + }, +}) +s.AddTool(zhipu.ChatCompletionToolDrawingTool{}) +s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error { + for _, c := range chunk.Choices { + for _, tc := range c.Delta.ToolCalls { + if tc.Type == ToolTypeDrawingTool && tc.DrawingTool != nil { + if tc.DrawingTool.Input != "" { + // DO SOMETHING + } + if len(tc.DrawingTool.Outputs) > 0 { + // DO SOMETHING + } + } + } + } + return nil +}) +s.Do(context.Background()) +``` + +**Embedding** + +```go +service := client.Embedding("embedding-v2").SetInput("你好呀") +service.Do(context.Background()) +``` + +**Image Generation** + +```go +service := client.ImageGeneration("cogview-3").SetPrompt("一只可爱的小猫咪") +service.Do(context.Background()) +``` + +**Video Generation** + +```go +service := client.VideoGeneration("cogvideox").SetPrompt("一只可爱的小猫咪") +resp, err := service.Do(context.Background()) + +for { + result, err := client.AsyncResult(resp.ID).Do(context.Background()) + + if result.TaskStatus == zhipu.VideoGenerationTaskStatusSuccess { + _ = result.VideoResult[0].URL + _ = result.VideoResult[0].CoverImageURL + break + } + + if result.TaskStatus != zhipu.VideoGenerationTaskStatusProcessing { + break + } + + time.Sleep(5 * time.Second) +} +``` + +**Upload File (Retrieval)** + +```go +service := client.FileCreate(zhipu.FilePurposeRetrieval) +service.SetLocalFile(filepath.Join("testdata", "test-file.txt")) +service.SetKnowledgeID("your-knowledge-id") + +service.Do(context.Background()) +``` + +**Upload File (Fine-Tune)** + +```go +service := client.FileCreate(zhipu.FilePurposeFineTune) +service.SetLocalFile(filepath.Join("testdata", "test-file.jsonl")) +service.Do(context.Background()) +``` + +**Batch Create** + +```go +service := client.BatchCreate(). + SetInputFileID("fileid"). + SetCompletionWindow(zhipu.BatchCompletionWindow24h). + SetEndpoint(BatchEndpointV4ChatCompletions) +service.Do(context.Background()) +``` + +**Knowledge Base** + +```go +client.KnowledgeCreate("") +client.KnowledgeEdit("") +``` + +**Fine Tune** + +```go +client.FineTuneCreate("") +``` + +### Batch Support + +**Batch File Writer** + +```go +f, err := os.OpenFile("batch.jsonl", os.O_CREATE|os.O_WRONLY, 0644) + +bw := zhipu.NewBatchFileWriter(f) + +bw.Add("action_1", client.ChatCompletion("glm-4-flash"). + AddMessage(zhipu.ChatCompletionMessage{ + Role: "user", + Content: "你好", + })) +bw.Add("action_2", client.Embedding("embedding-v2").SetInput("你好呀")) +bw.Add("action_3", client.ImageGeneration("cogview-3").SetPrompt("一只可爱的小猫咪")) +``` + +**Batch Result Reader** + +```go +br := zhipu.NewBatchResultReader[zhipu.ChatCompletionResponse](r) + +for { + var res zhipu.BatchResult[zhipu.ChatCompletionResponse] + err := br.Read(&res) + if err != nil { + break + } +} +``` + +## Donation + +Executing unit tests will actually call the ChatGLM API and consume my quota. Please donate and thank you for your support! + + + +## Credits + +GUO YANKE, MIT License diff --git a/zhipu/README.zh.md b/zhipu/README.zh.md new file mode 100644 index 0000000..2563f05 --- /dev/null +++ b/zhipu/README.zh.md @@ -0,0 +1,278 @@ +# zhipu + +[![Go Reference](https://pkg.go.dev/badge/github.com/yankeguo/zhipu.svg)](https://pkg.go.dev/github.com/yankeguo/zhipu) +[![go](https://github.com/yankeguo/zhipu/actions/workflows/go.yml/badge.svg)](https://github.com/yankeguo/zhipu/actions/workflows/go.yml) + +Zhipu AI 平台第三方 Golang 客户端库 + +## 用法 + +### 安装库 + +```bash +go get -u github.com/yankeguo/zhipu +``` + +### 创建客户端 + +```go +// 默认使用环境变量 ZHIPUAI_API_KEY +client, err := zhipu.NewClient() +// 或者手动指定密钥 +client, err = zhipu.NewClient(zhipu.WithAPIKey("your api key")) +``` + +### 使用客户端 + +**ChatCompletion(大语言模型)** + +```go +service := client.ChatCompletion("glm-4-flash"). + AddMessage(zhipu.ChatCompletionMessage{ + Role: "user", + Content: "你好", + }) + +res, err := service.Do(context.Background()) + +if err != nil { + zhipu.GetAPIErrorCode(err) // get the API error code +} else { + println(res.Choices[0].Message.Content) +} +``` + +**ChatCompletion(流式调用大语言模型)** + +```go +service := client.ChatCompletion("glm-4-flash"). + AddMessage(zhipu.ChatCompletionMessage{ + Role: "user", + Content: "你好", + }).SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error { + println(chunk.Choices[0].Delta.Content) + return nil + }) + +res, err := service.Do(context.Background()) + +if err != nil { + zhipu.GetAPIErrorCode(err) // get the API error code +} else { + // this package will combine the stream chunks and build a final result mimicking the non-streaming API + println(res.Choices[0].Message.Content) +} +``` + +**ChatCompletion(流式调用大语言工具模型GLM-4-AllTools)** + +```go +// CodeInterpreter +s := client.ChatCompletion("GLM-4-AllTools") +s.AddMessage(zhipu.ChatCompletionMultiMessage{ + Role: "user", + Content: []zhipu.ChatCompletionMultiContent{ + { + Type: "text", + Text: "计算[5,10,20,700,99,310,978,100]的平均值和方差。", + }, + }, +}) +s.AddTool(zhipu.ChatCompletionToolCodeInterpreter{ + Sandbox: zhipu.Ptr(CodeInterpreterSandboxAuto), +}) +s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error { + for _, c := range chunk.Choices { + for _, tc := range c.Delta.ToolCalls { + if tc.Type == ToolTypeCodeInterpreter && tc.CodeInterpreter != nil { + if tc.CodeInterpreter.Input != "" { + // DO SOMETHING + } + if len(tc.CodeInterpreter.Outputs) > 0 { + // DO SOMETHING + } + } + } + } + return nil +}) + +// WebBrowser +// CAUTION: NOT 'WebSearch' +s := client.ChatCompletion("GLM-4-AllTools") +s.AddMessage(zhipu.ChatCompletionMultiMessage{ + Role: "user", + Content: []zhipu.ChatCompletionMultiContent{ + { + Type: "text", + Text: "搜索下本周深圳天气如何", + }, + }, +}) +s.AddTool(zhipu.ChatCompletionToolWebBrowser{}) +s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error { + for _, c := range chunk.Choices { + for _, tc := range c.Delta.ToolCalls { + if tc.Type == ToolTypeWebBrowser && tc.WebBrowser != nil { + if tc.WebBrowser.Input != "" { + // DO SOMETHING + } + if len(tc.WebBrowser.Outputs) > 0 { + // DO SOMETHING + } + } + } + } + return nil +}) +s.Do(context.Background()) + +// DrawingTool +s := client.ChatCompletion("GLM-4-AllTools") +s.AddMessage(zhipu.ChatCompletionMultiMessage{ + Role: "user", + Content: []zhipu.ChatCompletionMultiContent{ + { + Type: "text", + Text: "画一个正弦函数图像", + }, + }, +}) +s.AddTool(zhipu.ChatCompletionToolDrawingTool{}) +s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error { + for _, c := range chunk.Choices { + for _, tc := range c.Delta.ToolCalls { + if tc.Type == ToolTypeDrawingTool && tc.DrawingTool != nil { + if tc.DrawingTool.Input != "" { + // DO SOMETHING + } + if len(tc.DrawingTool.Outputs) > 0 { + // DO SOMETHING + } + } + } + } + return nil +}) +s.Do(context.Background()) +``` + +**Embedding** + +```go +service := client.Embedding("embedding-v2").SetInput("你好呀") +service.Do(context.Background()) +``` + +**ImageGeneration(图像生成)** + +```go +service := client.ImageGeneration("cogview-3").SetPrompt("一只可爱的小猫咪") +service.Do(context.Background()) +``` + +**VideoGeneration(视频生成)** + +```go +service := client.VideoGeneration("cogvideox").SetPrompt("一只可爱的小猫咪") +resp, err := service.Do(context.Background()) + +for { + result, err := client.AsyncResult(resp.ID).Do(context.Background()) + + if result.TaskStatus == zhipu.VideoGenerationTaskStatusSuccess { + _ = result.VideoResult[0].URL + _ = result.VideoResult[0].CoverImageURL + break + } + + if result.TaskStatus != zhipu.VideoGenerationTaskStatusProcessing { + break + } + + time.Sleep(5 * time.Second) +} +``` + +**UploadFile(上传文件用于取回)** + +```go +service := client.FileCreate(zhipu.FilePurposeRetrieval) +service.SetLocalFile(filepath.Join("testdata", "test-file.txt")) +service.SetKnowledgeID("your-knowledge-id") + +service.Do(context.Background()) +``` + +**UploadFile(上传文件用于微调)** + +```go +service := client.FileCreate(zhipu.FilePurposeFineTune) +service.SetLocalFile(filepath.Join("testdata", "test-file.jsonl")) +service.Do(context.Background()) +``` + +**BatchCreate(创建批量任务)** + +```go +service := client.BatchCreate(). + SetInputFileID("fileid"). + SetCompletionWindow(zhipu.BatchCompletionWindow24h). + SetEndpoint(BatchEndpointV4ChatCompletions) +service.Do(context.Background()) +``` + +**KnowledgeBase(知识库)** + +```go +client.KnowledgeCreate("") +client.KnowledgeEdit("") +``` + +**FineTune(微调)** + +```go +client.FineTuneCreate("") +``` + +### 批量任务辅助工具 + +**批量任务文件创建** + +```go +f, err := os.OpenFile("batch.jsonl", os.O_CREATE|os.O_WRONLY, 0644) + +bw := zhipu.NewBatchFileWriter(f) + +bw.Add("action_1", client.ChatCompletion("glm-4-flash"). + AddMessage(zhipu.ChatCompletionMessage{ + Role: "user", + Content: "你好", + })) +bw.Add("action_2", client.Embedding("embedding-v2").SetInput("你好呀")) +bw.Add("action_3", client.ImageGeneration("cogview-3").SetPrompt("一只可爱的小猫咪")) +``` + +**批量任务结果解析** + +```go +br := zhipu.NewBatchResultReader[zhipu.ChatCompletionResponse](r) + +for { + var res zhipu.BatchResult[zhipu.ChatCompletionResponse] + err := br.Read(&res) + if err != nil { + break + } +} +``` + +## 赞助 + +执行单元测试会真实调用GLM接口,消耗我充值的额度,开发不易,请微信扫码捐赠,感谢您的支持! + + + +## 许可证 + +GUO YANKE, MIT License diff --git a/zhipu/async_result.go b/zhipu/async_result.go new file mode 100644 index 0000000..5e51db1 --- /dev/null +++ b/zhipu/async_result.go @@ -0,0 +1,63 @@ +package zhipu + +import ( + "context" + + "github.com/go-resty/resty/v2" +) + +// AsyncResultService creates a new async result get service +type AsyncResultService struct { + client *Client + + id string +} + +// AsyncResultVideo is the video result of the AsyncResultService +type AsyncResultVideo struct { + URL string `json:"url"` + CoverImageURL string `json:"cover_image_url"` +} + +// AsyncResultResponse is the response of the AsyncResultService +type AsyncResultResponse struct { + Model string `json:"model"` + TaskStatus string `json:"task_status"` + RequestID string `json:"request_id"` + ID string `json:"id"` + VideoResult []AsyncResultVideo `json:"video_result"` +} + +// NewAsyncResultService creates a new async result get service +func NewAsyncResultService(client *Client) *AsyncResultService { + return &AsyncResultService{ + client: client, + } +} + +// SetID sets the id parameter +func (s *AsyncResultService) SetID(id string) *AsyncResultService { + s.id = id + return s +} + +func (s *AsyncResultService) Do(ctx context.Context) (res AsyncResultResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + if resp, err = s.client.request(ctx). + SetResult(&res). + SetError(&apiError). + Get("async-result/" + s.id); err != nil { + return + } + + if resp.IsError() { + err = apiError + return + } + + return +} diff --git a/zhipu/async_result_test.go b/zhipu/async_result_test.go new file mode 100644 index 0000000..f4a78b5 --- /dev/null +++ b/zhipu/async_result_test.go @@ -0,0 +1 @@ +package zhipu diff --git a/zhipu/batch.go b/zhipu/batch.go new file mode 100644 index 0000000..d661f56 --- /dev/null +++ b/zhipu/batch.go @@ -0,0 +1,258 @@ +package zhipu + +import ( + "context" + "encoding/json" + "strconv" + + "github.com/go-resty/resty/v2" +) + +const ( + BatchEndpointV4ChatCompletions = "/v4/chat/completions" + BatchEndpointV4ImagesGenerations = "/v4/images/generations" + BatchEndpointV4Embeddings = "/v4/embeddings" + BatchEndpointV4VideosGenerations = "/v4/videos/generations" + + BatchCompletionWindow24h = "24h" +) + +// BatchRequestCounts represents the counts of the batch requests. +type BatchRequestCounts struct { + Total int64 `json:"total"` + Completed int64 `json:"completed"` + Failed int64 `json:"failed"` +} + +// BatchItem represents a batch item. +type BatchItem struct { + ID string `json:"id"` + Object any `json:"object"` + Endpoint string `json:"endpoint"` + InputFileID string `json:"input_file_id"` + CompletionWindow string `json:"completion_window"` + Status string `json:"status"` + OutputFileID string `json:"output_file_id"` + ErrorFileID string `json:"error_file_id"` + CreatedAt int64 `json:"created_at"` + InProgressAt int64 `json:"in_progress_at"` + ExpiresAt int64 `json:"expires_at"` + FinalizingAt int64 `json:"finalizing_at"` + CompletedAt int64 `json:"completed_at"` + FailedAt int64 `json:"failed_at"` + ExpiredAt int64 `json:"expired_at"` + CancellingAt int64 `json:"cancelling_at"` + CancelledAt int64 `json:"cancelled_at"` + RequestCounts BatchRequestCounts `json:"request_counts"` + Metadata json.RawMessage `json:"metadata"` +} + +// BatchCreateService is a service to create a batch. +type BatchCreateService struct { + client *Client + + inputFileID string + endpoint string + completionWindow string + metadata any +} + +// NewBatchCreateService creates a new BatchCreateService. +func NewBatchCreateService(client *Client) *BatchCreateService { + return &BatchCreateService{client: client} +} + +// SetInputFileID sets the input file id for the batch. +func (s *BatchCreateService) SetInputFileID(inputFileID string) *BatchCreateService { + s.inputFileID = inputFileID + return s +} + +// SetEndpoint sets the endpoint for the batch. +func (s *BatchCreateService) SetEndpoint(endpoint string) *BatchCreateService { + s.endpoint = endpoint + return s +} + +// SetCompletionWindow sets the completion window for the batch. +func (s *BatchCreateService) SetCompletionWindow(window string) *BatchCreateService { + s.completionWindow = window + return s +} + +// SetMetadata sets the metadata for the batch. +func (s *BatchCreateService) SetMetadata(metadata any) *BatchCreateService { + s.metadata = metadata + return s +} + +// Do executes the batch create service. +func (s *BatchCreateService) Do(ctx context.Context) (res BatchItem, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + if resp, err = s.client.request(ctx). + SetBody(M{ + "input_file_id": s.inputFileID, + "endpoint": s.endpoint, + "completion_window": s.completionWindow, + "metadata": s.metadata, + }). + SetResult(&res). + SetError(&apiError). + Post("batches"); err != nil { + return + } + + if resp.IsError() { + err = apiError + } + + return +} + +// BatchGetService is a service to get a batch. +type BatchGetService struct { + client *Client + batchID string +} + +// BatchGetResponse represents the response of the batch get service. +type BatchGetResponse = BatchItem + +// NewBatchGetService creates a new BatchGetService. +func NewBatchGetService(client *Client) *BatchGetService { + return &BatchGetService{client: client} +} + +// SetBatchID sets the batch id for the batch get service. +func (s *BatchGetService) SetBatchID(batchID string) *BatchGetService { + s.batchID = batchID + return s +} + +// Do executes the batch get service. +func (s *BatchGetService) Do(ctx context.Context) (res BatchGetResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + if resp, err = s.client.request(ctx). + SetPathParam("batch_id", s.batchID). + SetResult(&res). + SetError(&apiError). + Get("batches/{batch_id}"); err != nil { + return + } + + if resp.IsError() { + err = apiError + } + + return +} + +// BatchCancelService is a service to cancel a batch. +type BatchCancelService struct { + client *Client + batchID string +} + +// NewBatchCancelService creates a new BatchCancelService. +func NewBatchCancelService(client *Client) *BatchCancelService { + return &BatchCancelService{client: client} +} + +// SetBatchID sets the batch id for the batch cancel service. +func (s *BatchCancelService) SetBatchID(batchID string) *BatchCancelService { + s.batchID = batchID + return s +} + +// Do executes the batch cancel service. +func (s *BatchCancelService) Do(ctx context.Context) (err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + if resp, err = s.client.request(ctx). + SetPathParam("batch_id", s.batchID). + SetBody(M{}). + SetError(&apiError). + Post("batches/{batch_id}/cancel"); err != nil { + return + } + + if resp.IsError() { + err = apiError + } + + return +} + +// BatchListService is a service to list batches. +type BatchListService struct { + client *Client + + after *string + limit *int +} + +// BatchListResponse represents the response of the batch list service. +type BatchListResponse struct { + Object string `json:"object"` + Data []BatchItem `json:"data"` + FirstID string `json:"first_id"` + LastID string `json:"last_id"` + HasMore bool `json:"has_more"` +} + +// NewBatchListService creates a new BatchListService. +func NewBatchListService(client *Client) *BatchListService { + return &BatchListService{client: client} +} + +// SetAfter sets the after cursor for the batch list service. +func (s *BatchListService) SetAfter(after string) *BatchListService { + s.after = &after + return s +} + +// SetLimit sets the limit for the batch list service. +func (s *BatchListService) SetLimit(limit int) *BatchListService { + s.limit = &limit + return s +} + +// Do executes the batch list service. +func (s *BatchListService) Do(ctx context.Context) (res BatchListResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + req := s.client.request(ctx) + if s.after != nil { + req.SetQueryParam("after", *s.after) + } + if s.limit != nil { + req.SetQueryParam("limit", strconv.Itoa(*s.limit)) + } + + if resp, err = req. + SetResult(&res). + SetError(&apiError). + Get("batches"); err != nil { + return + } + + if resp.IsError() { + err = apiError + } + + return +} diff --git a/zhipu/batch_support.go b/zhipu/batch_support.go new file mode 100644 index 0000000..9427ef4 --- /dev/null +++ b/zhipu/batch_support.go @@ -0,0 +1,63 @@ +package zhipu + +import ( + "encoding/json" + "io" +) + +// BatchSupport is the interface for services with batch support. +type BatchSupport interface { + BatchMethod() string + BatchURL() string + BatchBody() any +} + +// BatchFileWriter is a writer for batch files. +type BatchFileWriter struct { + w io.Writer + je *json.Encoder +} + +// NewBatchFileWriter creates a new BatchFileWriter. +func NewBatchFileWriter(w io.Writer) *BatchFileWriter { + return &BatchFileWriter{w: w, je: json.NewEncoder(w)} +} + +// Write writes a batch file. +func (b *BatchFileWriter) Write(customID string, s BatchSupport) error { + return b.je.Encode(M{ + "custom_id": customID, + "method": s.BatchMethod(), + "url": s.BatchURL(), + "body": s.BatchBody(), + }) +} + +// BatchResultResponse is the response of a batch result. +type BatchResultResponse[T any] struct { + StatusCode int `json:"status_code"` + Body T `json:"body"` +} + +// BatchResult is the result of a batch. +type BatchResult[T any] struct { + ID string `json:"id"` + CustomID string `json:"custom_id"` + Response BatchResultResponse[T] `json:"response"` +} + +// BatchResultReader reads batch results. +type BatchResultReader[T any] struct { + r io.Reader + jd *json.Decoder +} + +// NewBatchResultReader creates a new BatchResultReader. +func NewBatchResultReader[T any](r io.Reader) *BatchResultReader[T] { + return &BatchResultReader[T]{r: r, jd: json.NewDecoder(r)} +} + +// Read reads a batch result. +func (r *BatchResultReader[T]) Read(out *BatchResult[T]) error { + return r.jd.Decode(out) +} diff --git a/zhipu/batch_support_test.go b/zhipu/batch_support_test.go new file mode 100644 index 0000000..59a19c6 --- /dev/null +++ b/zhipu/batch_support_test.go @@ -0,0 +1,73 @@ +package zhipu + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBatchFileWriter(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + buf := &bytes.Buffer{} + + w := NewBatchFileWriter(buf) + err = w.Write("batch-1", client.ChatCompletion("a").AddMessage(ChatCompletionMessage{ + Role: "user", Content: "hello", + })) + require.NoError(t, err) + err = w.Write("batch-2", client.Embedding("c").SetInput("whoa")) + require.NoError(t, err) + err = w.Write("batch-3", client.ImageGeneration("d").SetPrompt("whoa")) + require.NoError(t, err) + + require.Equal(t, `{"body":{"messages":[{"role":"user","content":"hello"}],"model":"a"},"custom_id":"batch-1","method":"POST","url":"/v4/chat/completions"} +{"body":{"input":"whoa","model":"c"},"custom_id":"batch-2","method":"POST","url":"/v4/embeddings"} +{"body":{"model":"d","prompt":"whoa"},"custom_id":"batch-3","method":"POST","url":"/v4/images/generations"} +`, buf.String()) +} + +func TestBatchResultReader(t *testing.T) { + result := ` + {"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":26,"prompt_tokens":89,"total_tokens":115},"model":"glm-4","id":"8668357533850320547","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"订单处理慢\"\n}\n'''"}}],"request_id":"615-request-1"}},"custom_id":"request-1","id":"batch_1791490810192076800"} +{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":22,"prompt_tokens":94,"total_tokens":116},"model":"glm-4","id":"8668368425887509080","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"产品缺陷\"\n}\n'''"}}],"request_id":"616-request-2"}},"custom_id":"request-2","id":"batch_1791490810192076800"} +{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":25,"prompt_tokens":86,"total_tokens":111},"model":"glm-4","id":"8668355815863214980","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"正面\",\n \"特定问题标注\": \"性价比\"\n}\n'''"}}],"request_id":"617-request-3"}},"custom_id":"request-3","id":"batch_1791490810192076800"} +{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":28,"prompt_tokens":89,"total_tokens":117},"model":"glm-4","id":"8668355815863214981","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"说明文档不清晰\"\n}\n'''"}}],"request_id":"618-request-4"}},"custom_id":"request-4","id":"batch_1791490810192076800"} + +{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":26,"prompt_tokens":88,"total_tokens":114},"model":"glm-4","id":"8668357533850320546","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"中性\",\n \"特定问题标注\": \"价格问题\"\n}\n'''"}}],"request_id":"619-request-5"}},"custom_id":"request-5","id":"batch_1791490810192076800"} + +{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":26,"prompt_tokens":90,"total_tokens":116},"model":"glm-4","id":"8668356159460662846","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"配送延迟\"\n}\n'''"}}],"request_id":"620-request-6"}},"custom_id":"request-6","id":"batch_1791490810192076800"} + + +{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":27,"prompt_tokens":88,"total_tokens":115},"model":"glm-4","id":"8668357671289274638","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"产品描述不符\"\n}\n'''"}}],"request_id":"621-request-7"}},"custom_id":"request-7","id":"batch_1791490810192076800"} +{"response":{"status_code":200,"body":{"created":1715959702,"usage":{"completion_tokens":26,"prompt_tokens":87,"total_tokens":113},"model":"glm-4","id":"8668355644064514872","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"正面\",\n \"特定问题标注\": \"客服态度\"\n}\n'''"}}],"request_id":"622-request-8"}},"custom_id":"request-8","id":"batch_1791490810192076800"} + {"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":29,"prompt_tokens":90,"total_tokens":119},"model":"glm-4","id":"8668357671289274639","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"包装问题, 产品损坏\"\n}\n'''"}}],"request_id":"623-request-9"}},"custom_id":"request-9","id":"batch_1791490810192076800"} +{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":27,"prompt_tokens":87,"total_tokens":114},"model":"glm-4","id":"8668355644064514871","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"正面\",\n \"特定问题标注\": \"产品描述不符\"\n}\n'''"}}],"request_id":"624-request-10"}},"custom_id":"request-10","id":"batch_1791490810192076800"} +` + + brr := NewBatchResultReader[ChatCompletionResponse](bytes.NewReader([]byte(result))) + + var count int + + for { + var res BatchResult[ChatCompletionResponse] + + err := brr.Read(&res) + + if err != nil { + if err == io.EOF { + err = nil + } + require.Equal(t, 10, count) + require.NoError(t, err) + break + } + + require.Equal(t, 200, res.Response.StatusCode) + + count++ + } +} diff --git a/zhipu/batch_test.go b/zhipu/batch_test.go new file mode 100644 index 0000000..8ce5aa5 --- /dev/null +++ b/zhipu/batch_test.go @@ -0,0 +1,59 @@ +package zhipu + +import ( + "bytes" + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBatchServiceAll(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + buf := &bytes.Buffer{} + + bfw := NewBatchFileWriter(buf) + err = bfw.Write("batch_1", client.ChatCompletion("glm-4-flash").AddMessage(ChatCompletionMessage{ + Role: RoleUser, Content: "你好呀", + })) + require.NoError(t, err) + err = bfw.Write("batch_2", client.ChatCompletion("glm-4-flash").AddMessage(ChatCompletionMessage{ + Role: RoleUser, Content: "你叫什么名字", + })) + require.NoError(t, err) + + res, err := client.FileCreate(FilePurposeBatch).SetFile(bytes.NewReader(buf.Bytes()), "batch.jsonl").Do(context.Background()) + require.NoError(t, err) + + fileID := res.FileCreateFineTuneResponse.ID + require.NotEmpty(t, fileID) + + res1, err := client.BatchCreate(). + SetInputFileID(fileID). + SetCompletionWindow(BatchCompletionWindow24h). + SetEndpoint(BatchEndpointV4ChatCompletions).Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res1.ID) + + res2, err := client.BatchGet(res1.ID).Do(context.Background()) + require.NoError(t, err) + require.Equal(t, res2.ID, res1.ID) + + res3, err := client.BatchList().Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res3.Data) + + err = client.BatchCancel(res1.ID).Do(context.Background()) + require.NoError(t, err) +} + +func TestBatchListService(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + res, err := client.BatchList().Do(context.Background()) + require.NoError(t, err) + t.Log(res) +} diff --git a/zhipu/chat_completion.go b/zhipu/chat_completion.go new file mode 100644 index 0000000..b00db85 --- /dev/null +++ b/zhipu/chat_completion.go @@ -0,0 +1,577 @@ +package zhipu + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "io" + + "github.com/go-resty/resty/v2" +) + +const ( + RoleSystem = "system" + RoleUser = "user" + RoleAssistant = "assistant" + RoleTool = "tool" + + ToolChoiceAuto = "auto" + + FinishReasonStop = "stop" + FinishReasonStopSequence = "stop_sequence" + FinishReasonToolCalls = "tool_calls" + FinishReasonLength = "length" + FinishReasonSensitive = "sensitive" + FinishReasonNetworkError = "network_error" + + ToolTypeFunction = "function" + ToolTypeWebSearch = "web_search" + ToolTypeRetrieval = "retrieval" + + MultiContentTypeText = "text" + MultiContentTypeImageURL = "image_url" + MultiContentTypeVideoURL = "video_url" + + // New in GLM-4-AllTools + ToolTypeCodeInterpreter = "code_interpreter" + ToolTypeDrawingTool = "drawing_tool" + ToolTypeWebBrowser = "web_browser" + + CodeInterpreterSandboxNone = "none" + CodeInterpreterSandboxAuto = "auto" + + ChatCompletionStatusFailed = "failed" + ChatCompletionStatusCompleted = "completed" + ChatCompletionStatusRequiresAction = "requires_action" +) + +// ChatCompletionTool is the interface for chat completion tool +type ChatCompletionTool interface { + isChatCompletionTool() +} + +// ChatCompletionToolFunction is the function for chat completion +type ChatCompletionToolFunction struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters any `json:"parameters"` +} + +func (ChatCompletionToolFunction) isChatCompletionTool() {} + +// ChatCompletionToolRetrieval is the retrieval for chat completion +type ChatCompletionToolRetrieval struct { + KnowledgeID string `json:"knowledge_id"` + PromptTemplate string `json:"prompt_template,omitempty"` +} + +func (ChatCompletionToolRetrieval) isChatCompletionTool() {} + +// ChatCompletionToolWebSearch is the web search for chat completion +type ChatCompletionToolWebSearch struct { + Enable *bool `json:"enable,omitempty"` + SearchQuery string `json:"search_query,omitempty"` + SearchResult bool `json:"search_result,omitempty"` +} + +func (ChatCompletionToolWebSearch) isChatCompletionTool() {} + +// ChatCompletionToolCodeInterpreter is the code interpreter for chat completion +// only in GLM-4-AllTools +type ChatCompletionToolCodeInterpreter struct { + Sandbox *string `json:"sandbox,omitempty"` +} + +func (ChatCompletionToolCodeInterpreter) isChatCompletionTool() {} + +// ChatCompletionToolDrawingTool is the drawing tool for chat completion +// only in GLM-4-AllTools +type ChatCompletionToolDrawingTool struct { + // no fields +} + +func (ChatCompletionToolDrawingTool) isChatCompletionTool() {} + +// ChatCompletionToolWebBrowser is the web browser for chat completion +type ChatCompletionToolWebBrowser struct { + // no fields +} + +func (ChatCompletionToolWebBrowser) isChatCompletionTool() {} + +// ChatCompletionUsage is the usage for chat completion +type ChatCompletionUsage struct { + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` +} + +// ChatCompletionWebSearch is the web search result for chat completion +type ChatCompletionWebSearch struct { + Icon string `json:"icon"` + Title string `json:"title"` + Link string `json:"link"` + Media string `json:"media"` + Content string `json:"content"` +} + +// ChatCompletionToolCallFunction is the function for chat completion tool call +type ChatCompletionToolCallFunction struct { + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments"` +} + +// ChatCompletionToolCallCodeInterpreterOutput is the output for chat completion tool call code interpreter +type ChatCompletionToolCallCodeInterpreterOutput struct { + Type string `json:"type"` + Logs string `json:"logs"` + File string `json:"file"` +} + +// ChatCompletionToolCallCodeInterpreter is the code interpreter for chat completion tool call +type ChatCompletionToolCallCodeInterpreter struct { + Input string `json:"input"` + Outputs []ChatCompletionToolCallCodeInterpreterOutput `json:"outputs"` +} + +// ChatCompletionToolCallDrawingToolOutput is the output for chat completion tool call drawing tool +type ChatCompletionToolCallDrawingToolOutput struct { + Image string `json:"image"` +} + +// ChatCompletionToolCallDrawingTool is the drawing tool for chat completion tool call +type ChatCompletionToolCallDrawingTool struct { + Input string `json:"input"` + Outputs []ChatCompletionToolCallDrawingToolOutput `json:"outputs"` +} + +// ChatCompletionToolCallWebBrowserOutput is the output for chat completion tool call web browser +type ChatCompletionToolCallWebBrowserOutput struct { + Title string `json:"title"` + Link string `json:"link"` + Content string `json:"content"` +} + +// ChatCompletionToolCallWebBrowser is the web browser for chat completion tool call +type ChatCompletionToolCallWebBrowser struct { + Input string `json:"input"` + Outputs []ChatCompletionToolCallWebBrowserOutput `json:"outputs"` +} + +// ChatCompletionToolCall is the tool call for chat completion +type ChatCompletionToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function *ChatCompletionToolCallFunction `json:"function,omitempty"` + CodeInterpreter *ChatCompletionToolCallCodeInterpreter `json:"code_interpreter,omitempty"` + DrawingTool *ChatCompletionToolCallDrawingTool `json:"drawing_tool,omitempty"` + WebBrowser *ChatCompletionToolCallWebBrowser `json:"web_browser,omitempty"` +} + +type ChatCompletionMessageType interface { + isChatCompletionMessageType() +} + +// ChatCompletionMessage is the message for chat completion +type ChatCompletionMessage struct { + Role string `json:"role"` + Content string `json:"content,omitempty"` + ToolCalls []ChatCompletionToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +func (ChatCompletionMessage) isChatCompletionMessageType() {} + +type ChatCompletionMultiContent struct { + Type string `json:"type"` + Text string `json:"text"` + ImageURL *URLItem `json:"image_url,omitempty"` + VideoURL *URLItem `json:"video_url,omitempty"` +} + +// ChatCompletionMultiMessage is the multi message for chat completion +type ChatCompletionMultiMessage struct { + Role string `json:"role"` + Content []ChatCompletionMultiContent `json:"content"` +} + +func (ChatCompletionMultiMessage) isChatCompletionMessageType() {} + +// ChatCompletionMeta is the meta for chat completion +type ChatCompletionMeta struct { + UserInfo string `json:"user_info"` + BotInfo string `json:"bot_info"` + UserName string `json:"user_name"` + BotName string `json:"bot_name"` +} + +// ChatCompletionChoice is the choice for chat completion +type ChatCompletionChoice struct { + Index int `json:"index"` + FinishReason string `json:"finish_reason"` + Delta ChatCompletionMessage `json:"delta"` // stream mode + Message ChatCompletionMessage `json:"message"` // non-stream mode +} + +// ChatCompletionResponse is the response for chat completion +type ChatCompletionResponse struct { + ID string `json:"id"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionChoice `json:"choices"` + Usage ChatCompletionUsage `json:"usage"` + WebSearch []ChatCompletionWebSearch `json:"web_search"` + // Status is the status of the chat completion, only in GLM-4-AllTools + Status string `json:"status"` +} + +// ChatCompletionStreamHandler is the handler for chat completion stream +type ChatCompletionStreamHandler func(chunk ChatCompletionResponse) error + +var ( + chatCompletionStreamPrefix = []byte("data:") + chatCompletionStreamDone = []byte("[DONE]") +) + +// chatCompletionReduceResponse reduce the chunk to the response +func chatCompletionReduceResponse(out *ChatCompletionResponse, chunk ChatCompletionResponse) { + if len(out.Choices) == 0 { + out.Choices = append(out.Choices, ChatCompletionChoice{}) + } + + // basic + out.ID = chunk.ID + out.Created = chunk.Created + out.Model = chunk.Model + + // choices + if len(chunk.Choices) != 0 { + oc := &out.Choices[0] + cc := chunk.Choices[0] + + oc.Index = cc.Index + if cc.Delta.Role != "" { + oc.Message.Role = cc.Delta.Role + } + oc.Message.Content += cc.Delta.Content + oc.Message.ToolCalls = append(oc.Message.ToolCalls, cc.Delta.ToolCalls...) + if cc.FinishReason != "" { + oc.FinishReason = cc.FinishReason + } + } + + // usage + if chunk.Usage.CompletionTokens != 0 { + out.Usage.CompletionTokens = chunk.Usage.CompletionTokens + } + if chunk.Usage.PromptTokens != 0 { + out.Usage.PromptTokens = chunk.Usage.PromptTokens + } + if chunk.Usage.TotalTokens != 0 { + out.Usage.TotalTokens = chunk.Usage.TotalTokens + } + + // web search + out.WebSearch = append(out.WebSearch, chunk.WebSearch...) +} + +// chatCompletionDecodeStream decode the sse stream of chat completion +func chatCompletionDecodeStream(r io.Reader, fn func(chunk ChatCompletionResponse) error) (err error) { + br := bufio.NewReader(r) + + for { + var line []byte + + if line, err = br.ReadBytes('\n'); err != nil { + if errors.Is(err, io.EOF) { + err = nil + } + break + } + + line = bytes.TrimSpace(line) + + if len(line) == 0 { + continue + } + + if !bytes.HasPrefix(line, chatCompletionStreamPrefix) { + continue + } + + data := bytes.TrimSpace(line[len(chatCompletionStreamPrefix):]) + + if bytes.Equal(data, chatCompletionStreamDone) { + break + } + + if len(data) == 0 { + continue + } + + var chunk ChatCompletionResponse + if err = json.Unmarshal(data, &chunk); err != nil { + return + } + if err = fn(chunk); err != nil { + return + } + } + + return +} + +// ChatCompletionStreamService is the service for chat completion stream +type ChatCompletionService struct { + client *Client + + model string + requestID *string + doSample *bool + temperature *float64 + topP *float64 + maxTokens *int + stop []string + toolChoice *string + userID *string + meta *ChatCompletionMeta + + messages []any + tools []any + + streamHandler ChatCompletionStreamHandler +} + +var ( + _ BatchSupport = &ChatCompletionService{} +) + +// NewChatCompletionService creates a new ChatCompletionService. +func NewChatCompletionService(client *Client) *ChatCompletionService { + return &ChatCompletionService{ + client: client, + } +} + +func (s *ChatCompletionService) BatchMethod() string { + return "POST" +} + +func (s *ChatCompletionService) BatchURL() string { + return BatchEndpointV4ChatCompletions +} + +func (s *ChatCompletionService) BatchBody() any { + return s.buildBody() +} + +// SetModel set the model of the chat completion +func (s *ChatCompletionService) SetModel(model string) *ChatCompletionService { + s.model = model + return s +} + +// SetMeta set the meta of the chat completion, optional +func (s *ChatCompletionService) SetMeta(meta ChatCompletionMeta) *ChatCompletionService { + s.meta = &meta + return s +} + +// SetRequestID set the request id of the chat completion, optional +func (s *ChatCompletionService) SetRequestID(requestID string) *ChatCompletionService { + s.requestID = &requestID + return s +} + +// SetTemperature set the temperature of the chat completion, optional +func (s *ChatCompletionService) SetDoSample(doSample bool) *ChatCompletionService { + s.doSample = &doSample + return s +} + +// SetTemperature set the temperature of the chat completion, optional +func (s *ChatCompletionService) SetTemperature(temperature float64) *ChatCompletionService { + s.temperature = &temperature + return s +} + +// SetTopP set the top p of the chat completion, optional +func (s *ChatCompletionService) SetTopP(topP float64) *ChatCompletionService { + s.topP = &topP + return s +} + +// SetMaxTokens set the max tokens of the chat completion, optional +func (s *ChatCompletionService) SetMaxTokens(maxTokens int) *ChatCompletionService { + s.maxTokens = &maxTokens + return s +} + +// SetStop set the stop of the chat completion, optional +func (s *ChatCompletionService) SetStop(stop ...string) *ChatCompletionService { + s.stop = stop + return s +} + +// SetToolChoice set the tool choice of the chat completion, optional +func (s *ChatCompletionService) SetToolChoice(toolChoice string) *ChatCompletionService { + s.toolChoice = &toolChoice + return s +} + +// SetUserID set the user id of the chat completion, optional +func (s *ChatCompletionService) SetUserID(userID string) *ChatCompletionService { + s.userID = &userID + return s +} + +// SetStreamHandler set the stream handler of the chat completion, optional +// this will enable the stream mode +func (s *ChatCompletionService) SetStreamHandler(handler ChatCompletionStreamHandler) *ChatCompletionService { + s.streamHandler = handler + return s +} + +// AddMessage add the message to the chat completion +func (s *ChatCompletionService) AddMessage(messages ...ChatCompletionMessageType) *ChatCompletionService { + for _, message := range messages { + s.messages = append(s.messages, message) + } + return s +} + +// AddFunction add the function to the chat completion +func (s *ChatCompletionService) AddTool(tools ...ChatCompletionTool) *ChatCompletionService { + for _, tool := range tools { + switch tool := tool.(type) { + case ChatCompletionToolFunction: + s.tools = append(s.tools, map[string]any{ + "type": ToolTypeFunction, + ToolTypeFunction: tool, + }) + case ChatCompletionToolRetrieval: + s.tools = append(s.tools, map[string]any{ + "type": ToolTypeRetrieval, + ToolTypeRetrieval: tool, + }) + case ChatCompletionToolWebSearch: + s.tools = append(s.tools, map[string]any{ + "type": ToolTypeWebSearch, + ToolTypeWebSearch: tool, + }) + case ChatCompletionToolCodeInterpreter: + s.tools = append(s.tools, map[string]any{ + "type": ToolTypeCodeInterpreter, + ToolTypeCodeInterpreter: tool, + }) + case ChatCompletionToolDrawingTool: + s.tools = append(s.tools, map[string]any{ + "type": ToolTypeDrawingTool, + ToolTypeDrawingTool: tool, + }) + case ChatCompletionToolWebBrowser: + s.tools = append(s.tools, map[string]any{ + "type": ToolTypeWebBrowser, + ToolTypeWebBrowser: tool, + }) + } + } + return s +} + +func (s *ChatCompletionService) buildBody() M { + body := map[string]any{ + "model": s.model, + "messages": s.messages, + } + if s.requestID != nil { + body["request_id"] = *s.requestID + } + if s.doSample != nil { + body["do_sample"] = *s.doSample + } + if s.temperature != nil { + body["temperature"] = *s.temperature + } + if s.topP != nil { + body["top_p"] = *s.topP + } + if s.maxTokens != nil { + body["max_tokens"] = *s.maxTokens + } + if len(s.stop) != 0 { + body["stop"] = s.stop + } + if len(s.tools) != 0 { + body["tools"] = s.tools + } + if s.toolChoice != nil { + body["tool_choice"] = *s.toolChoice + } + if s.userID != nil { + body["user_id"] = *s.userID + } + if s.meta != nil { + body["meta"] = s.meta + } + return body +} + +// Do send the request of the chat completion and return the response +func (s *ChatCompletionService) Do(ctx context.Context) (res ChatCompletionResponse, err error) { + body := s.buildBody() + + streamHandler := s.streamHandler + + if streamHandler == nil { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + //fmt.Println(u.BMagenta(u.JsonP(body)), 111) + if resp, err = s.client.request(ctx).SetBody(body).SetResult(&res).SetError(&apiError).Post("chat/completions"); err != nil { + //fmt.Println(u.BRed(err.Error()), 2221) + return + } + if resp.IsError() { + err = apiError + //fmt.Println(u.BRed(err.Error()), 2222) + return + } + //fmt.Println(u.BGreen(u.JsonP(resp.Result())), resp.Status(), resp.Status(), 333) + return + } + + // stream mode + + body["stream"] = true + + var resp *resty.Response + + if resp, err = s.client.request(ctx).SetBody(body).SetDoNotParseResponse(true).Post("chat/completions"); err != nil { + return + } + defer resp.RawBody().Close() + + if resp.IsError() { + err = errors.New(resp.Status()) + return + } + + var choice ChatCompletionChoice + + if err = chatCompletionDecodeStream(resp.RawBody(), func(chunk ChatCompletionResponse) error { + // reduce the chunk to the response + chatCompletionReduceResponse(&res, chunk) + // invoke the stream handler + return streamHandler(chunk) + }); err != nil { + return + } + + res.Choices = append(res.Choices, choice) + + return +} diff --git a/zhipu/chat_completion_test.go b/zhipu/chat_completion_test.go new file mode 100644 index 0000000..8839850 --- /dev/null +++ b/zhipu/chat_completion_test.go @@ -0,0 +1,251 @@ +package zhipu + +import ( + "context" + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestChatCompletionService(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.ChatCompletion("glm-4-flash") + s.AddMessage(ChatCompletionMessage{ + Role: RoleUser, + Content: "你好呀", + }) + res, err := s.Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res.Choices) + choice := res.Choices[0] + require.Equal(t, FinishReasonStop, choice.FinishReason) + require.NotEmpty(t, choice.Message.Content) +} + +func TestChatCompletionServiceCharGLM(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.ChatCompletion("charglm-3") + s.SetMeta( + ChatCompletionMeta{ + UserName: "啵酱", + UserInfo: "啵酱是小少爷", + BotName: "塞巴斯酱", + BotInfo: "塞巴斯酱是一个冷酷的恶魔管家", + }, + ).AddMessage(ChatCompletionMessage{ + Role: RoleUser, + Content: "早上好", + }) + res, err := s.Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res.Choices) + choice := res.Choices[0] + require.Contains(t, []string{FinishReasonStop, FinishReasonStopSequence}, choice.FinishReason) + require.NotEmpty(t, choice.Message.Content) +} + +func TestChatCompletionServiceAllToolsCodeInterpreter(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.ChatCompletion("GLM-4-AllTools") + s.AddMessage(ChatCompletionMultiMessage{ + Role: "user", + Content: []ChatCompletionMultiContent{ + { + Type: "text", + Text: "计算[5,10,20,700,99,310,978,100]的平均值和方差。", + }, + }, + }) + s.AddTool(ChatCompletionToolCodeInterpreter{ + Sandbox: Ptr(CodeInterpreterSandboxAuto), + }) + + foundInterpreterInput := false + foundInterpreterOutput := false + + s.SetStreamHandler(func(chunk ChatCompletionResponse) error { + for _, c := range chunk.Choices { + for _, tc := range c.Delta.ToolCalls { + if tc.Type == ToolTypeCodeInterpreter && tc.CodeInterpreter != nil { + if tc.CodeInterpreter.Input != "" { + foundInterpreterInput = true + } + if len(tc.CodeInterpreter.Outputs) > 0 { + foundInterpreterOutput = true + } + } + } + } + buf, _ := json.MarshalIndent(chunk, "", " ") + t.Log(string(buf)) + return nil + }) + + res, err := s.Do(context.Background()) + require.True(t, foundInterpreterInput) + require.True(t, foundInterpreterOutput) + require.NotNil(t, res) + require.NoError(t, err) +} + +func TestChatCompletionServiceAllToolsDrawingTool(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.ChatCompletion("GLM-4-AllTools") + s.AddMessage(ChatCompletionMultiMessage{ + Role: "user", + Content: []ChatCompletionMultiContent{ + { + Type: "text", + Text: "画一个正弦函数图像", + }, + }, + }) + s.AddTool(ChatCompletionToolDrawingTool{}) + + foundInput := false + foundOutput := false + outputImage := "" + + s.SetStreamHandler(func(chunk ChatCompletionResponse) error { + for _, c := range chunk.Choices { + for _, tc := range c.Delta.ToolCalls { + if tc.Type == ToolTypeDrawingTool && tc.DrawingTool != nil { + if tc.DrawingTool.Input != "" { + foundInput = true + } + if len(tc.DrawingTool.Outputs) > 0 { + foundOutput = true + } + for _, output := range tc.DrawingTool.Outputs { + if output.Image != "" { + outputImage = output.Image + } + } + } + } + } + buf, _ := json.MarshalIndent(chunk, "", " ") + t.Log(string(buf)) + return nil + }) + + res, err := s.Do(context.Background()) + require.True(t, foundInput) + require.True(t, foundOutput) + require.NotEmpty(t, outputImage) + t.Log(outputImage) + require.NotNil(t, res) + require.NoError(t, err) +} + +func TestChatCompletionServiceAllToolsWebBrowser(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.ChatCompletion("GLM-4-AllTools") + s.AddMessage(ChatCompletionMultiMessage{ + Role: "user", + Content: []ChatCompletionMultiContent{ + { + Type: "text", + Text: "搜索下本周深圳天气如何", + }, + }, + }) + s.AddTool(ChatCompletionToolWebBrowser{}) + + foundInput := false + foundOutput := false + outputContent := "" + + s.SetStreamHandler(func(chunk ChatCompletionResponse) error { + for _, c := range chunk.Choices { + for _, tc := range c.Delta.ToolCalls { + if tc.Type == ToolTypeWebBrowser && tc.WebBrowser != nil { + if tc.WebBrowser.Input != "" { + foundInput = true + } + if len(tc.WebBrowser.Outputs) > 0 { + foundOutput = true + } + for _, output := range tc.WebBrowser.Outputs { + if output.Content != "" { + outputContent = output.Content + } + } + } + } + } + buf, _ := json.MarshalIndent(chunk, "", " ") + t.Log(string(buf)) + return nil + }) + + res, err := s.Do(context.Background()) + require.True(t, foundInput) + require.True(t, foundOutput) + require.NotEmpty(t, outputContent) + t.Log(outputContent) + require.NotNil(t, res) + require.NoError(t, err) +} + +func TestChatCompletionServiceStream(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + var content string + + s := client.ChatCompletion("glm-4-flash").AddMessage(ChatCompletionMessage{ + Role: RoleUser, + Content: "你好呀", + }).SetStreamHandler(func(chunk ChatCompletionResponse) error { + content += chunk.Choices[0].Delta.Content + return nil + }) + res, err := s.Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res.Choices) + choice := res.Choices[0] + require.Equal(t, FinishReasonStop, choice.FinishReason) + require.NotEmpty(t, choice.Message.Content) + require.Equal(t, content, choice.Message.Content) +} + +func TestChatCompletionServiceVision(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.ChatCompletion("glm-4v") + s.AddMessage(ChatCompletionMultiMessage{ + Role: RoleUser, + Content: []ChatCompletionMultiContent{ + { + Type: MultiContentTypeText, + Text: "图里有什么", + }, + { + Type: MultiContentTypeImageURL, + ImageURL: &URLItem{ + URL: "https://img1.baidu.com/it/u=1369931113,3388870256&fm=253&app=138&size=w931&n=0&f=JPEG&fmt=auto?sec=1703696400&t=f3028c7a1dca43a080aeb8239f09cc2f", + }, + }, + }, + }) + res, err := s.Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res.Choices) + require.NotZero(t, res.Usage.CompletionTokens) + choice := res.Choices[0] + require.Equal(t, FinishReasonStop, choice.FinishReason) + require.NotEmpty(t, choice.Message.Content) +} diff --git a/zhipu/client.go b/zhipu/client.go new file mode 100644 index 0000000..5f2aa9b --- /dev/null +++ b/zhipu/client.go @@ -0,0 +1,291 @@ +package zhipu + +import ( + "context" + "errors" + "net/http" + "os" + "strconv" + "strings" + "time" + + "github.com/go-resty/resty/v2" + "github.com/golang-jwt/jwt/v5" +) + +const ( + envAPIKey = "ZHIPUAI_API_KEY" + envBaseURL = "ZHIPUAI_BASE_URL" + envDebug = "ZHIPUAI_DEBUG" + + defaultBaseURL = "https://open.bigmodel.cn/api/paas/v4" +) + +var ( + // ErrAPIKeyMissing is the error when the api key is missing + ErrAPIKeyMissing = errors.New("zhipu: api key is missing") + // ErrAPIKeyMalformed is the error when the api key is malformed + ErrAPIKeyMalformed = errors.New("zhipu: api key is malformed") +) + +type clientOptions struct { + baseURL string + apiKey string + client *http.Client + debug *bool +} + +// ClientOption is a function that configures the client +type ClientOption func(opts *clientOptions) + +// WithAPIKey set the api key of the client +func WithAPIKey(apiKey string) ClientOption { + return func(opts *clientOptions) { + opts.apiKey = apiKey + } +} + +// WithBaseURL set the base url of the client +func WithBaseURL(baseURL string) ClientOption { + return func(opts *clientOptions) { + opts.baseURL = baseURL + } +} + +// WithHTTPClient set the http client of the client +func WithHTTPClient(client *http.Client) ClientOption { + return func(opts *clientOptions) { + opts.client = client + } +} + +// WithDebug set the debug mode of the client +func WithDebug(debug bool) ClientOption { + return func(opts *clientOptions) { + opts.debug = new(bool) + *opts.debug = debug + } +} + +// Client is the client for zhipu ai platform +type Client struct { + client *resty.Client + debug bool + keyID string + keySecret []byte +} + +func (c *Client) createJWT() string { + timestamp := time.Now().UnixMilli() + exp := timestamp + time.Hour.Milliseconds()*24*7 + + t := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "api_key": c.keyID, + "timestamp": timestamp, + "exp": exp, + }) + t.Header = map[string]interface{}{ + "alg": "HS256", + "sign_type": "SIGN", + } + + token, err := t.SignedString(c.keySecret) + if err != nil { + panic(err) + } + return token +} + +// request creates a new resty request with the jwt token and context +func (c *Client) request(ctx context.Context) *resty.Request { + return c.client.R().SetContext(ctx).SetHeader("Authorization", c.createJWT()) +} + +// NewClient creates a new client +// It will read the api key from the environment variable ZHIPUAI_API_KEY +// It will read the base url from the environment variable ZHIPUAI_BASE_URL +func NewClient(optFns ...ClientOption) (client *Client, err error) { + var opts clientOptions + for _, optFn := range optFns { + optFn(&opts) + } + // base url + if opts.baseURL == "" { + opts.baseURL = strings.TrimSpace(os.Getenv(envBaseURL)) + } + if opts.baseURL == "" { + opts.baseURL = defaultBaseURL + } + // api key + if opts.apiKey == "" { + opts.apiKey = strings.TrimSpace(os.Getenv(envAPIKey)) + } + if opts.apiKey == "" { + err = ErrAPIKeyMissing + return + } + // debug + if opts.debug == nil { + if debugStr := strings.TrimSpace(os.Getenv(envDebug)); debugStr != "" { + if debug, err1 := strconv.ParseBool(debugStr); err1 == nil { + opts.debug = &debug + } + } + } + + keyComponents := strings.SplitN(opts.apiKey, ".", 2) + + if len(keyComponents) != 2 { + err = ErrAPIKeyMalformed + return + } + + client = &Client{ + keyID: keyComponents[0], + keySecret: []byte(keyComponents[1]), + } + + if opts.client == nil { + client.client = resty.New() + } else { + client.client = resty.NewWithClient(opts.client) + } + + client.client = client.client.SetBaseURL(opts.baseURL) + + if opts.debug != nil { + client.client.SetDebug(*opts.debug) + client.debug = *opts.debug + } + return +} + +// BatchCreate creates a new BatchCreateService. +func (c *Client) BatchCreate() *BatchCreateService { + return NewBatchCreateService(c) +} + +// BatchGet creates a new BatchGetService. +func (c *Client) BatchGet(batchID string) *BatchGetService { + return NewBatchGetService(c).SetBatchID(batchID) +} + +// BatchCancel creates a new BatchCancelService. +func (c *Client) BatchCancel(batchID string) *BatchCancelService { + return NewBatchCancelService(c).SetBatchID(batchID) +} + +// BatchList creates a new BatchListService. +func (c *Client) BatchList() *BatchListService { + return NewBatchListService(c) +} + +// ChatCompletion creates a new ChatCompletionService. +func (c *Client) ChatCompletion(model string) *ChatCompletionService { + return NewChatCompletionService(c).SetModel(model) +} + +// Embedding embeds a list of text into a vector space. +func (c *Client) Embedding(model string) *EmbeddingService { + return NewEmbeddingService(c).SetModel(model) +} + +// FileCreate creates a new FileCreateService. +func (c *Client) FileCreate(purpose string) *FileCreateService { + return NewFileCreateService(c).SetPurpose(purpose) +} + +// FileEditService creates a new FileEditService. +func (c *Client) FileEdit(documentID string) *FileEditService { + return NewFileEditService(c).SetDocumentID(documentID) +} + +// FileList creates a new FileListService. +func (c *Client) FileList(purpose string) *FileListService { + return NewFileListService(c).SetPurpose(purpose) +} + +// FileDeleteService creates a new FileDeleteService. +func (c *Client) FileDelete(documentID string) *FileDeleteService { + return NewFileDeleteService(c).SetDocumentID(documentID) +} + +// FileGetService creates a new FileGetService. +func (c *Client) FileGet(documentID string) *FileGetService { + return NewFileGetService(c).SetDocumentID(documentID) +} + +// FileDownload creates a new FileDownloadService. +func (c *Client) FileDownload(fileID string) *FileDownloadService { + return NewFileDownloadService(c).SetFileID(fileID) +} + +// FineTuneCreate creates a new fine tune create service +func (c *Client) FineTuneCreate(model string) *FineTuneCreateService { + return NewFineTuneCreateService(c).SetModel(model) +} + +// FineTuneEventList creates a new fine tune event list service +func (c *Client) FineTuneEventList(jobID string) *FineTuneEventListService { + return NewFineTuneEventListService(c).SetJobID(jobID) +} + +// FineTuneGet creates a new fine tune get service +func (c *Client) FineTuneGet(jobID string) *FineTuneGetService { + return NewFineTuneGetService(c).SetJobID(jobID) +} + +// FineTuneList creates a new fine tune list service +func (c *Client) FineTuneList() *FineTuneListService { + return NewFineTuneListService(c) +} + +// FineTuneDelete creates a new fine tune delete service +func (c *Client) FineTuneDelete(jobID string) *FineTuneDeleteService { + return NewFineTuneDeleteService(c).SetJobID(jobID) +} + +// FineTuneCancel creates a new fine tune cancel service +func (c *Client) FineTuneCancel(jobID string) *FineTuneCancelService { + return NewFineTuneCancelService(c).SetJobID(jobID) +} + +// ImageGeneration creates a new image generation service +func (c *Client) ImageGeneration(model string) *ImageGenerationService { + return NewImageGenerationService(c).SetModel(model) +} + +// KnowledgeCreate creates a new knowledge create service +func (c *Client) KnowledgeCreate() *KnowledgeCreateService { + return NewKnowledgeCreateService(c) +} + +// KnowledgeEdit creates a new knowledge edit service +func (c *Client) KnowledgeEdit(knowledgeID string) *KnowledgeEditService { + return NewKnowledgeEditService(c).SetKnowledgeID(knowledgeID) +} + +// KnowledgeList list all the knowledge +func (c *Client) KnowledgeList() *KnowledgeListService { + return NewKnowledgeListService(c) +} + +// KnowledgeDelete creates a new knowledge delete service +func (c *Client) KnowledgeDelete(knowledgeID string) *KnowledgeDeleteService { + return NewKnowledgeDeleteService(c).SetKnowledgeID(knowledgeID) +} + +// KnowledgeGet creates a new knowledge get service +func (c *Client) KnowledgeCapacity() *KnowledgeCapacityService { + return NewKnowledgeCapacityService(c) +} + +// VideoGeneration creates a new video generation service +func (c *Client) VideoGeneration(model string) *VideoGenerationService { + return NewVideoGenerationService(c).SetModel(model) +} + +// AsyncResult creates a new async result get service +func (c *Client) AsyncResult(id string) *AsyncResultService { + return NewAsyncResultService(c).SetID(id) +} diff --git a/zhipu/client_test.go b/zhipu/client_test.go new file mode 100644 index 0000000..dd4800d --- /dev/null +++ b/zhipu/client_test.go @@ -0,0 +1,17 @@ +package zhipu + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestClientR(t *testing.T) { + c, err := NewClient() + require.NoError(t, err) + // the only free api is to list fine-tuning jobs + res, err := c.request(context.Background()).Get("fine_tuning/jobs") + require.NoError(t, err) + require.True(t, res.IsSuccess()) +} diff --git a/zhipu/cog.toml b/zhipu/cog.toml new file mode 100644 index 0000000..ee92ed8 --- /dev/null +++ b/zhipu/cog.toml @@ -0,0 +1,25 @@ +from_latest_tag = false +ignore_merge_commits = false +disable_changelog = false +disable_bump_commit = false +generate_mono_repository_global_tag = true +branch_whitelist = [] +skip_ci = "[skip ci]" +skip_untracked = false +pre_bump_hooks = [] +post_bump_hooks = [] +pre_package_bump_hooks = [] +post_package_bump_hooks = [] +tag_prefix = "v" + +[git_hooks] + +[commit_types] + +[changelog] +path = "CHANGELOG.md" +authors = [] + +[bump_profiles] + +[packages] diff --git a/zhipu/embedding.go b/zhipu/embedding.go new file mode 100644 index 0000000..45c672b --- /dev/null +++ b/zhipu/embedding.go @@ -0,0 +1,87 @@ +package zhipu + +import ( + "context" + + "github.com/go-resty/resty/v2" +) + +// EmbeddingData is the data for each embedding. +type EmbeddingData struct { + Embedding []float64 `json:"embedding"` + Index int `json:"index"` + Object string `json:"object"` +} + +// EmbeddingResponse is the response from the embedding service. +type EmbeddingResponse struct { + Model string `json:"model"` + Data []EmbeddingData `json:"data"` + Object string `json:"object"` + Usage ChatCompletionUsage `json:"usage"` +} + +// EmbeddingService embeds a list of text into a vector space. +type EmbeddingService struct { + client *Client + + model string + input string +} + +var ( + _ BatchSupport = &EmbeddingService{} +) + +// NewEmbeddingService creates a new EmbeddingService. +func NewEmbeddingService(client *Client) *EmbeddingService { + return &EmbeddingService{client: client} +} + +func (s *EmbeddingService) BatchMethod() string { + return "POST" +} + +func (s *EmbeddingService) BatchURL() string { + return BatchEndpointV4Embeddings +} + +func (s *EmbeddingService) BatchBody() any { + return s.buildBody() +} + +// SetModel sets the model to use for the embedding. +func (s *EmbeddingService) SetModel(model string) *EmbeddingService { + s.model = model + return s +} + +// SetInput sets the input text to embed. +func (s *EmbeddingService) SetInput(input string) *EmbeddingService { + s.input = input + return s +} + +func (s *EmbeddingService) buildBody() M { + return M{"model": s.model, "input": s.input} +} + +func (s *EmbeddingService) Do(ctx context.Context) (res EmbeddingResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + if resp, err = s.client.request(ctx). + SetBody(s.buildBody()). + SetResult(&res). + SetError(&apiError). + Post("embeddings"); err != nil { + return + } + if resp.IsError() { + err = apiError + return + } + return +} diff --git a/zhipu/embedding_test.go b/zhipu/embedding_test.go new file mode 100644 index 0000000..46f4aeb --- /dev/null +++ b/zhipu/embedding_test.go @@ -0,0 +1,21 @@ +package zhipu + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestEmbeddingService(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + service := client.Embedding("embedding-2") + + resp, err := service.SetInput("你好").Do(context.Background()) + require.NoError(t, err) + require.NotZero(t, resp.Usage.TotalTokens) + require.NotEmpty(t, resp.Data) + require.NotEmpty(t, resp.Data[0].Embedding) +} diff --git a/zhipu/error.go b/zhipu/error.go new file mode 100644 index 0000000..ea07ad7 --- /dev/null +++ b/zhipu/error.go @@ -0,0 +1,58 @@ +package zhipu + +type APIError struct { + Code string `json:"code"` + Message string `json:"message"` +} + +func (e APIError) Error() string { + return e.Message +} + +type APIErrorResponse struct { + APIError `json:"error"` +} + +func (e APIErrorResponse) Error() string { + return e.APIError.Error() +} + +// GetAPIErrorCode returns the error code of an API error. +func GetAPIErrorCode(err error) string { + if err == nil { + return "" + } + if e, ok := err.(APIError); ok { + return e.Code + } + if e, ok := err.(APIErrorResponse); ok { + return e.Code + } + if e, ok := err.(*APIError); ok && e != nil { + return e.Code + } + if e, ok := err.(*APIErrorResponse); ok && e != nil { + return e.Code + } + return "" +} + +// GetAPIErrorMessage returns the error message of an API error. +func GetAPIErrorMessage(err error) string { + if err == nil { + return "" + } + if e, ok := err.(APIError); ok { + return e.Message + } + if e, ok := err.(APIErrorResponse); ok { + return e.Message + } + if e, ok := err.(*APIError); ok && e != nil { + return e.Message + } + if e, ok := err.(*APIErrorResponse); ok && e != nil { + return e.Message + } + return err.Error() +} diff --git a/zhipu/error_test.go b/zhipu/error_test.go new file mode 100644 index 0000000..3f2fa08 --- /dev/null +++ b/zhipu/error_test.go @@ -0,0 +1,38 @@ +package zhipu + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAPIError(t *testing.T) { + err := APIError{ + Code: "code", + Message: "message", + } + require.Equal(t, "message", err.Error()) + require.Equal(t, "code", GetAPIErrorCode(err)) + require.Equal(t, "message", GetAPIErrorMessage(err)) +} + +func TestAPIErrorResponse(t *testing.T) { + err := APIErrorResponse{ + APIError: APIError{ + Code: "code", + Message: "message", + }, + } + require.Equal(t, "message", err.Error()) + require.Equal(t, "code", GetAPIErrorCode(err)) + require.Equal(t, "message", GetAPIErrorMessage(err)) +} + +func TestAPIErrorResponseFromDoc(t *testing.T) { + var res APIErrorResponse + err := json.Unmarshal([]byte(`{"error":{"code":"1002","message":"Authorization Token非法,请确认Authorization Token正确传递。"}}`), &res) + require.NoError(t, err) + require.Equal(t, "1002", res.Code) + require.Equal(t, "1002", GetAPIErrorCode(res)) +} diff --git a/zhipu/file.go b/zhipu/file.go new file mode 100644 index 0000000..e8b66af --- /dev/null +++ b/zhipu/file.go @@ -0,0 +1,541 @@ +package zhipu + +import ( + "context" + "errors" + "io" + "os" + "path/filepath" + "strconv" + + "github.com/go-resty/resty/v2" +) + +const ( + FilePurposeFineTune = "fine-tune" + FilePurposeRetrieval = "retrieval" + FilePurposeBatch = "batch" + + KnowledgeTypeArticle = 1 + KnowledgeTypeQADocument = 2 + KnowledgeTypeQASpreadsheet = 3 + KnowledgeTypeProductDatabaseSpreadsheet = 4 + KnowledgeTypeCustom = 5 +) + +// FileCreateService is a service to create a file. +type FileCreateService struct { + client *Client + + purpose string + + localFile string + file io.Reader + filename string + + customSeparator *string + sentenceSize *int + knowledgeID *string +} + +// FileCreateKnowledgeSuccessInfo is the success info of the FileCreateKnowledgeResponse. +type FileCreateKnowledgeSuccessInfo struct { + Filename string `json:"fileName"` + DocumentID string `json:"documentId"` +} + +// FileCreateKnowledgeFailedInfo is the failed info of the FileCreateKnowledgeResponse. +type FileCreateKnowledgeFailedInfo struct { + Filename string `json:"fileName"` + FailReason string `json:"failReason"` +} + +// FileCreateKnowledgeResponse is the response of the FileCreateService. +type FileCreateKnowledgeResponse struct { + SuccessInfos []FileCreateKnowledgeSuccessInfo `json:"successInfos"` + FailedInfos []FileCreateKnowledgeFailedInfo `json:"failedInfos"` +} + +// FileCreateFineTuneResponse is the response of the FileCreateService. +type FileCreateFineTuneResponse struct { + Bytes int64 `json:"bytes"` + CreatedAt int64 `json:"created_at"` + Filename string `json:"filename"` + Object string `json:"object"` + Purpose string `json:"purpose"` + ID string `json:"id"` +} + +// FileCreateResponse is the response of the FileCreateService. +type FileCreateResponse struct { + FileCreateFineTuneResponse + FileCreateKnowledgeResponse +} + +// NewFileCreateService creates a new FileCreateService. +func NewFileCreateService(client *Client) *FileCreateService { + return &FileCreateService{client: client} +} + +// SetLocalFile sets the local_file parameter of the FileCreateService. +func (s *FileCreateService) SetLocalFile(localFile string) *FileCreateService { + s.localFile = localFile + return s +} + +// SetFile sets the file parameter of the FileCreateService. +func (s *FileCreateService) SetFile(file io.Reader, filename string) *FileCreateService { + s.file = file + s.filename = filename + return s +} + +// SetPurpose sets the purpose parameter of the FileCreateService. +func (s *FileCreateService) SetPurpose(purpose string) *FileCreateService { + s.purpose = purpose + return s +} + +// SetCustomSeparator sets the custom_separator parameter of the FileCreateService. +func (s *FileCreateService) SetCustomSeparator(customSeparator string) *FileCreateService { + s.customSeparator = &customSeparator + return s +} + +// SetSentenceSize sets the sentence_size parameter of the FileCreateService. +func (s *FileCreateService) SetSentenceSize(sentenceSize int) *FileCreateService { + s.sentenceSize = &sentenceSize + return s +} + +// SetKnowledgeID sets the knowledge_id parameter of the FileCreateService. +func (s *FileCreateService) SetKnowledgeID(knowledgeID string) *FileCreateService { + s.knowledgeID = &knowledgeID + return s +} + +// Do makes the request. +func (s *FileCreateService) Do(ctx context.Context) (res FileCreateResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + body := map[string]string{"purpose": s.purpose} + + if s.customSeparator != nil { + body["custom_separator"] = *s.customSeparator + } + if s.sentenceSize != nil { + body["sentence_size"] = strconv.Itoa(*s.sentenceSize) + } + if s.knowledgeID != nil { + body["knowledge_id"] = *s.knowledgeID + } + + file, filename := s.file, s.filename + + if file == nil && s.localFile != "" { + var f *os.File + if f, err = os.Open(s.localFile); err != nil { + return + } + defer f.Close() + + file = f + filename = filepath.Base(s.localFile) + } + + if file == nil { + err = errors.New("no file specified") + return + } + + if resp, err = s.client.request(ctx). + SetFileReader("file", filename, file). + SetMultipartFormData(body). + SetResult(&res). + SetError(&apiError). + Post("files"); err != nil { + return + } + + if resp.IsError() { + err = apiError + return + } + + return +} + +// FileEditService is a service to edit a file. +type FileEditService struct { + client *Client + + documentID string + + knowledgeType *int + customSeparator []string + sentenceSize *int +} + +// NewFileEditService creates a new FileEditService. +func NewFileEditService(client *Client) *FileEditService { + return &FileEditService{client: client} +} + +// SetDocumentID sets the document_id parameter of the FileEditService. +func (s *FileEditService) SetDocumentID(documentID string) *FileEditService { + s.documentID = documentID + return s +} + +// SetKnowledgeType sets the knowledge_type parameter of the FileEditService. +func (s *FileEditService) SetKnowledgeType(knowledgeType int) *FileEditService { + s.knowledgeType = &knowledgeType + return s +} + +// SetSentenceSize sets the sentence_size parameter of the FileEditService. +func (s *FileEditService) SetCustomSeparator(customSeparator ...string) *FileEditService { + s.customSeparator = customSeparator + return s +} + +// SetSentenceSize sets the sentence_size parameter of the FileEditService. +func (s *FileEditService) SetSentenceSize(sentenceSize int) *FileEditService { + s.sentenceSize = &sentenceSize + return s +} + +// Do makes the request. +func (s *FileEditService) Do(ctx context.Context) (err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + body := M{} + + if s.knowledgeType != nil { + body["knowledge_type"] = strconv.Itoa(*s.knowledgeType) + } + if len(s.customSeparator) > 0 { + body["custom_separator"] = s.customSeparator + } + if s.sentenceSize != nil { + body["sentence_size"] = strconv.Itoa(*s.sentenceSize) + } + + if resp, err = s.client.request(ctx). + SetPathParam("document_id", s.documentID). + SetBody(body). + SetError(&apiError). + Put("document/{document_id}"); err != nil { + return + } + + if resp.IsError() { + err = apiError + return + } + + return +} + +// FileListService is a service to list files. +type FileListService struct { + client *Client + + purpose string + + knowledgeID *string + page *int + limit *int + after *string + orderAsc *bool +} + +// FileFailInfo is the failed info of the FileListKnowledgeItem. +type FileFailInfo struct { + EmbeddingCode int `json:"embedding_code"` + EmbeddingMsg string `json:"embedding_msg"` +} + +// FileListKnowledgeItem is the item of the FileListKnowledgeResponse. +type FileListKnowledgeItem struct { + ID string `json:"id"` + Name string `json:"name"` + URL string `json:"url"` + Length int64 `json:"length"` + SentenceSize int64 `json:"sentence_size"` + CustomSeparator []string `json:"custom_separator"` + EmbeddingStat int `json:"embedding_stat"` + FailInfo *FileFailInfo `json:"failInfo"` + WordNum int64 `json:"word_num"` + ParseImage int `json:"parse_image"` +} + +// FileListKnowledgeResponse is the response of the FileListService. +type FileListKnowledgeResponse struct { + Total int `json:"total"` + List []FileListKnowledgeItem `json:"list"` +} + +// FileListFineTuneItem is the item of the FileListFineTuneResponse. +type FileListFineTuneItem struct { + Bytes int64 `json:"bytes"` + CreatedAt int64 `json:"created_at"` + Filename string `json:"filename"` + ID string `json:"id"` + Object string `json:"object"` + Purpose string `json:"purpose"` +} + +// FileListFineTuneResponse is the response of the FileListService. +type FileListFineTuneResponse struct { + Object string `json:"object"` + Data []FileListFineTuneItem `json:"data"` +} + +// FileListResponse is the response of the FileListService. +type FileListResponse struct { + FileListKnowledgeResponse + FileListFineTuneResponse +} + +// NewFileListService creates a new FileListService. +func NewFileListService(client *Client) *FileListService { + return &FileListService{client: client} +} + +// SetPurpose sets the purpose parameter of the FileListService. +func (s *FileListService) SetPurpose(purpose string) *FileListService { + s.purpose = purpose + return s +} + +// SetKnowledgeID sets the knowledge_id parameter of the FileListService. +func (s *FileListService) SetKnowledgeID(knowledgeID string) *FileListService { + s.knowledgeID = &knowledgeID + return s +} + +// SetPage sets the page parameter of the FileListService. +func (s *FileListService) SetPage(page int) *FileListService { + s.page = &page + return s +} + +// SetLimit sets the limit parameter of the FileListService. +func (s *FileListService) SetLimit(limit int) *FileListService { + s.limit = &limit + return s +} + +// SetAfter sets the after parameter of the FileListService. +func (s *FileListService) SetAfter(after string) *FileListService { + s.after = &after + return s +} + +// SetOrder sets the order parameter of the FileListService. +func (s *FileListService) SetOrder(asc bool) *FileListService { + s.orderAsc = &asc + return s +} + +// Do makes the request. +func (s *FileListService) Do(ctx context.Context) (res FileListResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + m := map[string]string{ + "purpose": s.purpose, + } + + if s.knowledgeID != nil { + m["knowledge_id"] = *s.knowledgeID + } + if s.page != nil { + m["page"] = strconv.Itoa(*s.page) + } + if s.limit != nil { + m["limit"] = strconv.Itoa(*s.limit) + } + if s.after != nil { + m["after"] = *s.after + } + if s.orderAsc != nil { + if *s.orderAsc { + m["order"] = "asc" + } else { + m["order"] = "desc" + } + } + + if resp, err = s.client.request(ctx). + SetQueryParams(m). + SetResult(&res). + SetError(&apiError). + Get("files"); err != nil { + return + } + + if resp.IsError() { + err = apiError + return + } + + return +} + +// FileDeleteService is a service to delete a file. +type FileDeleteService struct { + client *Client + documentID string +} + +// NewFileDeleteService creates a new FileDeleteService. +func NewFileDeleteService(client *Client) *FileDeleteService { + return &FileDeleteService{client: client} +} + +// SetDocumentID sets the document_id parameter of the FileDeleteService. +func (s *FileDeleteService) SetDocumentID(documentID string) *FileDeleteService { + s.documentID = documentID + return s +} + +// Do makes the request. +func (s *FileDeleteService) Do(ctx context.Context) (err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + if resp, err = s.client.request(ctx). + SetPathParam("document_id", s.documentID). + SetError(&apiError). + Delete("document/{document_id}"); err != nil { + return + } + + if resp.IsError() { + err = apiError + return + } + + return +} + +// FileGetService is a service to get a file. +type FileGetService struct { + client *Client + documentID string +} + +// FileGetResponse is the response of the FileGetService. +type FileGetResponse = FileListKnowledgeItem + +// NewFileGetService creates a new FileGetService. +func NewFileGetService(client *Client) *FileGetService { + return &FileGetService{client: client} +} + +// SetDocumentID sets the document_id parameter of the FileGetService. +func (s *FileGetService) SetDocumentID(documentID string) *FileGetService { + s.documentID = documentID + return s +} + +// Do makes the request. +func (s *FileGetService) Do(ctx context.Context) (res FileGetResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + if resp, err = s.client.request(ctx). + SetPathParam("document_id", s.documentID). + SetResult(&res). + SetError(&apiError). + Get("document/{document_id}"); err != nil { + return + } + + if resp.IsError() { + err = apiError + return + } + + return +} + +// FileDownloadService is a service to download a file. +type FileDownloadService struct { + client *Client + + fileID string + + writer io.Writer + filename string +} + +// NewFileDownloadService creates a new FileDownloadService. +func NewFileDownloadService(client *Client) *FileDownloadService { + return &FileDownloadService{client: client} +} + +// SetFileID sets the file_id parameter of the FileDownloadService. +func (s *FileDownloadService) SetFileID(fileID string) *FileDownloadService { + s.fileID = fileID + return s +} + +// SetOutput sets the output parameter of the FileDownloadService. +func (s *FileDownloadService) SetOutput(w io.Writer) *FileDownloadService { + s.writer = w + return s +} + +// SetOutputFile sets the output_file parameter of the FileDownloadService. +func (s *FileDownloadService) SetOutputFile(filename string) *FileDownloadService { + s.filename = filename + return s +} + +// Do makes the request. +func (s *FileDownloadService) Do(ctx context.Context) (err error) { + var resp *resty.Response + + writer := s.writer + + if writer == nil && s.filename != "" { + var f *os.File + if f, err = os.Create(s.filename); err != nil { + return + } + defer f.Close() + + writer = f + } + + if writer == nil { + return errors.New("no output specified") + } + + if resp, err = s.client.request(ctx). + SetDoNotParseResponse(true). + SetPathParam("file_id", s.fileID). + Get("files/{file_id}/content"); err != nil { + return + } + defer resp.RawBody().Close() + + _, err = io.Copy(writer, resp.RawBody()) + + return +} diff --git a/zhipu/file_test.go b/zhipu/file_test.go new file mode 100644 index 0000000..3d035ae --- /dev/null +++ b/zhipu/file_test.go @@ -0,0 +1,71 @@ +package zhipu + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFileServiceFineTune(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.FileCreate(FilePurposeFineTune) + s.SetLocalFile(filepath.Join("testdata", "test-file.jsonl")) + + res, err := s.Do(context.Background()) + require.NoError(t, err) + require.NotZero(t, res.Bytes) + require.NotZero(t, res.CreatedAt) + require.NotEmpty(t, res.ID) +} + +func TestFileServiceKnowledge(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.FileCreate(FilePurposeRetrieval) + s.SetKnowledgeID(os.Getenv("TEST_KNOWLEDGE_ID")) + s.SetLocalFile(filepath.Join("testdata", "test-file.txt")) + + res, err := s.Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res.SuccessInfos) + require.NotEmpty(t, res.SuccessInfos[0].DocumentID) + require.NotEmpty(t, res.SuccessInfos[0].Filename) + + documentID := res.SuccessInfos[0].DocumentID + + res2, err := client.FileGet(documentID).Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res2.ID) + + err = client.FileEdit(documentID).SetKnowledgeType(KnowledgeTypeCustom).Do(context.Background()) + require.True(t, err == nil || GetAPIErrorCode(err) == "10019") + + err = client.FileDelete(res.SuccessInfos[0].DocumentID).Do(context.Background()) + require.True(t, err == nil || GetAPIErrorCode(err) == "10019") +} + +func TestFileListServiceKnowledge(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.FileList(FilePurposeRetrieval).SetKnowledgeID(os.Getenv("TEST_KNOWLEDGE_ID")) + res, err := s.Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res.List) +} + +func TestFileListServiceFineTune(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.FileList(FilePurposeFineTune) + res, err := s.Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res.Data) +} diff --git a/zhipu/fine_tune.go b/zhipu/fine_tune.go new file mode 100644 index 0000000..f2121ea --- /dev/null +++ b/zhipu/fine_tune.go @@ -0,0 +1,456 @@ +package zhipu + +import ( + "context" + "strconv" + + "github.com/go-resty/resty/v2" +) + +const ( + HyperParameterAuto = "auto" + + FineTuneStatusCreate = "create" + FineTuneStatusValidatingFiles = "validating_files" + FineTuneStatusQueued = "queued" + FineTuneStatusRunning = "running" + FineTuneStatusSucceeded = "succeeded" + FineTuneStatusFailed = "failed" + FineTuneStatusCancelled = "cancelled" +) + +// FineTuneItem is the item of the FineTune +type FineTuneItem struct { + ID string `json:"id"` + RequestID string `json:"request_id"` + FineTunedModel string `json:"fine_tuned_model"` + Status string `json:"status"` + Object string `json:"object"` + TrainingFile string `json:"training_file"` + ValidationFile string `json:"validation_file"` + Error APIError `json:"error"` +} + +// FineTuneCreateService creates a new fine tune +type FineTuneCreateService struct { + client *Client + + model string + trainingFile string + validationFile *string + + learningRateMultiplier *StringOr[float64] + batchSize *StringOr[int] + nEpochs *StringOr[int] + + suffix *string + requestID *string +} + +// FineTuneCreateResponse is the response of the FineTuneCreateService +type FineTuneCreateResponse = FineTuneItem + +// NewFineTuneCreateService creates a new FineTuneCreateService +func NewFineTuneCreateService(client *Client) *FineTuneCreateService { + return &FineTuneCreateService{ + client: client, + } +} + +// SetModel sets the model parameter +func (s *FineTuneCreateService) SetModel(model string) *FineTuneCreateService { + s.model = model + return s +} + +// SetTrainingFile sets the trainingFile parameter +func (s *FineTuneCreateService) SetTrainingFile(trainingFile string) *FineTuneCreateService { + s.trainingFile = trainingFile + return s +} + +// SetValidationFile sets the validationFile parameter +func (s *FineTuneCreateService) SetValidationFile(validationFile string) *FineTuneCreateService { + s.validationFile = &validationFile + return s +} + +// SetLearningRateMultiplier sets the learningRateMultiplier parameter +func (s *FineTuneCreateService) SetLearningRateMultiplier(learningRateMultiplier float64) *FineTuneCreateService { + s.learningRateMultiplier = &StringOr[float64]{} + s.learningRateMultiplier.SetValue(learningRateMultiplier) + return s +} + +// SetLearningRateMultiplierAuto sets the learningRateMultiplier parameter to auto +func (s *FineTuneCreateService) SetLearningRateMultiplierAuto() *FineTuneCreateService { + s.learningRateMultiplier = &StringOr[float64]{} + s.learningRateMultiplier.SetString(HyperParameterAuto) + return s +} + +// SetBatchSize sets the batchSize parameter +func (s *FineTuneCreateService) SetBatchSize(batchSize int) *FineTuneCreateService { + s.batchSize = &StringOr[int]{} + s.batchSize.SetValue(batchSize) + return s +} + +// SetBatchSizeAuto sets the batchSize parameter to auto +func (s *FineTuneCreateService) SetBatchSizeAuto() *FineTuneCreateService { + s.batchSize = &StringOr[int]{} + s.batchSize.SetString(HyperParameterAuto) + return s +} + +// SetNEpochs sets the nEpochs parameter +func (s *FineTuneCreateService) SetNEpochs(nEpochs int) *FineTuneCreateService { + s.nEpochs = &StringOr[int]{} + s.nEpochs.SetValue(nEpochs) + return s +} + +// SetNEpochsAuto sets the nEpochs parameter to auto +func (s *FineTuneCreateService) SetNEpochsAuto() *FineTuneCreateService { + s.nEpochs = &StringOr[int]{} + s.nEpochs.SetString(HyperParameterAuto) + return s +} + +// SetSuffix sets the suffix parameter +func (s *FineTuneCreateService) SetSuffix(suffix string) *FineTuneCreateService { + s.suffix = &suffix + return s +} + +// SetRequestID sets the requestID parameter +func (s *FineTuneCreateService) SetRequestID(requestID string) *FineTuneCreateService { + s.requestID = &requestID + return s +} + +// Do makes the request +func (s *FineTuneCreateService) Do(ctx context.Context) (res FineTuneCreateResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + body := M{ + "model": s.model, + "training_file": s.trainingFile, + } + + if s.validationFile != nil { + body["validation_file"] = *s.validationFile + } + if s.suffix != nil { + body["suffix"] = *s.suffix + } + if s.requestID != nil { + body["request_id"] = *s.requestID + } + if s.learningRateMultiplier != nil || s.batchSize != nil || s.nEpochs != nil { + hp := M{} + if s.learningRateMultiplier != nil { + hp["learning_rate_multiplier"] = s.learningRateMultiplier + } + if s.batchSize != nil { + hp["batch_size"] = s.batchSize + } + if s.nEpochs != nil { + hp["n_epochs"] = s.nEpochs + } + body["hyperparameters"] = hp + } + + if resp, err = s.client.request(ctx). + SetBody(body). + SetResult(&res). + SetError(&apiError). + Post("fine_tuning/jobs"); err != nil { + return + } + if resp.IsError() { + err = apiError + return + } + return +} + +// FineTuneEventListService creates a new fine tune event list +type FineTuneEventListService struct { + client *Client + + jobID string + + limit *int + after *string +} + +// FineTuneEventData is the data of the FineTuneEventItem +type FineTuneEventData struct { + Acc float64 `json:"acc"` + Loss float64 `json:"loss"` + CurrentSteps int64 `json:"current_steps"` + RemainingTime string `json:"remaining_time"` + ElapsedTime string `json:"elapsed_time"` + TotalSteps int64 `json:"total_steps"` + Epoch int64 `json:"epoch"` + TrainedTokens int64 `json:"trained_tokens"` + LearningRate float64 `json:"learning_rate"` +} + +// FineTuneEventItem is the item of the FineTuneEventListResponse +type FineTuneEventItem struct { + ID string `json:"id"` + Type string `json:"type"` + Level string `json:"level"` + Message string `json:"message"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Data FineTuneEventData `json:"data"` +} + +// FineTuneEventListResponse is the response of the FineTuneEventListService +type FineTuneEventListResponse struct { + Data []FineTuneEventItem `json:"data"` + HasMore bool `json:"has_more"` + Object string `json:"object"` +} + +// NewFineTuneEventListService creates a new FineTuneEventListService +func NewFineTuneEventListService(client *Client) *FineTuneEventListService { + return &FineTuneEventListService{ + client: client, + } +} + +// SetJobID sets the jobID parameter +func (s *FineTuneEventListService) SetJobID(jobID string) *FineTuneEventListService { + s.jobID = jobID + return s +} + +// SetLimit sets the limit parameter +func (s *FineTuneEventListService) SetLimit(limit int) *FineTuneEventListService { + s.limit = &limit + return s +} + +// SetAfter sets the after parameter +func (s *FineTuneEventListService) SetAfter(after string) *FineTuneEventListService { + s.after = &after + return s +} + +// Do makes the request +func (s *FineTuneEventListService) Do(ctx context.Context) (res FineTuneEventListResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + req := s.client.request(ctx) + + if s.limit != nil { + req.SetQueryParam("limit", strconv.Itoa(*s.limit)) + } + if s.after != nil { + req.SetQueryParam("after", *s.after) + } + + if resp, err = req. + SetPathParam("job_id", s.jobID). + SetResult(&res). + SetError(&apiError). + Get("fine_tuning/jobs/{job_id}/events"); err != nil { + return + } + if resp.IsError() { + err = apiError + return + } + return +} + +// FineTuneGetService creates a new fine tune get +type FineTuneGetService struct { + client *Client + jobID string +} + +// NewFineTuneGetService creates a new FineTuneGetService +func NewFineTuneGetService(client *Client) *FineTuneGetService { + return &FineTuneGetService{ + client: client, + } +} + +// SetJobID sets the jobID parameter +func (s *FineTuneGetService) SetJobID(jobID string) *FineTuneGetService { + s.jobID = jobID + return s +} + +// Do makes the request +func (s *FineTuneGetService) Do(ctx context.Context) (res FineTuneItem, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + if resp, err = s.client.request(ctx). + SetPathParam("job_id", s.jobID). + SetResult(&res). + SetError(&apiError). + Get("fine_tuning/jobs/{job_id}"); err != nil { + return + } + if resp.IsError() { + err = apiError + return + } + return +} + +// FineTuneListService creates a new fine tune list +type FineTuneListService struct { + client *Client + + limit *int + after *string +} + +// FineTuneListResponse is the response of the FineTuneListService +type FineTuneListResponse struct { + Data []FineTuneItem `json:"data"` + Object string `json:"object"` +} + +// NewFineTuneListService creates a new FineTuneListService +func NewFineTuneListService(client *Client) *FineTuneListService { + return &FineTuneListService{ + client: client, + } +} + +// SetLimit sets the limit parameter +func (s *FineTuneListService) SetLimit(limit int) *FineTuneListService { + s.limit = &limit + return s +} + +// SetAfter sets the after parameter +func (s *FineTuneListService) SetAfter(after string) *FineTuneListService { + s.after = &after + return s +} + +// Do makes the request +func (s *FineTuneListService) Do(ctx context.Context) (res FineTuneListResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + req := s.client.request(ctx) + if s.limit != nil { + req.SetQueryParam("limit", strconv.Itoa(*s.limit)) + } + if s.after != nil { + req.SetQueryParam("after", *s.after) + } + + if resp, err = req. + SetResult(&res). + SetError(&apiError). + Get("fine_tuning/jobs"); err != nil { + return + } + if resp.IsError() { + err = apiError + return + } + return +} + +// FineTuneDeleteService creates a new fine tune delete +type FineTuneDeleteService struct { + client *Client + jobID string +} + +// NewFineTuneDeleteService creates a new FineTuneDeleteService +func NewFineTuneDeleteService(client *Client) *FineTuneDeleteService { + return &FineTuneDeleteService{ + client: client, + } +} + +// SetJobID sets the jobID parameter +func (s *FineTuneDeleteService) SetJobID(jobID string) *FineTuneDeleteService { + s.jobID = jobID + return s +} + +// Do makes the request +func (s *FineTuneDeleteService) Do(ctx context.Context) (res FineTuneItem, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + if resp, err = s.client.request(ctx). + SetPathParam("job_id", s.jobID). + SetResult(&res). + SetError(&apiError). + Delete("fine_tuning/jobs/{job_id}"); err != nil { + return + } + if resp.IsError() { + err = apiError + return + } + return +} + +// FineTuneCancelService creates a new fine tune cancel +type FineTuneCancelService struct { + client *Client + jobID string +} + +// NewFineTuneCancelService creates a new FineTuneCancelService +func NewFineTuneCancelService(client *Client) *FineTuneCancelService { + return &FineTuneCancelService{ + client: client, + } +} + +// SetJobID sets the jobID parameter +func (s *FineTuneCancelService) SetJobID(jobID string) *FineTuneCancelService { + s.jobID = jobID + return s +} + +// Do makes the request +func (s *FineTuneCancelService) Do(ctx context.Context) (res FineTuneItem, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + if resp, err = s.client.request(ctx). + SetPathParam("job_id", s.jobID). + SetResult(&res). + SetError(&apiError). + Post("fine_tuning/jobs/{job_id}/cancel"); err != nil { + return + } + if resp.IsError() { + err = apiError + return + } + return +} diff --git a/zhipu/fine_tune_test.go b/zhipu/fine_tune_test.go new file mode 100644 index 0000000..2adb1fa --- /dev/null +++ b/zhipu/fine_tune_test.go @@ -0,0 +1,3 @@ +package zhipu + +// tests not available since lack of budget to test it diff --git a/zhipu/image_generation.go b/zhipu/image_generation.go new file mode 100644 index 0000000..32486a9 --- /dev/null +++ b/zhipu/image_generation.go @@ -0,0 +1,110 @@ +package zhipu + +import ( + "context" + + "github.com/go-resty/resty/v2" +) + +// ImageGenerationService creates a new image generation +type ImageGenerationService struct { + client *Client + + model string + prompt string + size string + userID string +} + +var ( + _ BatchSupport = &ImageGenerationService{} +) + +// ImageGenerationResponse is the response of the ImageGenerationService +type ImageGenerationResponse struct { + Created int64 `json:"created"` + Data []URLItem `json:"data"` +} + +// NewImageGenerationService creates a new ImageGenerationService +func NewImageGenerationService(client *Client) *ImageGenerationService { + return &ImageGenerationService{ + client: client, + } +} + +func (s *ImageGenerationService) BatchMethod() string { + return "POST" +} + +func (s *ImageGenerationService) BatchURL() string { + return BatchEndpointV4ImagesGenerations +} + +func (s *ImageGenerationService) BatchBody() any { + return s.buildBody() +} + +// SetModel sets the model parameter +func (s *ImageGenerationService) SetModel(model string) *ImageGenerationService { + s.model = model + return s +} + +// SetPrompt sets the prompt parameter +func (s *ImageGenerationService) SetPrompt(prompt string) *ImageGenerationService { + s.prompt = prompt + return s +} + +func (s *ImageGenerationService) SetSize(size string) *ImageGenerationService { + s.size = size + return s +} + +// SetUserID sets the userID parameter +func (s *ImageGenerationService) SetUserID(userID string) *ImageGenerationService { + s.userID = userID + return s +} + +func (s *ImageGenerationService) buildBody() M { + body := M{ + "model": s.model, + "prompt": s.prompt, + } + + if s.userID != "" { + body["user_id"] = s.userID + } + + if s.size != "" { + body["size"] = s.size + } + + return body +} + +func (s *ImageGenerationService) Do(ctx context.Context) (res ImageGenerationResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + body := s.buildBody() + + if resp, err = s.client.request(ctx). + SetBody(body). + SetResult(&res). + SetError(&apiError). + Post("images/generations"); err != nil { + return + } + + if resp.IsError() { + err = apiError + return + } + + return +} diff --git a/zhipu/image_generation_test.go b/zhipu/image_generation_test.go new file mode 100644 index 0000000..7ab7807 --- /dev/null +++ b/zhipu/image_generation_test.go @@ -0,0 +1,21 @@ +package zhipu + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestImageGenerationService(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.ImageGeneration("cogview-3") + s.SetPrompt("一只可爱的小猫") + + res, err := s.Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res.Data) + t.Log(res.Data[0].URL) +} diff --git a/zhipu/knowledge.go b/zhipu/knowledge.go new file mode 100644 index 0000000..7f43808 --- /dev/null +++ b/zhipu/knowledge.go @@ -0,0 +1,299 @@ +package zhipu + +import ( + "context" + "strconv" + + "github.com/go-resty/resty/v2" +) + +const ( + KnowledgeEmbeddingIDEmbedding2 = 3 +) + +// KnowledgeCreateService creates a new knowledge +type KnowledgeCreateService struct { + client *Client + + embeddingID int + name string + description *string +} + +// KnowledgeCreateResponse is the response of the KnowledgeCreateService +type KnowledgeCreateResponse = IDItem + +// NewKnowledgeCreateService creates a new KnowledgeCreateService +func NewKnowledgeCreateService(client *Client) *KnowledgeCreateService { + return &KnowledgeCreateService{ + client: client, + } +} + +// SetEmbeddingID sets the embedding id of the knowledge +func (s *KnowledgeCreateService) SetEmbeddingID(embeddingID int) *KnowledgeCreateService { + s.embeddingID = embeddingID + return s +} + +// SetName sets the name of the knowledge +func (s *KnowledgeCreateService) SetName(name string) *KnowledgeCreateService { + s.name = name + return s +} + +// SetDescription sets the description of the knowledge +func (s *KnowledgeCreateService) SetDescription(description string) *KnowledgeCreateService { + s.description = &description + return s +} + +// Do creates the knowledge +func (s *KnowledgeCreateService) Do(ctx context.Context) (res KnowledgeCreateResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + body := M{ + "name": s.name, + "embedding_id": s.embeddingID, + } + if s.description != nil { + body["description"] = *s.description + } + if resp, err = s.client.request(ctx). + SetBody(body). + SetResult(&res). + SetError(&apiError). + Post("knowledge"); err != nil { + return + } + if resp.IsError() { + err = apiError + return + } + return +} + +// KnowledgeEditService edits a knowledge +type KnowledgeEditService struct { + client *Client + + knowledgeID string + + embeddingID *int + name *string + description *string +} + +// NewKnowledgeEditService creates a new KnowledgeEditService +func NewKnowledgeEditService(client *Client) *KnowledgeEditService { + return &KnowledgeEditService{ + client: client, + } +} + +// SetKnowledgeID sets the knowledge id +func (s *KnowledgeEditService) SetKnowledgeID(knowledgeID string) *KnowledgeEditService { + s.knowledgeID = knowledgeID + return s +} + +// SetName sets the name of the knowledge +func (s *KnowledgeEditService) SetName(name string) *KnowledgeEditService { + s.name = &name + return s +} + +// SetEmbeddingID sets the embedding id of the knowledge +func (s *KnowledgeEditService) SetEmbeddingID(embeddingID int) *KnowledgeEditService { + s.embeddingID = &embeddingID + return s +} + +// SetDescription sets the description of the knowledge +func (s *KnowledgeEditService) SetDescription(description string) *KnowledgeEditService { + s.description = &description + return s +} + +// Do edits the knowledge +func (s *KnowledgeEditService) Do(ctx context.Context) (err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + body := M{} + if s.name != nil { + body["name"] = *s.name + } + if s.description != nil { + body["description"] = *s.description + } + if s.embeddingID != nil { + body["embedding_id"] = *s.embeddingID + } + if resp, err = s.client.request(ctx). + SetPathParam("knowledge_id", s.knowledgeID). + SetBody(body). + SetError(&apiError). + Put("knowledge/{knowledge_id}"); err != nil { + return + } + if resp.IsError() { + err = apiError + return + } + return +} + +// KnowledgeListService lists the knowledge +type KnowledgeListService struct { + client *Client + + page *int + size *int +} + +// KnowledgeItem is an item in the knowledge list +type KnowledgeItem struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Icon string `json:"icon"` + Background string `json:"background"` + EmbeddingID int `json:"embedding_id"` + CustomIdentifier string `json:"custom_identifier"` + WordNum int64 `json:"word_num"` + Length int64 `json:"length"` + DocumentSize int64 `json:"document_size"` +} + +// KnowledgeListResponse is the response of the KnowledgeListService +type KnowledgeListResponse struct { + List []KnowledgeItem `json:"list"` + Total int `json:"total"` +} + +// NewKnowledgeListService creates a new KnowledgeListService +func NewKnowledgeListService(client *Client) *KnowledgeListService { + return &KnowledgeListService{client: client} +} + +// SetPage sets the page of the knowledge list +func (s *KnowledgeListService) SetPage(page int) *KnowledgeListService { + s.page = &page + return s +} + +// SetSize sets the size of the knowledge list +func (s *KnowledgeListService) SetSize(size int) *KnowledgeListService { + s.size = &size + return s +} + +// Do lists the knowledge +func (s *KnowledgeListService) Do(ctx context.Context) (res KnowledgeListResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + req := s.client.request(ctx) + if s.page != nil { + req.SetQueryParam("page", strconv.Itoa(*s.page)) + } + if s.size != nil { + req.SetQueryParam("size", strconv.Itoa(*s.size)) + } + if resp, err = req. + SetResult(&res). + SetError(&apiError). + Get("knowledge"); err != nil { + return + } + if resp.IsError() { + err = apiError + return + } + return +} + +// KnowledgeDeleteService deletes a knowledge +type KnowledgeDeleteService struct { + client *Client + + knowledgeID string +} + +// NewKnowledgeDeleteService creates a new KnowledgeDeleteService +func NewKnowledgeDeleteService(client *Client) *KnowledgeDeleteService { + return &KnowledgeDeleteService{ + client: client, + } +} + +// SetKnowledgeID sets the knowledge id +func (s *KnowledgeDeleteService) SetKnowledgeID(knowledgeID string) *KnowledgeDeleteService { + s.knowledgeID = knowledgeID + return s +} + +// Do deletes the knowledge +func (s *KnowledgeDeleteService) Do(ctx context.Context) (err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + if resp, err = s.client.request(ctx). + SetPathParam("knowledge_id", s.knowledgeID). + SetError(&apiError). + Delete("knowledge/{knowledge_id}"); err != nil { + return + } + if resp.IsError() { + err = apiError + return + } + return +} + +// KnowledgeCapacityService query the capacity of the knowledge +type KnowledgeCapacityService struct { + client *Client +} + +// KnowledgeCapacityItem is an item in the knowledge capacity +type KnowledgeCapacityItem struct { + WordNum int64 `json:"word_num"` + Length int64 `json:"length"` +} + +// KnowledgeCapacityResponse is the response of the KnowledgeCapacityService +type KnowledgeCapacityResponse struct { + Used KnowledgeCapacityItem `json:"used"` + Total KnowledgeCapacityItem `json:"total"` +} + +// SetKnowledgeID sets the knowledge id +func NewKnowledgeCapacityService(client *Client) *KnowledgeCapacityService { + return &KnowledgeCapacityService{client: client} +} + +// Do query the capacity of the knowledge +func (s *KnowledgeCapacityService) Do(ctx context.Context) (res KnowledgeCapacityResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + if resp, err = s.client.request(ctx). + SetResult(&res). + SetError(&apiError). + Get("knowledge/capacity"); err != nil { + return + } + if resp.IsError() { + err = apiError + return + } + return +} diff --git a/zhipu/knowledge_test.go b/zhipu/knowledge_test.go new file mode 100644 index 0000000..a330e74 --- /dev/null +++ b/zhipu/knowledge_test.go @@ -0,0 +1,50 @@ +package zhipu + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestKnowledgeCapacity(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.KnowledgeCapacity() + res, err := s.Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res.Total.Length) + require.NotEmpty(t, res.Total.WordNum) +} + +func TestKnowledgeServiceAll(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.KnowledgeCreate() + s.SetName("test") + s.SetDescription("test description") + s.SetEmbeddingID(KnowledgeEmbeddingIDEmbedding2) + + res, err := s.Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res.ID) + + s2 := client.KnowledgeList() + res2, err := s2.Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res2.List) + require.Equal(t, res.ID, res2.List[0].ID) + + s3 := client.KnowledgeEdit(res.ID) + s3.SetDescription("test description 2") + s3.SetName("test 2") + s3.SetEmbeddingID(KnowledgeEmbeddingIDEmbedding2) + err = s3.Do(context.Background()) + require.NoError(t, err) + + s4 := client.KnowledgeDelete(res.ID) + err = s4.Do(context.Background()) + require.NoError(t, err) +} diff --git a/zhipu/string_or.go b/zhipu/string_or.go new file mode 100644 index 0000000..f0a72e6 --- /dev/null +++ b/zhipu/string_or.go @@ -0,0 +1,54 @@ +package zhipu + +import ( + "bytes" + "encoding/json" +) + +// StringOr is a struct that can be either a string or a value of type T. +type StringOr[T any] struct { + String *string + Value *T +} + +var ( + _ json.Marshaler = StringOr[float64]{} + _ json.Unmarshaler = &StringOr[float64]{} +) + +// SetString sets the string value of the struct. +func (f *StringOr[T]) SetString(v string) { + f.String = &v + f.Value = nil +} + +// SetValue sets the value of the struct. +func (f *StringOr[T]) SetValue(v T) { + f.String = nil + f.Value = &v +} + +func (f StringOr[T]) MarshalJSON() ([]byte, error) { + if f.Value != nil { + return json.Marshal(f.Value) + } + return json.Marshal(f.String) +} + +func (f *StringOr[T]) UnmarshalJSON(data []byte) error { + if len(data) == 0 { + return nil + } + if bytes.Equal(data, []byte("null")) { + return nil + } + if data[0] == '"' { + f.String = new(string) + f.Value = nil + return json.Unmarshal(data, f.String) + } else { + f.Value = new(T) + f.String = nil + return json.Unmarshal(data, f.Value) + } +} diff --git a/zhipu/string_or_test.go b/zhipu/string_or_test.go new file mode 100644 index 0000000..cdc27da --- /dev/null +++ b/zhipu/string_or_test.go @@ -0,0 +1,37 @@ +package zhipu + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestStringOr(t *testing.T) { + data := struct { + Item *StringOr[int] `json:"item,omitempty"` + }{} + data.Item = &StringOr[int]{} + data.Item.SetString("test") + + b, err := json.Marshal(data) + require.NoError(t, err) + require.Equal(t, `{"item":"test"}`, string(b)) + + data.Item.SetValue(1) + b, err = json.Marshal(data) + require.NoError(t, err) + require.Equal(t, `{"item":1}`, string(b)) + + err = json.Unmarshal([]byte(`{"item":"test2"}`), &data) + require.NoError(t, err) + require.NotNil(t, data.Item.String) + require.Nil(t, data.Item.Value) + require.Equal(t, "test2", *data.Item.String) + + err = json.Unmarshal([]byte(`{"item":2}`), &data) + require.NoError(t, err) + require.Nil(t, data.Item.String) + require.NotNil(t, data.Item.Value) + require.Equal(t, 2, *data.Item.Value) +} diff --git a/zhipu/testdata/test-file.jsonl b/zhipu/testdata/test-file.jsonl new file mode 100644 index 0000000..4737cfa --- /dev/null +++ b/zhipu/testdata/test-file.jsonl @@ -0,0 +1,3 @@ +{"messages": [{"role": "system", "content": "你是一位乐于助人,知识渊博的全能AI助手。"},{"role": "user", "content": "请介绍一下AlphaGo的基本原理。"}, {"role": "assistant", "content": "AlphaGo利用深度神经网络与树搜索算法相结合,通过对弈数据进行训练,逐步提升游戏水平,能够对围棋局面进行判断并选择最佳落子位置。"}]} +{"messages": [{"role": "system", "content": "你是一位乐于助人,知识渊博的全能AI助手。"},{"role": "user", "content": "请列举一些健康饮食的重要原则。"}, {"role": "assistant", "content": "健康饮食的重要原则包括:膳食应营养均衡,合理摄入各种营养素;多吃新鲜蔬果,适量饮水;少吃糖、油、盐,忌吃过量快餐;养成规律生活作息习惯,适量运动。"}]} +{"messages": [{"role": "system", "content": "你是一位乐于助人,知识渊博的全能AI助手。"},{"role": "user", "content": "请给出一些改善睡眠质量的建议。"}, {"role": "assistant", "content": "改善睡眠质量的建议包括:保证睡眠时间,避免睡眠不足;睡前放松身心,远离刺激;养成规律作息时间,不要频繁熬夜;适量运动,但不要运动过度;睡前可以喝一杯热牛奶等温和饮料。"}]} \ No newline at end of file diff --git a/zhipu/testdata/test-file.txt b/zhipu/testdata/test-file.txt new file mode 100644 index 0000000..ff3bb63 --- /dev/null +++ b/zhipu/testdata/test-file.txt @@ -0,0 +1 @@ +The quick brown fox jumps over the lazy dog \ No newline at end of file diff --git a/zhipu/util.go b/zhipu/util.go new file mode 100644 index 0000000..4912b68 --- /dev/null +++ b/zhipu/util.go @@ -0,0 +1,22 @@ +package zhipu + +// URLItem is a struct that contains a URL. +type URLItem struct { + URL string `json:"url,omitempty"` +} + +// IDItem is a struct that contains an ID. +type IDItem struct { + ID string `json:"id,omitempty"` +} + +// Ptr returns a pointer to the value passed in. +// Example: +// +// web_search_enable = zhipu.Ptr(false) +func Ptr[T any](v T) *T { + return &v +} + +// M is a shorthand for map[string]any. +type M = map[string]any diff --git a/zhipu/util_test.go b/zhipu/util_test.go new file mode 100644 index 0000000..52b5052 --- /dev/null +++ b/zhipu/util_test.go @@ -0,0 +1,3 @@ +package zhipu + +// nothing to test diff --git a/zhipu/video_generation.go b/zhipu/video_generation.go new file mode 100644 index 0000000..3ae3279 --- /dev/null +++ b/zhipu/video_generation.go @@ -0,0 +1,125 @@ +package zhipu + +import ( + "context" + + "github.com/go-resty/resty/v2" +) + +const ( + VideoGenerationTaskStatusProcessing = "PROCESSING" + VideoGenerationTaskStatusSuccess = "SUCCESS" + VideoGenerationTaskStatusFail = "FAIL" +) + +// VideoGenerationService creates a new video generation +type VideoGenerationService struct { + client *Client + + model string + prompt string + userID string + imageURL string + requestID string +} + +var ( + _ BatchSupport = &VideoGenerationService{} +) + +// VideoGenerationResponse is the response of the VideoGenerationService +type VideoGenerationResponse struct { + RequestID string `json:"request_id"` + ID string `json:"id"` + Model string `json:"model"` + TaskStatus string `json:"task_status"` +} + +func NewVideoGenerationService(client *Client) *VideoGenerationService { + return &VideoGenerationService{ + client: client, + } +} + +func (s *VideoGenerationService) BatchMethod() string { + return "POST" +} + +func (s *VideoGenerationService) BatchURL() string { + return BatchEndpointV4VideosGenerations +} + +func (s *VideoGenerationService) BatchBody() any { + return s.buildBody() +} + +// SetModel sets the model parameter +func (s *VideoGenerationService) SetModel(model string) *VideoGenerationService { + s.model = model + return s +} + +// SetPrompt sets the prompt parameter +func (s *VideoGenerationService) SetPrompt(prompt string) *VideoGenerationService { + s.prompt = prompt + return s +} + +// SetUserID sets the userID parameter +func (s *VideoGenerationService) SetUserID(userID string) *VideoGenerationService { + s.userID = userID + return s +} + +// SetImageURL sets the imageURL parameter +func (s *VideoGenerationService) SetImageURL(imageURL string) *VideoGenerationService { + s.imageURL = imageURL + return s +} + +// SetRequestID sets the requestID parameter +func (s *VideoGenerationService) SetRequestID(requestID string) *VideoGenerationService { + s.requestID = requestID + return s +} + +func (s *VideoGenerationService) buildBody() M { + body := M{ + "model": s.model, + "prompt": s.prompt, + } + if s.userID != "" { + body["user_id"] = s.userID + } + if s.imageURL != "" { + body["image_url"] = s.imageURL + } + if s.requestID != "" { + body["request_id"] = s.requestID + } + return body +} + +func (s *VideoGenerationService) Do(ctx context.Context) (res VideoGenerationResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + body := s.buildBody() + + if resp, err = s.client.request(ctx). + SetBody(body). + SetResult(&res). + SetError(&apiError). + Post("videos/generations"); err != nil { + return + } + + if resp.IsError() { + err = apiError + return + } + + return +} diff --git a/zhipu/video_generation_test.go b/zhipu/video_generation_test.go new file mode 100644 index 0000000..8dc5ab4 --- /dev/null +++ b/zhipu/video_generation_test.go @@ -0,0 +1,38 @@ +package zhipu + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestVideoGeneration(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.VideoGeneration("cogvideox") + s.SetPrompt("一只可爱的小猫") + + res, err := s.Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res.TaskStatus) + require.NotEmpty(t, res.ID) + t.Log(res.ID) + + for { + res, err := client.AsyncResult(res.ID).Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res.TaskStatus) + if res.TaskStatus == VideoGenerationTaskStatusSuccess { + require.NotEmpty(t, res.VideoResult) + t.Log(res.VideoResult[0].URL) + t.Log(res.VideoResult[0].CoverImageURL) + } + if res.TaskStatus != VideoGenerationTaskStatusProcessing { + break + } + time.Sleep(time.Second * 5) + } +} diff --git a/zhipu/wechat-donation.png b/zhipu/wechat-donation.png new file mode 100644 index 0000000..f6519f6 Binary files /dev/null and b/zhipu/wechat-donation.png differ