This commit is contained in:
STARAI\Star 2024-09-07 23:14:12 +08:00
commit 29b0faf61b
41 changed files with 4712 additions and 0 deletions

9
LICENSE Normal file
View File

@ -0,0 +1,9 @@
MIT License
Copyright (c) 2024 apigo
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

83
aigc.go Normal file
View File

@ -0,0 +1,83 @@
package zhipu
import (
"apigo.cc/ai/zhipu/zhipu"
"context"
"errors"
"time"
)
func (ag *Agent) FastMakeImage(prompt, size, refImage string) ([]string, error) {
return ag.MakeImage(ModelCogView3Plus, prompt, size, refImage)
}
func (ag *Agent) BestMakeImage(prompt, size, refImage string) ([]string, error) {
return ag.MakeImage(ModelCogView3, prompt, size, refImage)
}
func (ag *Agent) MakeImage(model, prompt, size, refImage string) ([]string, error) {
c, err := zhipu.NewClient(zhipu.WithAPIKey(ag.config.ApiKey), zhipu.WithBaseURL(ag.config.Endpoint))
if err != nil {
return nil, err
}
cc := c.ImageGeneration(model).SetPrompt(prompt)
if size != "" {
cc.SetSize(size)
}
if r, err := cc.Do(context.Background()); err == nil {
results := make([]string, 0)
for _, item := range r.Data {
results = append(results, item.URL)
}
return results, nil
} else {
return nil, err
}
}
func (ag *Agent) FastMakeVideo(prompt, size, refImage string) ([]string, []string, error) {
return ag.MakeVideo(ModelCogVideoX, prompt, size, refImage)
}
func (ag *Agent) BestMakeVideo(prompt, size, refImage string) ([]string, []string, error) {
return ag.MakeVideo(ModelCogVideoX, prompt, size, refImage)
}
func (ag *Agent) MakeVideo(model, prompt, size, refImage string) ([]string, []string, error) {
c, err := zhipu.NewClient(zhipu.WithAPIKey(ag.config.ApiKey), zhipu.WithBaseURL(ag.config.Endpoint))
if err != nil {
return nil, nil, err
}
cc := c.VideoGeneration(model).SetPrompt(prompt)
if refImage != "" {
cc.SetImageURL(refImage)
}
if resp, err := cc.Do(context.Background()); err == nil {
for i := 0; i < 1200; i++ {
r, err := c.AsyncResult(resp.ID).Do(context.Background())
if err != nil {
return nil, nil, err
}
if r.TaskStatus == zhipu.VideoGenerationTaskStatusSuccess {
covers := make([]string, 0)
results := make([]string, 0)
for _, item := range r.VideoResult {
results = append(results, item.URL)
covers = append(covers, item.CoverImageURL)
}
return results, covers, nil
}
if r.TaskStatus == zhipu.VideoGenerationTaskStatusFail {
return nil, nil, errors.New("fail on task " + resp.ID)
}
time.Sleep(3 * time.Second)
}
return nil, nil, errors.New("timeout on task " + resp.ID)
} else {
return nil, nil, err
}
}

140
chat.go Normal file
View File

@ -0,0 +1,140 @@
package zhipu
import (
"apigo.cc/ai/agent"
"apigo.cc/ai/zhipu/zhipu"
"context"
"strings"
)
func (ag *Agent) FastAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) {
return ag.Ask(messages, &agent.ChatModelConfig{
Model: ModelGLM4Flash,
}, callback)
}
func (ag *Agent) LongAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) {
return ag.Ask(messages, &agent.ChatModelConfig{
Model: ModelGLM4Long,
}, callback)
}
func (ag *Agent) BatterAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) {
return ag.Ask(messages, &agent.ChatModelConfig{
Model: ModelGLM4Plus,
}, callback)
}
func (ag *Agent) BestAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) {
return ag.Ask(messages, &agent.ChatModelConfig{
Model: ModelGLM40520,
}, callback)
}
func (ag *Agent) MultiAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) {
return ag.Ask(messages, &agent.ChatModelConfig{
Model: ModelGLM4VPlus,
}, callback)
}
func (ag *Agent) BestMultiAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) {
return ag.Ask(messages, &agent.ChatModelConfig{
Model: ModelGLM4V,
}, callback)
}
func (ag *Agent) CodeInterpreterAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) {
return ag.Ask(messages, &agent.ChatModelConfig{
Model: ModelGLM4AllTools,
Tools: map[string]any{agent.ToolCodeInterpreter: nil},
}, callback)
}
func (ag *Agent) WebSearchAsk(messages []agent.ChatMessage, callback func(answer string)) (string, agent.TokenUsage, error) {
return ag.Ask(messages, &agent.ChatModelConfig{
Model: ModelGLM4AllTools,
Tools: map[string]any{agent.ToolWebSearch: nil},
}, callback)
}
func (ag *Agent) Ask(messages []agent.ChatMessage, config *agent.ChatModelConfig, callback func(answer string)) (string, agent.TokenUsage, error) {
if config == nil {
config = &agent.ChatModelConfig{}
}
config.SetDefault(&ag.config.DefaultChatModelConfig)
c, err := zhipu.NewClient(zhipu.WithAPIKey(ag.config.ApiKey), zhipu.WithBaseURL(ag.config.Endpoint))
if err != nil {
return "", agent.TokenUsage{}, err
}
cc := c.ChatCompletion(config.GetModel())
for _, msg := range messages {
var contents []zhipu.ChatCompletionMultiContent
if msg.Contents != nil {
contents = make([]zhipu.ChatCompletionMultiContent, len(msg.Contents))
for j, inPart := range msg.Contents {
part := zhipu.ChatCompletionMultiContent{}
part.Type = NameMap[inPart.Type]
switch inPart.Type {
case agent.TypeText:
part.Text = inPart.Content
case agent.TypeImage:
part.ImageURL = &zhipu.URLItem{URL: inPart.Content}
case agent.TypeVideo:
part.VideoURL = &zhipu.URLItem{URL: inPart.Content}
}
contents[j] = part
}
}
cc.AddMessage(zhipu.ChatCompletionMultiMessage{
Role: NameMap[msg.Role],
Content: contents,
})
}
for name := range config.GetTools() {
switch name {
case agent.ToolCodeInterpreter:
cc.AddTool(zhipu.ChatCompletionToolCodeInterpreter{})
case agent.ToolWebSearch:
cc.AddTool(zhipu.ChatCompletionToolWebBrowser{})
}
}
if config.GetMaxTokens() != 0 {
cc.SetMaxTokens(config.GetMaxTokens())
}
if config.GetTemperature() != 0 {
cc.SetTemperature(config.GetTemperature())
}
if config.GetTopP() != 0 {
cc.SetTopP(config.GetTopP())
}
if callback != nil {
cc.SetStreamHandler(func(r2 zhipu.ChatCompletionResponse) error {
if r2.Choices != nil {
for _, ch := range r2.Choices {
text := ch.Delta.Content
callback(text)
}
}
return nil
})
}
if r, err := cc.Do(context.Background()); err == nil {
results := make([]string, 0)
if r.Choices != nil {
for _, ch := range r.Choices {
results = append(results, ch.Message.Content)
}
}
return strings.Join(results, ""), agent.TokenUsage{
AskTokens: r.Usage.PromptTokens,
AnswerTokens: r.Usage.CompletionTokens,
TotalTokens: r.Usage.TotalTokens,
}, nil
} else {
return "", agent.TokenUsage{}, err
}
}

60
config.go Normal file
View File

@ -0,0 +1,60 @@
package zhipu
import (
"apigo.cc/ai/agent"
"apigo.cc/ai/zhipu/zhipu"
)
type Agent struct {
config agent.APIConfig
}
var NameMap = map[string]string{
agent.TypeText: zhipu.MultiContentTypeText,
agent.TypeImage: zhipu.MultiContentTypeImageURL,
agent.TypeVideo: zhipu.MultiContentTypeVideoURL,
agent.RoleSystem: zhipu.RoleSystem,
agent.RoleUser: zhipu.RoleUser,
agent.RoleAssistant: zhipu.RoleAssistant,
agent.RoleTool: zhipu.RoleTool,
}
const (
ModelGLM4Plus = "GLM-4-Plus"
ModelGLM40520 = "GLM-4-0520"
ModelGLM4Long = "GLM-4-Long"
ModelGLM4AirX = "GLM-4-AirX"
ModelGLM4Air = "GLM-4-Air"
ModelGLM4Flash = "GLM-4-Flash"
ModelGLM4AllTools = "GLM-4-AllTools"
ModelGLM4 = "GLM-4"
ModelGLM4VPlus = "GLM-4V-Plus"
ModelGLM4V = "GLM-4V"
ModelCogVideoX = "CogVideoX"
ModelCogView3Plus = "CogView-3-Plus"
ModelCogView3 = "CogView-3"
ModelEmbedding3 = "Embedding-3"
ModelEmbedding2 = "Embedding-2"
ModelCharGLM3 = "CharGLM-3"
ModelEmohaa = "Emohaa"
ModelCodeGeeX4 = "CodeGeeX-4"
)
func (ag *Agent) Support() agent.Support {
return agent.Support{
Ask: true,
AskWithImage: true,
AskWithVideo: true,
AskWithCodeInterpreter: true,
AskWithWebSearch: true,
MakeImage: true,
MakeVideo: true,
Models: []string{ModelGLM4Plus, ModelGLM40520, ModelGLM4Long, ModelGLM4AirX, ModelGLM4Air, ModelGLM4Flash, ModelGLM4AllTools, ModelGLM4, ModelGLM4VPlus, ModelGLM4V, ModelCogVideoX, ModelCogView3Plus, ModelCogView3, ModelEmbedding3, ModelEmbedding2, ModelCharGLM3, ModelEmohaa, ModelCodeGeeX4},
}
}
func init() {
agent.RegisterAgentMaker("zhipu", func(config agent.APIConfig) agent.Agent {
return &Agent{config: config}
})
}

17
go.mod Normal file
View File

@ -0,0 +1,17 @@
module apigo.cc/ai/zhipu
go 1.22
require (
apigo.cc/ai/agent v0.0.1
github.com/go-resty/resty/v2 v2.14.0
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/stretchr/testify v1.9.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/net v0.29.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

108
zhipu/CHANGELOG.md Normal file
View File

@ -0,0 +1,108 @@
# Changelog
All notable changes to this project will be documented in this file. See [conventional commits](https://www.conventionalcommits.org/) for commit guidelines.
- - -
## v0.1.2 - 2024-08-15
#### Bug Fixes
- add FinishReasonStopSequence - (01b4201) - GUO YANKE
#### Documentation
- update README.md [skip ci] - (e48a88b) - GUO YANKE
#### Features
- add videos/generations - (7261999) - GUO YANKE
#### Miscellaneous Chores
- relaxing go version to 1.18 - (6acc17c) - GUO YANKE
- - -
## v0.1.1 - 2024-07-17
#### Documentation
- update README.md [skip ci] - (695432a) - GUO YANKE
#### Features
- add support for GLM-4-AllTools - (9627a36) - GUO YANKE
- - -
## v0.1.0 - 2024-06-28
#### Bug Fixes
- rename client function for batch list - (40ac05f) - GUO YANKE
#### Documentation
- update README.md [skip ci] - (6ce5754) - GUO YANKE
#### Features
- add knowledge capacity service - (4ce62b3) - GUO YANKE
#### Refactoring
- update batch service - (b92d438) - GUO YANKE
- update chat completion service - (19dd77f) - GUO YANKE
- update embedding service - (c1bbc2d) - GUO YANKE
- update file services - (7ef4d87) - GUO YANKE
- update fine tune services, using APIError - (15aed88) - GUO YANKE
- update fine tune services - (664523b) - GUO YANKE
- update image generation service - (a18e028) - GUO YANKE
- update knowledge services - (c7bfb73) - GUO YANKE
- - -
## v0.0.6 - 2024-06-28
#### Features
- add batch support for result reader - (c062095) - GUO YANKE
- add fine tune services - (f172f51) - GUO YANKE
- add knowledge service - (09792b5) - GUO YANKE
- - -
## v0.0.5 - 2024-06-28
#### Bug Fixes
- api error parsing - (60a17f4) - GUO YANKE
#### Features
- add batch service - (389aec3) - GUO YANKE
- add batch support for chat completions, image generations and embeddings - (c017ffd) - GUO YANKE
- add file edit/get/delete service - (8a4d309) - GUO YANKE
- add file create serivce - (6d2140b) - GUO YANKE
- - -
## v0.0.4 - 2024-06-26
#### Bug Fixes
- remove Client.R(), hide resty for future removal - (dc2a4ca) - GUO YANKE
#### Features
- add meta support for charglm - (fdd20e7) - GUO YANKE
- add client option to custom http client - (c62d6a9) - GUO YANKE
- - -
## v0.0.3 - 2024-06-26
#### Features
- add image generation service - (9f3f54f) - GUO YANKE
- add support for vision models - (2dcd82a) - GUO YANKE
- add embedding service - (f57806a) - GUO YANKE
- - -
## v0.0.2 - 2024-06-26
#### Bug Fixes
- **(deps)** update golang-jwt/jwt to v5 - (2f76a57) - GUO YANKE
#### Features
- add constants for roles - (3d08a72) - GUO YANKE
- - -
## v0.0.1 - 2024-06-26
#### Bug Fixes
- add json tag "omitempty" to various types - (bf81097) - GUO YANKE
#### Continuous Integration
- add github action workflows for testing - (5a64987) - GUO YANKE
#### Documentation
- update README.md [skip ci] - (d504f57) - GUO YANKE
#### Features
- add chat completion in stream mode - (130fe1d) - GUO YANKE
- add chat completion in non-stream mode - (2326e37) - GUO YANKE
- support debug option while creating client - (0f104d8) - GUO YANKE
- add APIError and APIErrorResponse - (1886d85) - GUO YANKE
- add client struct - (710d8e8) - GUO YANKE
#### Refactoring
- change signature of Client#createJWT since there is no reason to fail - (f0d7887) - GUO YANKE
#### Tests
- add client_test.go - (a3fc217) - GUO YANKE
- - -
Changelog generated by [cocogitto](https://github.com/cocogitto/cocogitto).

21
zhipu/LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 Yanke G.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

280
zhipu/README.md Normal file
View File

@ -0,0 +1,280 @@
# zhipu
[![Go Reference](https://pkg.go.dev/badge/github.com/yankeguo/zhipu.svg)](https://pkg.go.dev/github.com/yankeguo/zhipu)
[![go](https://github.com/yankeguo/zhipu/actions/workflows/go.yml/badge.svg)](https://github.com/yankeguo/zhipu/actions/workflows/go.yml)
[中文文档](README.zh.md)
A 3rd-Party Golang Client Library for Zhipu AI Platform
## Usage
### Install the package
```bash
go get -u github.com/yankeguo/zhipu
```
### Create a client
```go
// this will use environment variables ZHIPUAI_API_KEY
client, err := zhipu.NewClient()
// or you can specify the API key
client, err = zhipu.NewClient(zhipu.WithAPIKey("your api key"))
```
### Use the client
**ChatCompletion**
```go
service := client.ChatCompletion("glm-4-flash").
AddMessage(zhipu.ChatCompletionMessage{
Role: "user",
Content: "你好",
})
res, err := service.Do(context.Background())
if err != nil {
zhipu.GetAPIErrorCode(err) // get the API error code
} else {
println(res.Choices[0].Message.Content)
}
```
**ChatCompletion (Stream)**
```go
service := client.ChatCompletion("glm-4-flash").
AddMessage(zhipu.ChatCompletionMessage{
Role: "user",
Content: "你好",
}).SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
println(chunk.Choices[0].Delta.Content)
return nil
})
res, err := service.Do(context.Background())
if err != nil {
zhipu.GetAPIErrorCode(err) // get the API error code
} else {
// this package will combine the stream chunks and build a final result mimicking the non-streaming API
println(res.Choices[0].Message.Content)
}
```
**ChatCompletion (Stream with GLM-4-AllTools)**
```go
// CodeInterpreter
s := client.ChatCompletion("GLM-4-AllTools")
s.AddMessage(zhipu.ChatCompletionMultiMessage{
Role: "user",
Content: []zhipu.ChatCompletionMultiContent{
{
Type: "text",
Text: "计算[5,10,20,700,99,310,978,100]的平均值和方差。",
},
},
})
s.AddTool(zhipu.ChatCompletionToolCodeInterpreter{
Sandbox: zhipu.Ptr(CodeInterpreterSandboxAuto),
})
s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
for _, c := range chunk.Choices {
for _, tc := range c.Delta.ToolCalls {
if tc.Type == ToolTypeCodeInterpreter && tc.CodeInterpreter != nil {
if tc.CodeInterpreter.Input != "" {
// DO SOMETHING
}
if len(tc.CodeInterpreter.Outputs) > 0 {
// DO SOMETHING
}
}
}
}
return nil
})
// WebBrowser
// CAUTION: NOT 'WebSearch'
s := client.ChatCompletion("GLM-4-AllTools")
s.AddMessage(zhipu.ChatCompletionMultiMessage{
Role: "user",
Content: []zhipu.ChatCompletionMultiContent{
{
Type: "text",
Text: "搜索下本周深圳天气如何",
},
},
})
s.AddTool(zhipu.ChatCompletionToolWebBrowser{})
s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
for _, c := range chunk.Choices {
for _, tc := range c.Delta.ToolCalls {
if tc.Type == ToolTypeWebBrowser && tc.WebBrowser != nil {
if tc.WebBrowser.Input != "" {
// DO SOMETHING
}
if len(tc.WebBrowser.Outputs) > 0 {
// DO SOMETHING
}
}
}
}
return nil
})
s.Do(context.Background())
// DrawingTool
s := client.ChatCompletion("GLM-4-AllTools")
s.AddMessage(zhipu.ChatCompletionMultiMessage{
Role: "user",
Content: []zhipu.ChatCompletionMultiContent{
{
Type: "text",
Text: "画一个正弦函数图像",
},
},
})
s.AddTool(zhipu.ChatCompletionToolDrawingTool{})
s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
for _, c := range chunk.Choices {
for _, tc := range c.Delta.ToolCalls {
if tc.Type == ToolTypeDrawingTool && tc.DrawingTool != nil {
if tc.DrawingTool.Input != "" {
// DO SOMETHING
}
if len(tc.DrawingTool.Outputs) > 0 {
// DO SOMETHING
}
}
}
}
return nil
})
s.Do(context.Background())
```
**Embedding**
```go
service := client.Embedding("embedding-v2").SetInput("你好呀")
service.Do(context.Background())
```
**Image Generation**
```go
service := client.ImageGeneration("cogview-3").SetPrompt("一只可爱的小猫咪")
service.Do(context.Background())
```
**Video Generation**
```go
service := client.VideoGeneration("cogvideox").SetPrompt("一只可爱的小猫咪")
resp, err := service.Do(context.Background())
for {
result, err := client.AsyncResult(resp.ID).Do(context.Background())
if result.TaskStatus == zhipu.VideoGenerationTaskStatusSuccess {
_ = result.VideoResult[0].URL
_ = result.VideoResult[0].CoverImageURL
break
}
if result.TaskStatus != zhipu.VideoGenerationTaskStatusProcessing {
break
}
time.Sleep(5 * time.Second)
}
```
**Upload File (Retrieval)**
```go
service := client.FileCreate(zhipu.FilePurposeRetrieval)
service.SetLocalFile(filepath.Join("testdata", "test-file.txt"))
service.SetKnowledgeID("your-knowledge-id")
service.Do(context.Background())
```
**Upload File (Fine-Tune)**
```go
service := client.FileCreate(zhipu.FilePurposeFineTune)
service.SetLocalFile(filepath.Join("testdata", "test-file.jsonl"))
service.Do(context.Background())
```
**Batch Create**
```go
service := client.BatchCreate().
SetInputFileID("fileid").
SetCompletionWindow(zhipu.BatchCompletionWindow24h).
SetEndpoint(BatchEndpointV4ChatCompletions)
service.Do(context.Background())
```
**Knowledge Base**
```go
client.KnowledgeCreate("")
client.KnowledgeEdit("")
```
**Fine Tune**
```go
client.FineTuneCreate("")
```
### Batch Support
**Batch File Writer**
```go
f, err := os.OpenFile("batch.jsonl", os.O_CREATE|os.O_WRONLY, 0644)
bw := zhipu.NewBatchFileWriter(f)
bw.Add("action_1", client.ChatCompletion("glm-4-flash").
AddMessage(zhipu.ChatCompletionMessage{
Role: "user",
Content: "你好",
}))
bw.Add("action_2", client.Embedding("embedding-v2").SetInput("你好呀"))
bw.Add("action_3", client.ImageGeneration("cogview-3").SetPrompt("一只可爱的小猫咪"))
```
**Batch Result Reader**
```go
br := zhipu.NewBatchResultReader[zhipu.ChatCompletionResponse](r)
for {
var res zhipu.BatchResult[zhipu.ChatCompletionResponse]
err := br.Read(&res)
if err != nil {
break
}
}
```
## Donation
Executing unit tests will actually call the ChatGLM API and consume my quota. Please donate and thank you for your support!
<img src="./wechat-donation.png" width="180"/>
## Credits
GUO YANKE, MIT License

278
zhipu/README.zh.md Normal file
View File

@ -0,0 +1,278 @@
# zhipu
[![Go Reference](https://pkg.go.dev/badge/github.com/yankeguo/zhipu.svg)](https://pkg.go.dev/github.com/yankeguo/zhipu)
[![go](https://github.com/yankeguo/zhipu/actions/workflows/go.yml/badge.svg)](https://github.com/yankeguo/zhipu/actions/workflows/go.yml)
Zhipu AI 平台第三方 Golang 客户端库
## 用法
### 安装库
```bash
go get -u github.com/yankeguo/zhipu
```
### 创建客户端
```go
// 默认使用环境变量 ZHIPUAI_API_KEY
client, err := zhipu.NewClient()
// 或者手动指定密钥
client, err = zhipu.NewClient(zhipu.WithAPIKey("your api key"))
```
### 使用客户端
**ChatCompletion(大语言模型)**
```go
service := client.ChatCompletion("glm-4-flash").
AddMessage(zhipu.ChatCompletionMessage{
Role: "user",
Content: "你好",
})
res, err := service.Do(context.Background())
if err != nil {
zhipu.GetAPIErrorCode(err) // get the API error code
} else {
println(res.Choices[0].Message.Content)
}
```
**ChatCompletion(流式调用大语言模型)**
```go
service := client.ChatCompletion("glm-4-flash").
AddMessage(zhipu.ChatCompletionMessage{
Role: "user",
Content: "你好",
}).SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
println(chunk.Choices[0].Delta.Content)
return nil
})
res, err := service.Do(context.Background())
if err != nil {
zhipu.GetAPIErrorCode(err) // get the API error code
} else {
// this package will combine the stream chunks and build a final result mimicking the non-streaming API
println(res.Choices[0].Message.Content)
}
```
**ChatCompletion(流式调用大语言工具模型GLM-4-AllTools)**
```go
// CodeInterpreter
s := client.ChatCompletion("GLM-4-AllTools")
s.AddMessage(zhipu.ChatCompletionMultiMessage{
Role: "user",
Content: []zhipu.ChatCompletionMultiContent{
{
Type: "text",
Text: "计算[5,10,20,700,99,310,978,100]的平均值和方差。",
},
},
})
s.AddTool(zhipu.ChatCompletionToolCodeInterpreter{
Sandbox: zhipu.Ptr(CodeInterpreterSandboxAuto),
})
s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
for _, c := range chunk.Choices {
for _, tc := range c.Delta.ToolCalls {
if tc.Type == ToolTypeCodeInterpreter && tc.CodeInterpreter != nil {
if tc.CodeInterpreter.Input != "" {
// DO SOMETHING
}
if len(tc.CodeInterpreter.Outputs) > 0 {
// DO SOMETHING
}
}
}
}
return nil
})
// WebBrowser
// CAUTION: NOT 'WebSearch'
s := client.ChatCompletion("GLM-4-AllTools")
s.AddMessage(zhipu.ChatCompletionMultiMessage{
Role: "user",
Content: []zhipu.ChatCompletionMultiContent{
{
Type: "text",
Text: "搜索下本周深圳天气如何",
},
},
})
s.AddTool(zhipu.ChatCompletionToolWebBrowser{})
s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
for _, c := range chunk.Choices {
for _, tc := range c.Delta.ToolCalls {
if tc.Type == ToolTypeWebBrowser && tc.WebBrowser != nil {
if tc.WebBrowser.Input != "" {
// DO SOMETHING
}
if len(tc.WebBrowser.Outputs) > 0 {
// DO SOMETHING
}
}
}
}
return nil
})
s.Do(context.Background())
// DrawingTool
s := client.ChatCompletion("GLM-4-AllTools")
s.AddMessage(zhipu.ChatCompletionMultiMessage{
Role: "user",
Content: []zhipu.ChatCompletionMultiContent{
{
Type: "text",
Text: "画一个正弦函数图像",
},
},
})
s.AddTool(zhipu.ChatCompletionToolDrawingTool{})
s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
for _, c := range chunk.Choices {
for _, tc := range c.Delta.ToolCalls {
if tc.Type == ToolTypeDrawingTool && tc.DrawingTool != nil {
if tc.DrawingTool.Input != "" {
// DO SOMETHING
}
if len(tc.DrawingTool.Outputs) > 0 {
// DO SOMETHING
}
}
}
}
return nil
})
s.Do(context.Background())
```
**Embedding**
```go
service := client.Embedding("embedding-v2").SetInput("你好呀")
service.Do(context.Background())
```
**ImageGeneration(图像生成)**
```go
service := client.ImageGeneration("cogview-3").SetPrompt("一只可爱的小猫咪")
service.Do(context.Background())
```
**VideoGeneration(视频生成)**
```go
service := client.VideoGeneration("cogvideox").SetPrompt("一只可爱的小猫咪")
resp, err := service.Do(context.Background())
for {
result, err := client.AsyncResult(resp.ID).Do(context.Background())
if result.TaskStatus == zhipu.VideoGenerationTaskStatusSuccess {
_ = result.VideoResult[0].URL
_ = result.VideoResult[0].CoverImageURL
break
}
if result.TaskStatus != zhipu.VideoGenerationTaskStatusProcessing {
break
}
time.Sleep(5 * time.Second)
}
```
**UploadFile(上传文件用于取回)**
```go
service := client.FileCreate(zhipu.FilePurposeRetrieval)
service.SetLocalFile(filepath.Join("testdata", "test-file.txt"))
service.SetKnowledgeID("your-knowledge-id")
service.Do(context.Background())
```
**UploadFile(上传文件用于微调)**
```go
service := client.FileCreate(zhipu.FilePurposeFineTune)
service.SetLocalFile(filepath.Join("testdata", "test-file.jsonl"))
service.Do(context.Background())
```
**BatchCreate(创建批量任务)**
```go
service := client.BatchCreate().
SetInputFileID("fileid").
SetCompletionWindow(zhipu.BatchCompletionWindow24h).
SetEndpoint(BatchEndpointV4ChatCompletions)
service.Do(context.Background())
```
**KnowledgeBase(知识库)**
```go
client.KnowledgeCreate("")
client.KnowledgeEdit("")
```
**FineTune(微调)**
```go
client.FineTuneCreate("")
```
### 批量任务辅助工具
**批量任务文件创建**
```go
f, err := os.OpenFile("batch.jsonl", os.O_CREATE|os.O_WRONLY, 0644)
bw := zhipu.NewBatchFileWriter(f)
bw.Add("action_1", client.ChatCompletion("glm-4-flash").
AddMessage(zhipu.ChatCompletionMessage{
Role: "user",
Content: "你好",
}))
bw.Add("action_2", client.Embedding("embedding-v2").SetInput("你好呀"))
bw.Add("action_3", client.ImageGeneration("cogview-3").SetPrompt("一只可爱的小猫咪"))
```
**批量任务结果解析**
```go
br := zhipu.NewBatchResultReader[zhipu.ChatCompletionResponse](r)
for {
var res zhipu.BatchResult[zhipu.ChatCompletionResponse]
err := br.Read(&res)
if err != nil {
break
}
}
```
## 赞助
执行单元测试会真实调用GLM接口消耗我充值的额度开发不易请微信扫码捐赠感谢您的支持
<img src="./wechat-donation.png" width="180"/>
## 许可证
GUO YANKE, MIT License

63
zhipu/async_result.go Normal file
View File

@ -0,0 +1,63 @@
package zhipu
import (
"context"
"github.com/go-resty/resty/v2"
)
// AsyncResultService creates a new async result get service
type AsyncResultService struct {
client *Client
id string
}
// AsyncResultVideo is the video result of the AsyncResultService
type AsyncResultVideo struct {
URL string `json:"url"`
CoverImageURL string `json:"cover_image_url"`
}
// AsyncResultResponse is the response of the AsyncResultService
type AsyncResultResponse struct {
Model string `json:"model"`
TaskStatus string `json:"task_status"`
RequestID string `json:"request_id"`
ID string `json:"id"`
VideoResult []AsyncResultVideo `json:"video_result"`
}
// NewAsyncResultService creates a new async result get service
func NewAsyncResultService(client *Client) *AsyncResultService {
return &AsyncResultService{
client: client,
}
}
// SetID sets the id parameter
func (s *AsyncResultService) SetID(id string) *AsyncResultService {
s.id = id
return s
}
func (s *AsyncResultService) Do(ctx context.Context) (res AsyncResultResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetResult(&res).
SetError(&apiError).
Get("async-result/" + s.id); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}

View File

@ -0,0 +1 @@
package zhipu

258
zhipu/batch.go Normal file
View File

@ -0,0 +1,258 @@
package zhipu
import (
"context"
"encoding/json"
"strconv"
"github.com/go-resty/resty/v2"
)
const (
BatchEndpointV4ChatCompletions = "/v4/chat/completions"
BatchEndpointV4ImagesGenerations = "/v4/images/generations"
BatchEndpointV4Embeddings = "/v4/embeddings"
BatchEndpointV4VideosGenerations = "/v4/videos/generations"
BatchCompletionWindow24h = "24h"
)
// BatchRequestCounts represents the counts of the batch requests.
type BatchRequestCounts struct {
Total int64 `json:"total"`
Completed int64 `json:"completed"`
Failed int64 `json:"failed"`
}
// BatchItem represents a batch item.
type BatchItem struct {
ID string `json:"id"`
Object any `json:"object"`
Endpoint string `json:"endpoint"`
InputFileID string `json:"input_file_id"`
CompletionWindow string `json:"completion_window"`
Status string `json:"status"`
OutputFileID string `json:"output_file_id"`
ErrorFileID string `json:"error_file_id"`
CreatedAt int64 `json:"created_at"`
InProgressAt int64 `json:"in_progress_at"`
ExpiresAt int64 `json:"expires_at"`
FinalizingAt int64 `json:"finalizing_at"`
CompletedAt int64 `json:"completed_at"`
FailedAt int64 `json:"failed_at"`
ExpiredAt int64 `json:"expired_at"`
CancellingAt int64 `json:"cancelling_at"`
CancelledAt int64 `json:"cancelled_at"`
RequestCounts BatchRequestCounts `json:"request_counts"`
Metadata json.RawMessage `json:"metadata"`
}
// BatchCreateService is a service to create a batch.
type BatchCreateService struct {
client *Client
inputFileID string
endpoint string
completionWindow string
metadata any
}
// NewBatchCreateService creates a new BatchCreateService.
func NewBatchCreateService(client *Client) *BatchCreateService {
return &BatchCreateService{client: client}
}
// SetInputFileID sets the input file id for the batch.
func (s *BatchCreateService) SetInputFileID(inputFileID string) *BatchCreateService {
s.inputFileID = inputFileID
return s
}
// SetEndpoint sets the endpoint for the batch.
func (s *BatchCreateService) SetEndpoint(endpoint string) *BatchCreateService {
s.endpoint = endpoint
return s
}
// SetCompletionWindow sets the completion window for the batch.
func (s *BatchCreateService) SetCompletionWindow(window string) *BatchCreateService {
s.completionWindow = window
return s
}
// SetMetadata sets the metadata for the batch.
func (s *BatchCreateService) SetMetadata(metadata any) *BatchCreateService {
s.metadata = metadata
return s
}
// Do executes the batch create service.
func (s *BatchCreateService) Do(ctx context.Context) (res BatchItem, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetBody(M{
"input_file_id": s.inputFileID,
"endpoint": s.endpoint,
"completion_window": s.completionWindow,
"metadata": s.metadata,
}).
SetResult(&res).
SetError(&apiError).
Post("batches"); err != nil {
return
}
if resp.IsError() {
err = apiError
}
return
}
// BatchGetService is a service to get a batch.
type BatchGetService struct {
client *Client
batchID string
}
// BatchGetResponse represents the response of the batch get service.
type BatchGetResponse = BatchItem
// NewBatchGetService creates a new BatchGetService.
func NewBatchGetService(client *Client) *BatchGetService {
return &BatchGetService{client: client}
}
// SetBatchID sets the batch id for the batch get service.
func (s *BatchGetService) SetBatchID(batchID string) *BatchGetService {
s.batchID = batchID
return s
}
// Do executes the batch get service.
func (s *BatchGetService) Do(ctx context.Context) (res BatchGetResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetPathParam("batch_id", s.batchID).
SetResult(&res).
SetError(&apiError).
Get("batches/{batch_id}"); err != nil {
return
}
if resp.IsError() {
err = apiError
}
return
}
// BatchCancelService is a service to cancel a batch.
type BatchCancelService struct {
client *Client
batchID string
}
// NewBatchCancelService creates a new BatchCancelService.
func NewBatchCancelService(client *Client) *BatchCancelService {
return &BatchCancelService{client: client}
}
// SetBatchID sets the batch id for the batch cancel service.
func (s *BatchCancelService) SetBatchID(batchID string) *BatchCancelService {
s.batchID = batchID
return s
}
// Do executes the batch cancel service.
func (s *BatchCancelService) Do(ctx context.Context) (err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetPathParam("batch_id", s.batchID).
SetBody(M{}).
SetError(&apiError).
Post("batches/{batch_id}/cancel"); err != nil {
return
}
if resp.IsError() {
err = apiError
}
return
}
// BatchListService is a service to list batches.
type BatchListService struct {
client *Client
after *string
limit *int
}
// BatchListResponse represents the response of the batch list service.
type BatchListResponse struct {
Object string `json:"object"`
Data []BatchItem `json:"data"`
FirstID string `json:"first_id"`
LastID string `json:"last_id"`
HasMore bool `json:"has_more"`
}
// NewBatchListService creates a new BatchListService.
func NewBatchListService(client *Client) *BatchListService {
return &BatchListService{client: client}
}
// SetAfter sets the after cursor for the batch list service.
func (s *BatchListService) SetAfter(after string) *BatchListService {
s.after = &after
return s
}
// SetLimit sets the limit for the batch list service.
func (s *BatchListService) SetLimit(limit int) *BatchListService {
s.limit = &limit
return s
}
// Do executes the batch list service.
func (s *BatchListService) Do(ctx context.Context) (res BatchListResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
req := s.client.request(ctx)
if s.after != nil {
req.SetQueryParam("after", *s.after)
}
if s.limit != nil {
req.SetQueryParam("limit", strconv.Itoa(*s.limit))
}
if resp, err = req.
SetResult(&res).
SetError(&apiError).
Get("batches"); err != nil {
return
}
if resp.IsError() {
err = apiError
}
return
}

63
zhipu/batch_support.go Normal file
View File

@ -0,0 +1,63 @@
package zhipu
import (
"encoding/json"
"io"
)
// BatchSupport is the interface for services with batch support.
type BatchSupport interface {
BatchMethod() string
BatchURL() string
BatchBody() any
}
// BatchFileWriter is a writer for batch files.
type BatchFileWriter struct {
w io.Writer
je *json.Encoder
}
// NewBatchFileWriter creates a new BatchFileWriter.
func NewBatchFileWriter(w io.Writer) *BatchFileWriter {
return &BatchFileWriter{w: w, je: json.NewEncoder(w)}
}
// Write writes a batch file.
func (b *BatchFileWriter) Write(customID string, s BatchSupport) error {
return b.je.Encode(M{
"custom_id": customID,
"method": s.BatchMethod(),
"url": s.BatchURL(),
"body": s.BatchBody(),
})
}
// BatchResultResponse is the response of a batch result.
type BatchResultResponse[T any] struct {
StatusCode int `json:"status_code"`
Body T `json:"body"`
}
// BatchResult is the result of a batch.
type BatchResult[T any] struct {
ID string `json:"id"`
CustomID string `json:"custom_id"`
Response BatchResultResponse[T] `json:"response"`
}
// BatchResultReader reads batch results.
type BatchResultReader[T any] struct {
r io.Reader
jd *json.Decoder
}
// NewBatchResultReader creates a new BatchResultReader.
func NewBatchResultReader[T any](r io.Reader) *BatchResultReader[T] {
return &BatchResultReader[T]{r: r, jd: json.NewDecoder(r)}
}
// Read reads a batch result.
func (r *BatchResultReader[T]) Read(out *BatchResult[T]) error {
return r.jd.Decode(out)
}

View File

@ -0,0 +1,73 @@
package zhipu
import (
"bytes"
"io"
"testing"
"github.com/stretchr/testify/require"
)
func TestBatchFileWriter(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
buf := &bytes.Buffer{}
w := NewBatchFileWriter(buf)
err = w.Write("batch-1", client.ChatCompletion("a").AddMessage(ChatCompletionMessage{
Role: "user", Content: "hello",
}))
require.NoError(t, err)
err = w.Write("batch-2", client.Embedding("c").SetInput("whoa"))
require.NoError(t, err)
err = w.Write("batch-3", client.ImageGeneration("d").SetPrompt("whoa"))
require.NoError(t, err)
require.Equal(t, `{"body":{"messages":[{"role":"user","content":"hello"}],"model":"a"},"custom_id":"batch-1","method":"POST","url":"/v4/chat/completions"}
{"body":{"input":"whoa","model":"c"},"custom_id":"batch-2","method":"POST","url":"/v4/embeddings"}
{"body":{"model":"d","prompt":"whoa"},"custom_id":"batch-3","method":"POST","url":"/v4/images/generations"}
`, buf.String())
}
func TestBatchResultReader(t *testing.T) {
result := `
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":26,"prompt_tokens":89,"total_tokens":115},"model":"glm-4","id":"8668357533850320547","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"订单处理慢\"\n}\n'''"}}],"request_id":"615-request-1"}},"custom_id":"request-1","id":"batch_1791490810192076800"}
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":22,"prompt_tokens":94,"total_tokens":116},"model":"glm-4","id":"8668368425887509080","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"产品缺陷\"\n}\n'''"}}],"request_id":"616-request-2"}},"custom_id":"request-2","id":"batch_1791490810192076800"}
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":25,"prompt_tokens":86,"total_tokens":111},"model":"glm-4","id":"8668355815863214980","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"正面\",\n \"特定问题标注\": \"性价比\"\n}\n'''"}}],"request_id":"617-request-3"}},"custom_id":"request-3","id":"batch_1791490810192076800"}
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":28,"prompt_tokens":89,"total_tokens":117},"model":"glm-4","id":"8668355815863214981","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"说明文档不清晰\"\n}\n'''"}}],"request_id":"618-request-4"}},"custom_id":"request-4","id":"batch_1791490810192076800"}
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":26,"prompt_tokens":88,"total_tokens":114},"model":"glm-4","id":"8668357533850320546","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"中性\",\n \"特定问题标注\": \"价格问题\"\n}\n'''"}}],"request_id":"619-request-5"}},"custom_id":"request-5","id":"batch_1791490810192076800"}
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":26,"prompt_tokens":90,"total_tokens":116},"model":"glm-4","id":"8668356159460662846","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"配送延迟\"\n}\n'''"}}],"request_id":"620-request-6"}},"custom_id":"request-6","id":"batch_1791490810192076800"}
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":27,"prompt_tokens":88,"total_tokens":115},"model":"glm-4","id":"8668357671289274638","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"产品描述不符\"\n}\n'''"}}],"request_id":"621-request-7"}},"custom_id":"request-7","id":"batch_1791490810192076800"}
{"response":{"status_code":200,"body":{"created":1715959702,"usage":{"completion_tokens":26,"prompt_tokens":87,"total_tokens":113},"model":"glm-4","id":"8668355644064514872","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"正面\",\n \"特定问题标注\": \"客服态度\"\n}\n'''"}}],"request_id":"622-request-8"}},"custom_id":"request-8","id":"batch_1791490810192076800"}
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":29,"prompt_tokens":90,"total_tokens":119},"model":"glm-4","id":"8668357671289274639","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"包装问题, 产品损坏\"\n}\n'''"}}],"request_id":"623-request-9"}},"custom_id":"request-9","id":"batch_1791490810192076800"}
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":27,"prompt_tokens":87,"total_tokens":114},"model":"glm-4","id":"8668355644064514871","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"正面\",\n \"特定问题标注\": \"产品描述不符\"\n}\n'''"}}],"request_id":"624-request-10"}},"custom_id":"request-10","id":"batch_1791490810192076800"}
`
brr := NewBatchResultReader[ChatCompletionResponse](bytes.NewReader([]byte(result)))
var count int
for {
var res BatchResult[ChatCompletionResponse]
err := brr.Read(&res)
if err != nil {
if err == io.EOF {
err = nil
}
require.Equal(t, 10, count)
require.NoError(t, err)
break
}
require.Equal(t, 200, res.Response.StatusCode)
count++
}
}

59
zhipu/batch_test.go Normal file
View File

@ -0,0 +1,59 @@
package zhipu
import (
"bytes"
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestBatchServiceAll(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
buf := &bytes.Buffer{}
bfw := NewBatchFileWriter(buf)
err = bfw.Write("batch_1", client.ChatCompletion("glm-4-flash").AddMessage(ChatCompletionMessage{
Role: RoleUser, Content: "你好呀",
}))
require.NoError(t, err)
err = bfw.Write("batch_2", client.ChatCompletion("glm-4-flash").AddMessage(ChatCompletionMessage{
Role: RoleUser, Content: "你叫什么名字",
}))
require.NoError(t, err)
res, err := client.FileCreate(FilePurposeBatch).SetFile(bytes.NewReader(buf.Bytes()), "batch.jsonl").Do(context.Background())
require.NoError(t, err)
fileID := res.FileCreateFineTuneResponse.ID
require.NotEmpty(t, fileID)
res1, err := client.BatchCreate().
SetInputFileID(fileID).
SetCompletionWindow(BatchCompletionWindow24h).
SetEndpoint(BatchEndpointV4ChatCompletions).Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res1.ID)
res2, err := client.BatchGet(res1.ID).Do(context.Background())
require.NoError(t, err)
require.Equal(t, res2.ID, res1.ID)
res3, err := client.BatchList().Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res3.Data)
err = client.BatchCancel(res1.ID).Do(context.Background())
require.NoError(t, err)
}
func TestBatchListService(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
res, err := client.BatchList().Do(context.Background())
require.NoError(t, err)
t.Log(res)
}

577
zhipu/chat_completion.go Normal file
View File

@ -0,0 +1,577 @@
package zhipu
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"io"
"github.com/go-resty/resty/v2"
)
const (
RoleSystem = "system"
RoleUser = "user"
RoleAssistant = "assistant"
RoleTool = "tool"
ToolChoiceAuto = "auto"
FinishReasonStop = "stop"
FinishReasonStopSequence = "stop_sequence"
FinishReasonToolCalls = "tool_calls"
FinishReasonLength = "length"
FinishReasonSensitive = "sensitive"
FinishReasonNetworkError = "network_error"
ToolTypeFunction = "function"
ToolTypeWebSearch = "web_search"
ToolTypeRetrieval = "retrieval"
MultiContentTypeText = "text"
MultiContentTypeImageURL = "image_url"
MultiContentTypeVideoURL = "video_url"
// New in GLM-4-AllTools
ToolTypeCodeInterpreter = "code_interpreter"
ToolTypeDrawingTool = "drawing_tool"
ToolTypeWebBrowser = "web_browser"
CodeInterpreterSandboxNone = "none"
CodeInterpreterSandboxAuto = "auto"
ChatCompletionStatusFailed = "failed"
ChatCompletionStatusCompleted = "completed"
ChatCompletionStatusRequiresAction = "requires_action"
)
// ChatCompletionTool is the interface for chat completion tool
type ChatCompletionTool interface {
isChatCompletionTool()
}
// ChatCompletionToolFunction is the function for chat completion
type ChatCompletionToolFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters any `json:"parameters"`
}
func (ChatCompletionToolFunction) isChatCompletionTool() {}
// ChatCompletionToolRetrieval is the retrieval for chat completion
type ChatCompletionToolRetrieval struct {
KnowledgeID string `json:"knowledge_id"`
PromptTemplate string `json:"prompt_template,omitempty"`
}
func (ChatCompletionToolRetrieval) isChatCompletionTool() {}
// ChatCompletionToolWebSearch is the web search for chat completion
type ChatCompletionToolWebSearch struct {
Enable *bool `json:"enable,omitempty"`
SearchQuery string `json:"search_query,omitempty"`
SearchResult bool `json:"search_result,omitempty"`
}
func (ChatCompletionToolWebSearch) isChatCompletionTool() {}
// ChatCompletionToolCodeInterpreter is the code interpreter for chat completion
// only in GLM-4-AllTools
type ChatCompletionToolCodeInterpreter struct {
Sandbox *string `json:"sandbox,omitempty"`
}
func (ChatCompletionToolCodeInterpreter) isChatCompletionTool() {}
// ChatCompletionToolDrawingTool is the drawing tool for chat completion
// only in GLM-4-AllTools
type ChatCompletionToolDrawingTool struct {
// no fields
}
func (ChatCompletionToolDrawingTool) isChatCompletionTool() {}
// ChatCompletionToolWebBrowser is the web browser for chat completion
type ChatCompletionToolWebBrowser struct {
// no fields
}
func (ChatCompletionToolWebBrowser) isChatCompletionTool() {}
// ChatCompletionUsage is the usage for chat completion
type ChatCompletionUsage struct {
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
TotalTokens int64 `json:"total_tokens"`
}
// ChatCompletionWebSearch is the web search result for chat completion
type ChatCompletionWebSearch struct {
Icon string `json:"icon"`
Title string `json:"title"`
Link string `json:"link"`
Media string `json:"media"`
Content string `json:"content"`
}
// ChatCompletionToolCallFunction is the function for chat completion tool call
type ChatCompletionToolCallFunction struct {
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
}
// ChatCompletionToolCallCodeInterpreterOutput is the output for chat completion tool call code interpreter
type ChatCompletionToolCallCodeInterpreterOutput struct {
Type string `json:"type"`
Logs string `json:"logs"`
File string `json:"file"`
}
// ChatCompletionToolCallCodeInterpreter is the code interpreter for chat completion tool call
type ChatCompletionToolCallCodeInterpreter struct {
Input string `json:"input"`
Outputs []ChatCompletionToolCallCodeInterpreterOutput `json:"outputs"`
}
// ChatCompletionToolCallDrawingToolOutput is the output for chat completion tool call drawing tool
type ChatCompletionToolCallDrawingToolOutput struct {
Image string `json:"image"`
}
// ChatCompletionToolCallDrawingTool is the drawing tool for chat completion tool call
type ChatCompletionToolCallDrawingTool struct {
Input string `json:"input"`
Outputs []ChatCompletionToolCallDrawingToolOutput `json:"outputs"`
}
// ChatCompletionToolCallWebBrowserOutput is the output for chat completion tool call web browser
type ChatCompletionToolCallWebBrowserOutput struct {
Title string `json:"title"`
Link string `json:"link"`
Content string `json:"content"`
}
// ChatCompletionToolCallWebBrowser is the web browser for chat completion tool call
type ChatCompletionToolCallWebBrowser struct {
Input string `json:"input"`
Outputs []ChatCompletionToolCallWebBrowserOutput `json:"outputs"`
}
// ChatCompletionToolCall is the tool call for chat completion
type ChatCompletionToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function *ChatCompletionToolCallFunction `json:"function,omitempty"`
CodeInterpreter *ChatCompletionToolCallCodeInterpreter `json:"code_interpreter,omitempty"`
DrawingTool *ChatCompletionToolCallDrawingTool `json:"drawing_tool,omitempty"`
WebBrowser *ChatCompletionToolCallWebBrowser `json:"web_browser,omitempty"`
}
type ChatCompletionMessageType interface {
isChatCompletionMessageType()
}
// ChatCompletionMessage is the message for chat completion
type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content,omitempty"`
ToolCalls []ChatCompletionToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
func (ChatCompletionMessage) isChatCompletionMessageType() {}
type ChatCompletionMultiContent struct {
Type string `json:"type"`
Text string `json:"text"`
ImageURL *URLItem `json:"image_url,omitempty"`
VideoURL *URLItem `json:"video_url,omitempty"`
}
// ChatCompletionMultiMessage is the multi message for chat completion
type ChatCompletionMultiMessage struct {
Role string `json:"role"`
Content []ChatCompletionMultiContent `json:"content"`
}
func (ChatCompletionMultiMessage) isChatCompletionMessageType() {}
// ChatCompletionMeta is the meta for chat completion
type ChatCompletionMeta struct {
UserInfo string `json:"user_info"`
BotInfo string `json:"bot_info"`
UserName string `json:"user_name"`
BotName string `json:"bot_name"`
}
// ChatCompletionChoice is the choice for chat completion
type ChatCompletionChoice struct {
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
Delta ChatCompletionMessage `json:"delta"` // stream mode
Message ChatCompletionMessage `json:"message"` // non-stream mode
}
// ChatCompletionResponse is the response for chat completion
type ChatCompletionResponse struct {
ID string `json:"id"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionChoice `json:"choices"`
Usage ChatCompletionUsage `json:"usage"`
WebSearch []ChatCompletionWebSearch `json:"web_search"`
// Status is the status of the chat completion, only in GLM-4-AllTools
Status string `json:"status"`
}
// ChatCompletionStreamHandler is the handler for chat completion stream
type ChatCompletionStreamHandler func(chunk ChatCompletionResponse) error
var (
chatCompletionStreamPrefix = []byte("data:")
chatCompletionStreamDone = []byte("[DONE]")
)
// chatCompletionReduceResponse reduce the chunk to the response
func chatCompletionReduceResponse(out *ChatCompletionResponse, chunk ChatCompletionResponse) {
if len(out.Choices) == 0 {
out.Choices = append(out.Choices, ChatCompletionChoice{})
}
// basic
out.ID = chunk.ID
out.Created = chunk.Created
out.Model = chunk.Model
// choices
if len(chunk.Choices) != 0 {
oc := &out.Choices[0]
cc := chunk.Choices[0]
oc.Index = cc.Index
if cc.Delta.Role != "" {
oc.Message.Role = cc.Delta.Role
}
oc.Message.Content += cc.Delta.Content
oc.Message.ToolCalls = append(oc.Message.ToolCalls, cc.Delta.ToolCalls...)
if cc.FinishReason != "" {
oc.FinishReason = cc.FinishReason
}
}
// usage
if chunk.Usage.CompletionTokens != 0 {
out.Usage.CompletionTokens = chunk.Usage.CompletionTokens
}
if chunk.Usage.PromptTokens != 0 {
out.Usage.PromptTokens = chunk.Usage.PromptTokens
}
if chunk.Usage.TotalTokens != 0 {
out.Usage.TotalTokens = chunk.Usage.TotalTokens
}
// web search
out.WebSearch = append(out.WebSearch, chunk.WebSearch...)
}
// chatCompletionDecodeStream decode the sse stream of chat completion
func chatCompletionDecodeStream(r io.Reader, fn func(chunk ChatCompletionResponse) error) (err error) {
br := bufio.NewReader(r)
for {
var line []byte
if line, err = br.ReadBytes('\n'); err != nil {
if errors.Is(err, io.EOF) {
err = nil
}
break
}
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if !bytes.HasPrefix(line, chatCompletionStreamPrefix) {
continue
}
data := bytes.TrimSpace(line[len(chatCompletionStreamPrefix):])
if bytes.Equal(data, chatCompletionStreamDone) {
break
}
if len(data) == 0 {
continue
}
var chunk ChatCompletionResponse
if err = json.Unmarshal(data, &chunk); err != nil {
return
}
if err = fn(chunk); err != nil {
return
}
}
return
}
// ChatCompletionStreamService is the service for chat completion stream
type ChatCompletionService struct {
client *Client
model string
requestID *string
doSample *bool
temperature *float64
topP *float64
maxTokens *int
stop []string
toolChoice *string
userID *string
meta *ChatCompletionMeta
messages []any
tools []any
streamHandler ChatCompletionStreamHandler
}
var (
_ BatchSupport = &ChatCompletionService{}
)
// NewChatCompletionService creates a new ChatCompletionService.
func NewChatCompletionService(client *Client) *ChatCompletionService {
return &ChatCompletionService{
client: client,
}
}
func (s *ChatCompletionService) BatchMethod() string {
return "POST"
}
func (s *ChatCompletionService) BatchURL() string {
return BatchEndpointV4ChatCompletions
}
func (s *ChatCompletionService) BatchBody() any {
return s.buildBody()
}
// SetModel set the model of the chat completion
func (s *ChatCompletionService) SetModel(model string) *ChatCompletionService {
s.model = model
return s
}
// SetMeta set the meta of the chat completion, optional
func (s *ChatCompletionService) SetMeta(meta ChatCompletionMeta) *ChatCompletionService {
s.meta = &meta
return s
}
// SetRequestID set the request id of the chat completion, optional
func (s *ChatCompletionService) SetRequestID(requestID string) *ChatCompletionService {
s.requestID = &requestID
return s
}
// SetTemperature set the temperature of the chat completion, optional
func (s *ChatCompletionService) SetDoSample(doSample bool) *ChatCompletionService {
s.doSample = &doSample
return s
}
// SetTemperature set the temperature of the chat completion, optional
func (s *ChatCompletionService) SetTemperature(temperature float64) *ChatCompletionService {
s.temperature = &temperature
return s
}
// SetTopP set the top p of the chat completion, optional
func (s *ChatCompletionService) SetTopP(topP float64) *ChatCompletionService {
s.topP = &topP
return s
}
// SetMaxTokens set the max tokens of the chat completion, optional
func (s *ChatCompletionService) SetMaxTokens(maxTokens int) *ChatCompletionService {
s.maxTokens = &maxTokens
return s
}
// SetStop set the stop of the chat completion, optional
func (s *ChatCompletionService) SetStop(stop ...string) *ChatCompletionService {
s.stop = stop
return s
}
// SetToolChoice set the tool choice of the chat completion, optional
func (s *ChatCompletionService) SetToolChoice(toolChoice string) *ChatCompletionService {
s.toolChoice = &toolChoice
return s
}
// SetUserID set the user id of the chat completion, optional
func (s *ChatCompletionService) SetUserID(userID string) *ChatCompletionService {
s.userID = &userID
return s
}
// SetStreamHandler set the stream handler of the chat completion, optional
// this will enable the stream mode
func (s *ChatCompletionService) SetStreamHandler(handler ChatCompletionStreamHandler) *ChatCompletionService {
s.streamHandler = handler
return s
}
// AddMessage add the message to the chat completion
func (s *ChatCompletionService) AddMessage(messages ...ChatCompletionMessageType) *ChatCompletionService {
for _, message := range messages {
s.messages = append(s.messages, message)
}
return s
}
// AddFunction add the function to the chat completion
func (s *ChatCompletionService) AddTool(tools ...ChatCompletionTool) *ChatCompletionService {
for _, tool := range tools {
switch tool := tool.(type) {
case ChatCompletionToolFunction:
s.tools = append(s.tools, map[string]any{
"type": ToolTypeFunction,
ToolTypeFunction: tool,
})
case ChatCompletionToolRetrieval:
s.tools = append(s.tools, map[string]any{
"type": ToolTypeRetrieval,
ToolTypeRetrieval: tool,
})
case ChatCompletionToolWebSearch:
s.tools = append(s.tools, map[string]any{
"type": ToolTypeWebSearch,
ToolTypeWebSearch: tool,
})
case ChatCompletionToolCodeInterpreter:
s.tools = append(s.tools, map[string]any{
"type": ToolTypeCodeInterpreter,
ToolTypeCodeInterpreter: tool,
})
case ChatCompletionToolDrawingTool:
s.tools = append(s.tools, map[string]any{
"type": ToolTypeDrawingTool,
ToolTypeDrawingTool: tool,
})
case ChatCompletionToolWebBrowser:
s.tools = append(s.tools, map[string]any{
"type": ToolTypeWebBrowser,
ToolTypeWebBrowser: tool,
})
}
}
return s
}
func (s *ChatCompletionService) buildBody() M {
body := map[string]any{
"model": s.model,
"messages": s.messages,
}
if s.requestID != nil {
body["request_id"] = *s.requestID
}
if s.doSample != nil {
body["do_sample"] = *s.doSample
}
if s.temperature != nil {
body["temperature"] = *s.temperature
}
if s.topP != nil {
body["top_p"] = *s.topP
}
if s.maxTokens != nil {
body["max_tokens"] = *s.maxTokens
}
if len(s.stop) != 0 {
body["stop"] = s.stop
}
if len(s.tools) != 0 {
body["tools"] = s.tools
}
if s.toolChoice != nil {
body["tool_choice"] = *s.toolChoice
}
if s.userID != nil {
body["user_id"] = *s.userID
}
if s.meta != nil {
body["meta"] = s.meta
}
return body
}
// Do send the request of the chat completion and return the response
func (s *ChatCompletionService) Do(ctx context.Context) (res ChatCompletionResponse, err error) {
body := s.buildBody()
streamHandler := s.streamHandler
if streamHandler == nil {
var (
resp *resty.Response
apiError APIErrorResponse
)
//fmt.Println(u.BMagenta(u.JsonP(body)), 111)
if resp, err = s.client.request(ctx).SetBody(body).SetResult(&res).SetError(&apiError).Post("chat/completions"); err != nil {
//fmt.Println(u.BRed(err.Error()), 2221)
return
}
if resp.IsError() {
err = apiError
//fmt.Println(u.BRed(err.Error()), 2222)
return
}
//fmt.Println(u.BGreen(u.JsonP(resp.Result())), resp.Status(), resp.Status(), 333)
return
}
// stream mode
body["stream"] = true
var resp *resty.Response
if resp, err = s.client.request(ctx).SetBody(body).SetDoNotParseResponse(true).Post("chat/completions"); err != nil {
return
}
defer resp.RawBody().Close()
if resp.IsError() {
err = errors.New(resp.Status())
return
}
var choice ChatCompletionChoice
if err = chatCompletionDecodeStream(resp.RawBody(), func(chunk ChatCompletionResponse) error {
// reduce the chunk to the response
chatCompletionReduceResponse(&res, chunk)
// invoke the stream handler
return streamHandler(chunk)
}); err != nil {
return
}
res.Choices = append(res.Choices, choice)
return
}

View File

@ -0,0 +1,251 @@
package zhipu
import (
"context"
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestChatCompletionService(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.ChatCompletion("glm-4-flash")
s.AddMessage(ChatCompletionMessage{
Role: RoleUser,
Content: "你好呀",
})
res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.Choices)
choice := res.Choices[0]
require.Equal(t, FinishReasonStop, choice.FinishReason)
require.NotEmpty(t, choice.Message.Content)
}
func TestChatCompletionServiceCharGLM(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.ChatCompletion("charglm-3")
s.SetMeta(
ChatCompletionMeta{
UserName: "啵酱",
UserInfo: "啵酱是小少爷",
BotName: "塞巴斯酱",
BotInfo: "塞巴斯酱是一个冷酷的恶魔管家",
},
).AddMessage(ChatCompletionMessage{
Role: RoleUser,
Content: "早上好",
})
res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.Choices)
choice := res.Choices[0]
require.Contains(t, []string{FinishReasonStop, FinishReasonStopSequence}, choice.FinishReason)
require.NotEmpty(t, choice.Message.Content)
}
func TestChatCompletionServiceAllToolsCodeInterpreter(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.ChatCompletion("GLM-4-AllTools")
s.AddMessage(ChatCompletionMultiMessage{
Role: "user",
Content: []ChatCompletionMultiContent{
{
Type: "text",
Text: "计算[5,10,20,700,99,310,978,100]的平均值和方差。",
},
},
})
s.AddTool(ChatCompletionToolCodeInterpreter{
Sandbox: Ptr(CodeInterpreterSandboxAuto),
})
foundInterpreterInput := false
foundInterpreterOutput := false
s.SetStreamHandler(func(chunk ChatCompletionResponse) error {
for _, c := range chunk.Choices {
for _, tc := range c.Delta.ToolCalls {
if tc.Type == ToolTypeCodeInterpreter && tc.CodeInterpreter != nil {
if tc.CodeInterpreter.Input != "" {
foundInterpreterInput = true
}
if len(tc.CodeInterpreter.Outputs) > 0 {
foundInterpreterOutput = true
}
}
}
}
buf, _ := json.MarshalIndent(chunk, "", " ")
t.Log(string(buf))
return nil
})
res, err := s.Do(context.Background())
require.True(t, foundInterpreterInput)
require.True(t, foundInterpreterOutput)
require.NotNil(t, res)
require.NoError(t, err)
}
func TestChatCompletionServiceAllToolsDrawingTool(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.ChatCompletion("GLM-4-AllTools")
s.AddMessage(ChatCompletionMultiMessage{
Role: "user",
Content: []ChatCompletionMultiContent{
{
Type: "text",
Text: "画一个正弦函数图像",
},
},
})
s.AddTool(ChatCompletionToolDrawingTool{})
foundInput := false
foundOutput := false
outputImage := ""
s.SetStreamHandler(func(chunk ChatCompletionResponse) error {
for _, c := range chunk.Choices {
for _, tc := range c.Delta.ToolCalls {
if tc.Type == ToolTypeDrawingTool && tc.DrawingTool != nil {
if tc.DrawingTool.Input != "" {
foundInput = true
}
if len(tc.DrawingTool.Outputs) > 0 {
foundOutput = true
}
for _, output := range tc.DrawingTool.Outputs {
if output.Image != "" {
outputImage = output.Image
}
}
}
}
}
buf, _ := json.MarshalIndent(chunk, "", " ")
t.Log(string(buf))
return nil
})
res, err := s.Do(context.Background())
require.True(t, foundInput)
require.True(t, foundOutput)
require.NotEmpty(t, outputImage)
t.Log(outputImage)
require.NotNil(t, res)
require.NoError(t, err)
}
func TestChatCompletionServiceAllToolsWebBrowser(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.ChatCompletion("GLM-4-AllTools")
s.AddMessage(ChatCompletionMultiMessage{
Role: "user",
Content: []ChatCompletionMultiContent{
{
Type: "text",
Text: "搜索下本周深圳天气如何",
},
},
})
s.AddTool(ChatCompletionToolWebBrowser{})
foundInput := false
foundOutput := false
outputContent := ""
s.SetStreamHandler(func(chunk ChatCompletionResponse) error {
for _, c := range chunk.Choices {
for _, tc := range c.Delta.ToolCalls {
if tc.Type == ToolTypeWebBrowser && tc.WebBrowser != nil {
if tc.WebBrowser.Input != "" {
foundInput = true
}
if len(tc.WebBrowser.Outputs) > 0 {
foundOutput = true
}
for _, output := range tc.WebBrowser.Outputs {
if output.Content != "" {
outputContent = output.Content
}
}
}
}
}
buf, _ := json.MarshalIndent(chunk, "", " ")
t.Log(string(buf))
return nil
})
res, err := s.Do(context.Background())
require.True(t, foundInput)
require.True(t, foundOutput)
require.NotEmpty(t, outputContent)
t.Log(outputContent)
require.NotNil(t, res)
require.NoError(t, err)
}
func TestChatCompletionServiceStream(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
var content string
s := client.ChatCompletion("glm-4-flash").AddMessage(ChatCompletionMessage{
Role: RoleUser,
Content: "你好呀",
}).SetStreamHandler(func(chunk ChatCompletionResponse) error {
content += chunk.Choices[0].Delta.Content
return nil
})
res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.Choices)
choice := res.Choices[0]
require.Equal(t, FinishReasonStop, choice.FinishReason)
require.NotEmpty(t, choice.Message.Content)
require.Equal(t, content, choice.Message.Content)
}
func TestChatCompletionServiceVision(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.ChatCompletion("glm-4v")
s.AddMessage(ChatCompletionMultiMessage{
Role: RoleUser,
Content: []ChatCompletionMultiContent{
{
Type: MultiContentTypeText,
Text: "图里有什么",
},
{
Type: MultiContentTypeImageURL,
ImageURL: &URLItem{
URL: "https://img1.baidu.com/it/u=1369931113,3388870256&fm=253&app=138&size=w931&n=0&f=JPEG&fmt=auto?sec=1703696400&t=f3028c7a1dca43a080aeb8239f09cc2f",
},
},
},
})
res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.Choices)
require.NotZero(t, res.Usage.CompletionTokens)
choice := res.Choices[0]
require.Equal(t, FinishReasonStop, choice.FinishReason)
require.NotEmpty(t, choice.Message.Content)
}

291
zhipu/client.go Normal file
View File

@ -0,0 +1,291 @@
package zhipu
import (
"context"
"errors"
"net/http"
"os"
"strconv"
"strings"
"time"
"github.com/go-resty/resty/v2"
"github.com/golang-jwt/jwt/v5"
)
const (
envAPIKey = "ZHIPUAI_API_KEY"
envBaseURL = "ZHIPUAI_BASE_URL"
envDebug = "ZHIPUAI_DEBUG"
defaultBaseURL = "https://open.bigmodel.cn/api/paas/v4"
)
var (
// ErrAPIKeyMissing is the error when the api key is missing
ErrAPIKeyMissing = errors.New("zhipu: api key is missing")
// ErrAPIKeyMalformed is the error when the api key is malformed
ErrAPIKeyMalformed = errors.New("zhipu: api key is malformed")
)
type clientOptions struct {
baseURL string
apiKey string
client *http.Client
debug *bool
}
// ClientOption is a function that configures the client
type ClientOption func(opts *clientOptions)
// WithAPIKey set the api key of the client
func WithAPIKey(apiKey string) ClientOption {
return func(opts *clientOptions) {
opts.apiKey = apiKey
}
}
// WithBaseURL set the base url of the client
func WithBaseURL(baseURL string) ClientOption {
return func(opts *clientOptions) {
opts.baseURL = baseURL
}
}
// WithHTTPClient set the http client of the client
func WithHTTPClient(client *http.Client) ClientOption {
return func(opts *clientOptions) {
opts.client = client
}
}
// WithDebug set the debug mode of the client
func WithDebug(debug bool) ClientOption {
return func(opts *clientOptions) {
opts.debug = new(bool)
*opts.debug = debug
}
}
// Client is the client for zhipu ai platform
type Client struct {
client *resty.Client
debug bool
keyID string
keySecret []byte
}
func (c *Client) createJWT() string {
timestamp := time.Now().UnixMilli()
exp := timestamp + time.Hour.Milliseconds()*24*7
t := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"api_key": c.keyID,
"timestamp": timestamp,
"exp": exp,
})
t.Header = map[string]interface{}{
"alg": "HS256",
"sign_type": "SIGN",
}
token, err := t.SignedString(c.keySecret)
if err != nil {
panic(err)
}
return token
}
// request creates a new resty request with the jwt token and context
func (c *Client) request(ctx context.Context) *resty.Request {
return c.client.R().SetContext(ctx).SetHeader("Authorization", c.createJWT())
}
// NewClient creates a new client
// It will read the api key from the environment variable ZHIPUAI_API_KEY
// It will read the base url from the environment variable ZHIPUAI_BASE_URL
func NewClient(optFns ...ClientOption) (client *Client, err error) {
var opts clientOptions
for _, optFn := range optFns {
optFn(&opts)
}
// base url
if opts.baseURL == "" {
opts.baseURL = strings.TrimSpace(os.Getenv(envBaseURL))
}
if opts.baseURL == "" {
opts.baseURL = defaultBaseURL
}
// api key
if opts.apiKey == "" {
opts.apiKey = strings.TrimSpace(os.Getenv(envAPIKey))
}
if opts.apiKey == "" {
err = ErrAPIKeyMissing
return
}
// debug
if opts.debug == nil {
if debugStr := strings.TrimSpace(os.Getenv(envDebug)); debugStr != "" {
if debug, err1 := strconv.ParseBool(debugStr); err1 == nil {
opts.debug = &debug
}
}
}
keyComponents := strings.SplitN(opts.apiKey, ".", 2)
if len(keyComponents) != 2 {
err = ErrAPIKeyMalformed
return
}
client = &Client{
keyID: keyComponents[0],
keySecret: []byte(keyComponents[1]),
}
if opts.client == nil {
client.client = resty.New()
} else {
client.client = resty.NewWithClient(opts.client)
}
client.client = client.client.SetBaseURL(opts.baseURL)
if opts.debug != nil {
client.client.SetDebug(*opts.debug)
client.debug = *opts.debug
}
return
}
// BatchCreate creates a new BatchCreateService.
func (c *Client) BatchCreate() *BatchCreateService {
return NewBatchCreateService(c)
}
// BatchGet creates a new BatchGetService.
func (c *Client) BatchGet(batchID string) *BatchGetService {
return NewBatchGetService(c).SetBatchID(batchID)
}
// BatchCancel creates a new BatchCancelService.
func (c *Client) BatchCancel(batchID string) *BatchCancelService {
return NewBatchCancelService(c).SetBatchID(batchID)
}
// BatchList creates a new BatchListService.
func (c *Client) BatchList() *BatchListService {
return NewBatchListService(c)
}
// ChatCompletion creates a new ChatCompletionService.
func (c *Client) ChatCompletion(model string) *ChatCompletionService {
return NewChatCompletionService(c).SetModel(model)
}
// Embedding embeds a list of text into a vector space.
func (c *Client) Embedding(model string) *EmbeddingService {
return NewEmbeddingService(c).SetModel(model)
}
// FileCreate creates a new FileCreateService.
func (c *Client) FileCreate(purpose string) *FileCreateService {
return NewFileCreateService(c).SetPurpose(purpose)
}
// FileEditService creates a new FileEditService.
func (c *Client) FileEdit(documentID string) *FileEditService {
return NewFileEditService(c).SetDocumentID(documentID)
}
// FileList creates a new FileListService.
func (c *Client) FileList(purpose string) *FileListService {
return NewFileListService(c).SetPurpose(purpose)
}
// FileDeleteService creates a new FileDeleteService.
func (c *Client) FileDelete(documentID string) *FileDeleteService {
return NewFileDeleteService(c).SetDocumentID(documentID)
}
// FileGetService creates a new FileGetService.
func (c *Client) FileGet(documentID string) *FileGetService {
return NewFileGetService(c).SetDocumentID(documentID)
}
// FileDownload creates a new FileDownloadService.
func (c *Client) FileDownload(fileID string) *FileDownloadService {
return NewFileDownloadService(c).SetFileID(fileID)
}
// FineTuneCreate creates a new fine tune create service
func (c *Client) FineTuneCreate(model string) *FineTuneCreateService {
return NewFineTuneCreateService(c).SetModel(model)
}
// FineTuneEventList creates a new fine tune event list service
func (c *Client) FineTuneEventList(jobID string) *FineTuneEventListService {
return NewFineTuneEventListService(c).SetJobID(jobID)
}
// FineTuneGet creates a new fine tune get service
func (c *Client) FineTuneGet(jobID string) *FineTuneGetService {
return NewFineTuneGetService(c).SetJobID(jobID)
}
// FineTuneList creates a new fine tune list service
func (c *Client) FineTuneList() *FineTuneListService {
return NewFineTuneListService(c)
}
// FineTuneDelete creates a new fine tune delete service
func (c *Client) FineTuneDelete(jobID string) *FineTuneDeleteService {
return NewFineTuneDeleteService(c).SetJobID(jobID)
}
// FineTuneCancel creates a new fine tune cancel service
func (c *Client) FineTuneCancel(jobID string) *FineTuneCancelService {
return NewFineTuneCancelService(c).SetJobID(jobID)
}
// ImageGeneration creates a new image generation service
func (c *Client) ImageGeneration(model string) *ImageGenerationService {
return NewImageGenerationService(c).SetModel(model)
}
// KnowledgeCreate creates a new knowledge create service
func (c *Client) KnowledgeCreate() *KnowledgeCreateService {
return NewKnowledgeCreateService(c)
}
// KnowledgeEdit creates a new knowledge edit service
func (c *Client) KnowledgeEdit(knowledgeID string) *KnowledgeEditService {
return NewKnowledgeEditService(c).SetKnowledgeID(knowledgeID)
}
// KnowledgeList list all the knowledge
func (c *Client) KnowledgeList() *KnowledgeListService {
return NewKnowledgeListService(c)
}
// KnowledgeDelete creates a new knowledge delete service
func (c *Client) KnowledgeDelete(knowledgeID string) *KnowledgeDeleteService {
return NewKnowledgeDeleteService(c).SetKnowledgeID(knowledgeID)
}
// KnowledgeGet creates a new knowledge get service
func (c *Client) KnowledgeCapacity() *KnowledgeCapacityService {
return NewKnowledgeCapacityService(c)
}
// VideoGeneration creates a new video generation service
func (c *Client) VideoGeneration(model string) *VideoGenerationService {
return NewVideoGenerationService(c).SetModel(model)
}
// AsyncResult creates a new async result get service
func (c *Client) AsyncResult(id string) *AsyncResultService {
return NewAsyncResultService(c).SetID(id)
}

17
zhipu/client_test.go Normal file
View File

@ -0,0 +1,17 @@
package zhipu
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestClientR(t *testing.T) {
c, err := NewClient()
require.NoError(t, err)
// the only free api is to list fine-tuning jobs
res, err := c.request(context.Background()).Get("fine_tuning/jobs")
require.NoError(t, err)
require.True(t, res.IsSuccess())
}

25
zhipu/cog.toml Normal file
View File

@ -0,0 +1,25 @@
from_latest_tag = false
ignore_merge_commits = false
disable_changelog = false
disable_bump_commit = false
generate_mono_repository_global_tag = true
branch_whitelist = []
skip_ci = "[skip ci]"
skip_untracked = false
pre_bump_hooks = []
post_bump_hooks = []
pre_package_bump_hooks = []
post_package_bump_hooks = []
tag_prefix = "v"
[git_hooks]
[commit_types]
[changelog]
path = "CHANGELOG.md"
authors = []
[bump_profiles]
[packages]

87
zhipu/embedding.go Normal file
View File

@ -0,0 +1,87 @@
package zhipu
import (
"context"
"github.com/go-resty/resty/v2"
)
// EmbeddingData is the data for each embedding.
type EmbeddingData struct {
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
Object string `json:"object"`
}
// EmbeddingResponse is the response from the embedding service.
type EmbeddingResponse struct {
Model string `json:"model"`
Data []EmbeddingData `json:"data"`
Object string `json:"object"`
Usage ChatCompletionUsage `json:"usage"`
}
// EmbeddingService embeds a list of text into a vector space.
type EmbeddingService struct {
client *Client
model string
input string
}
var (
_ BatchSupport = &EmbeddingService{}
)
// NewEmbeddingService creates a new EmbeddingService.
func NewEmbeddingService(client *Client) *EmbeddingService {
return &EmbeddingService{client: client}
}
func (s *EmbeddingService) BatchMethod() string {
return "POST"
}
func (s *EmbeddingService) BatchURL() string {
return BatchEndpointV4Embeddings
}
func (s *EmbeddingService) BatchBody() any {
return s.buildBody()
}
// SetModel sets the model to use for the embedding.
func (s *EmbeddingService) SetModel(model string) *EmbeddingService {
s.model = model
return s
}
// SetInput sets the input text to embed.
func (s *EmbeddingService) SetInput(input string) *EmbeddingService {
s.input = input
return s
}
func (s *EmbeddingService) buildBody() M {
return M{"model": s.model, "input": s.input}
}
func (s *EmbeddingService) Do(ctx context.Context) (res EmbeddingResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetBody(s.buildBody()).
SetResult(&res).
SetError(&apiError).
Post("embeddings"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}

21
zhipu/embedding_test.go Normal file
View File

@ -0,0 +1,21 @@
package zhipu
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestEmbeddingService(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
service := client.Embedding("embedding-2")
resp, err := service.SetInput("你好").Do(context.Background())
require.NoError(t, err)
require.NotZero(t, resp.Usage.TotalTokens)
require.NotEmpty(t, resp.Data)
require.NotEmpty(t, resp.Data[0].Embedding)
}

58
zhipu/error.go Normal file
View File

@ -0,0 +1,58 @@
package zhipu
type APIError struct {
Code string `json:"code"`
Message string `json:"message"`
}
func (e APIError) Error() string {
return e.Message
}
type APIErrorResponse struct {
APIError `json:"error"`
}
func (e APIErrorResponse) Error() string {
return e.APIError.Error()
}
// GetAPIErrorCode returns the error code of an API error.
func GetAPIErrorCode(err error) string {
if err == nil {
return ""
}
if e, ok := err.(APIError); ok {
return e.Code
}
if e, ok := err.(APIErrorResponse); ok {
return e.Code
}
if e, ok := err.(*APIError); ok && e != nil {
return e.Code
}
if e, ok := err.(*APIErrorResponse); ok && e != nil {
return e.Code
}
return ""
}
// GetAPIErrorMessage returns the error message of an API error.
func GetAPIErrorMessage(err error) string {
if err == nil {
return ""
}
if e, ok := err.(APIError); ok {
return e.Message
}
if e, ok := err.(APIErrorResponse); ok {
return e.Message
}
if e, ok := err.(*APIError); ok && e != nil {
return e.Message
}
if e, ok := err.(*APIErrorResponse); ok && e != nil {
return e.Message
}
return err.Error()
}

38
zhipu/error_test.go Normal file
View File

@ -0,0 +1,38 @@
package zhipu
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestAPIError(t *testing.T) {
err := APIError{
Code: "code",
Message: "message",
}
require.Equal(t, "message", err.Error())
require.Equal(t, "code", GetAPIErrorCode(err))
require.Equal(t, "message", GetAPIErrorMessage(err))
}
func TestAPIErrorResponse(t *testing.T) {
err := APIErrorResponse{
APIError: APIError{
Code: "code",
Message: "message",
},
}
require.Equal(t, "message", err.Error())
require.Equal(t, "code", GetAPIErrorCode(err))
require.Equal(t, "message", GetAPIErrorMessage(err))
}
func TestAPIErrorResponseFromDoc(t *testing.T) {
var res APIErrorResponse
err := json.Unmarshal([]byte(`{"error":{"code":"1002","message":"Authorization Token非法请确认Authorization Token正确传递。"}}`), &res)
require.NoError(t, err)
require.Equal(t, "1002", res.Code)
require.Equal(t, "1002", GetAPIErrorCode(res))
}

541
zhipu/file.go Normal file
View File

@ -0,0 +1,541 @@
package zhipu
import (
"context"
"errors"
"io"
"os"
"path/filepath"
"strconv"
"github.com/go-resty/resty/v2"
)
const (
FilePurposeFineTune = "fine-tune"
FilePurposeRetrieval = "retrieval"
FilePurposeBatch = "batch"
KnowledgeTypeArticle = 1
KnowledgeTypeQADocument = 2
KnowledgeTypeQASpreadsheet = 3
KnowledgeTypeProductDatabaseSpreadsheet = 4
KnowledgeTypeCustom = 5
)
// FileCreateService is a service to create a file.
type FileCreateService struct {
client *Client
purpose string
localFile string
file io.Reader
filename string
customSeparator *string
sentenceSize *int
knowledgeID *string
}
// FileCreateKnowledgeSuccessInfo is the success info of the FileCreateKnowledgeResponse.
type FileCreateKnowledgeSuccessInfo struct {
Filename string `json:"fileName"`
DocumentID string `json:"documentId"`
}
// FileCreateKnowledgeFailedInfo is the failed info of the FileCreateKnowledgeResponse.
type FileCreateKnowledgeFailedInfo struct {
Filename string `json:"fileName"`
FailReason string `json:"failReason"`
}
// FileCreateKnowledgeResponse is the response of the FileCreateService.
type FileCreateKnowledgeResponse struct {
SuccessInfos []FileCreateKnowledgeSuccessInfo `json:"successInfos"`
FailedInfos []FileCreateKnowledgeFailedInfo `json:"failedInfos"`
}
// FileCreateFineTuneResponse is the response of the FileCreateService.
type FileCreateFineTuneResponse struct {
Bytes int64 `json:"bytes"`
CreatedAt int64 `json:"created_at"`
Filename string `json:"filename"`
Object string `json:"object"`
Purpose string `json:"purpose"`
ID string `json:"id"`
}
// FileCreateResponse is the response of the FileCreateService.
type FileCreateResponse struct {
FileCreateFineTuneResponse
FileCreateKnowledgeResponse
}
// NewFileCreateService creates a new FileCreateService.
func NewFileCreateService(client *Client) *FileCreateService {
return &FileCreateService{client: client}
}
// SetLocalFile sets the local_file parameter of the FileCreateService.
func (s *FileCreateService) SetLocalFile(localFile string) *FileCreateService {
s.localFile = localFile
return s
}
// SetFile sets the file parameter of the FileCreateService.
func (s *FileCreateService) SetFile(file io.Reader, filename string) *FileCreateService {
s.file = file
s.filename = filename
return s
}
// SetPurpose sets the purpose parameter of the FileCreateService.
func (s *FileCreateService) SetPurpose(purpose string) *FileCreateService {
s.purpose = purpose
return s
}
// SetCustomSeparator sets the custom_separator parameter of the FileCreateService.
func (s *FileCreateService) SetCustomSeparator(customSeparator string) *FileCreateService {
s.customSeparator = &customSeparator
return s
}
// SetSentenceSize sets the sentence_size parameter of the FileCreateService.
func (s *FileCreateService) SetSentenceSize(sentenceSize int) *FileCreateService {
s.sentenceSize = &sentenceSize
return s
}
// SetKnowledgeID sets the knowledge_id parameter of the FileCreateService.
func (s *FileCreateService) SetKnowledgeID(knowledgeID string) *FileCreateService {
s.knowledgeID = &knowledgeID
return s
}
// Do makes the request.
func (s *FileCreateService) Do(ctx context.Context) (res FileCreateResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
body := map[string]string{"purpose": s.purpose}
if s.customSeparator != nil {
body["custom_separator"] = *s.customSeparator
}
if s.sentenceSize != nil {
body["sentence_size"] = strconv.Itoa(*s.sentenceSize)
}
if s.knowledgeID != nil {
body["knowledge_id"] = *s.knowledgeID
}
file, filename := s.file, s.filename
if file == nil && s.localFile != "" {
var f *os.File
if f, err = os.Open(s.localFile); err != nil {
return
}
defer f.Close()
file = f
filename = filepath.Base(s.localFile)
}
if file == nil {
err = errors.New("no file specified")
return
}
if resp, err = s.client.request(ctx).
SetFileReader("file", filename, file).
SetMultipartFormData(body).
SetResult(&res).
SetError(&apiError).
Post("files"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FileEditService is a service to edit a file.
type FileEditService struct {
client *Client
documentID string
knowledgeType *int
customSeparator []string
sentenceSize *int
}
// NewFileEditService creates a new FileEditService.
func NewFileEditService(client *Client) *FileEditService {
return &FileEditService{client: client}
}
// SetDocumentID sets the document_id parameter of the FileEditService.
func (s *FileEditService) SetDocumentID(documentID string) *FileEditService {
s.documentID = documentID
return s
}
// SetKnowledgeType sets the knowledge_type parameter of the FileEditService.
func (s *FileEditService) SetKnowledgeType(knowledgeType int) *FileEditService {
s.knowledgeType = &knowledgeType
return s
}
// SetSentenceSize sets the sentence_size parameter of the FileEditService.
func (s *FileEditService) SetCustomSeparator(customSeparator ...string) *FileEditService {
s.customSeparator = customSeparator
return s
}
// SetSentenceSize sets the sentence_size parameter of the FileEditService.
func (s *FileEditService) SetSentenceSize(sentenceSize int) *FileEditService {
s.sentenceSize = &sentenceSize
return s
}
// Do makes the request.
func (s *FileEditService) Do(ctx context.Context) (err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
body := M{}
if s.knowledgeType != nil {
body["knowledge_type"] = strconv.Itoa(*s.knowledgeType)
}
if len(s.customSeparator) > 0 {
body["custom_separator"] = s.customSeparator
}
if s.sentenceSize != nil {
body["sentence_size"] = strconv.Itoa(*s.sentenceSize)
}
if resp, err = s.client.request(ctx).
SetPathParam("document_id", s.documentID).
SetBody(body).
SetError(&apiError).
Put("document/{document_id}"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FileListService is a service to list files.
type FileListService struct {
client *Client
purpose string
knowledgeID *string
page *int
limit *int
after *string
orderAsc *bool
}
// FileFailInfo is the failed info of the FileListKnowledgeItem.
type FileFailInfo struct {
EmbeddingCode int `json:"embedding_code"`
EmbeddingMsg string `json:"embedding_msg"`
}
// FileListKnowledgeItem is the item of the FileListKnowledgeResponse.
type FileListKnowledgeItem struct {
ID string `json:"id"`
Name string `json:"name"`
URL string `json:"url"`
Length int64 `json:"length"`
SentenceSize int64 `json:"sentence_size"`
CustomSeparator []string `json:"custom_separator"`
EmbeddingStat int `json:"embedding_stat"`
FailInfo *FileFailInfo `json:"failInfo"`
WordNum int64 `json:"word_num"`
ParseImage int `json:"parse_image"`
}
// FileListKnowledgeResponse is the response of the FileListService.
type FileListKnowledgeResponse struct {
Total int `json:"total"`
List []FileListKnowledgeItem `json:"list"`
}
// FileListFineTuneItem is the item of the FileListFineTuneResponse.
type FileListFineTuneItem struct {
Bytes int64 `json:"bytes"`
CreatedAt int64 `json:"created_at"`
Filename string `json:"filename"`
ID string `json:"id"`
Object string `json:"object"`
Purpose string `json:"purpose"`
}
// FileListFineTuneResponse is the response of the FileListService.
type FileListFineTuneResponse struct {
Object string `json:"object"`
Data []FileListFineTuneItem `json:"data"`
}
// FileListResponse is the response of the FileListService.
type FileListResponse struct {
FileListKnowledgeResponse
FileListFineTuneResponse
}
// NewFileListService creates a new FileListService.
func NewFileListService(client *Client) *FileListService {
return &FileListService{client: client}
}
// SetPurpose sets the purpose parameter of the FileListService.
func (s *FileListService) SetPurpose(purpose string) *FileListService {
s.purpose = purpose
return s
}
// SetKnowledgeID sets the knowledge_id parameter of the FileListService.
func (s *FileListService) SetKnowledgeID(knowledgeID string) *FileListService {
s.knowledgeID = &knowledgeID
return s
}
// SetPage sets the page parameter of the FileListService.
func (s *FileListService) SetPage(page int) *FileListService {
s.page = &page
return s
}
// SetLimit sets the limit parameter of the FileListService.
func (s *FileListService) SetLimit(limit int) *FileListService {
s.limit = &limit
return s
}
// SetAfter sets the after parameter of the FileListService.
func (s *FileListService) SetAfter(after string) *FileListService {
s.after = &after
return s
}
// SetOrder sets the order parameter of the FileListService.
func (s *FileListService) SetOrder(asc bool) *FileListService {
s.orderAsc = &asc
return s
}
// Do makes the request.
func (s *FileListService) Do(ctx context.Context) (res FileListResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
m := map[string]string{
"purpose": s.purpose,
}
if s.knowledgeID != nil {
m["knowledge_id"] = *s.knowledgeID
}
if s.page != nil {
m["page"] = strconv.Itoa(*s.page)
}
if s.limit != nil {
m["limit"] = strconv.Itoa(*s.limit)
}
if s.after != nil {
m["after"] = *s.after
}
if s.orderAsc != nil {
if *s.orderAsc {
m["order"] = "asc"
} else {
m["order"] = "desc"
}
}
if resp, err = s.client.request(ctx).
SetQueryParams(m).
SetResult(&res).
SetError(&apiError).
Get("files"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FileDeleteService is a service to delete a file.
type FileDeleteService struct {
client *Client
documentID string
}
// NewFileDeleteService creates a new FileDeleteService.
func NewFileDeleteService(client *Client) *FileDeleteService {
return &FileDeleteService{client: client}
}
// SetDocumentID sets the document_id parameter of the FileDeleteService.
func (s *FileDeleteService) SetDocumentID(documentID string) *FileDeleteService {
s.documentID = documentID
return s
}
// Do makes the request.
func (s *FileDeleteService) Do(ctx context.Context) (err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetPathParam("document_id", s.documentID).
SetError(&apiError).
Delete("document/{document_id}"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FileGetService is a service to get a file.
type FileGetService struct {
client *Client
documentID string
}
// FileGetResponse is the response of the FileGetService.
type FileGetResponse = FileListKnowledgeItem
// NewFileGetService creates a new FileGetService.
func NewFileGetService(client *Client) *FileGetService {
return &FileGetService{client: client}
}
// SetDocumentID sets the document_id parameter of the FileGetService.
func (s *FileGetService) SetDocumentID(documentID string) *FileGetService {
s.documentID = documentID
return s
}
// Do makes the request.
func (s *FileGetService) Do(ctx context.Context) (res FileGetResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetPathParam("document_id", s.documentID).
SetResult(&res).
SetError(&apiError).
Get("document/{document_id}"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FileDownloadService is a service to download a file.
type FileDownloadService struct {
client *Client
fileID string
writer io.Writer
filename string
}
// NewFileDownloadService creates a new FileDownloadService.
func NewFileDownloadService(client *Client) *FileDownloadService {
return &FileDownloadService{client: client}
}
// SetFileID sets the file_id parameter of the FileDownloadService.
func (s *FileDownloadService) SetFileID(fileID string) *FileDownloadService {
s.fileID = fileID
return s
}
// SetOutput sets the output parameter of the FileDownloadService.
func (s *FileDownloadService) SetOutput(w io.Writer) *FileDownloadService {
s.writer = w
return s
}
// SetOutputFile sets the output_file parameter of the FileDownloadService.
func (s *FileDownloadService) SetOutputFile(filename string) *FileDownloadService {
s.filename = filename
return s
}
// Do makes the request.
func (s *FileDownloadService) Do(ctx context.Context) (err error) {
var resp *resty.Response
writer := s.writer
if writer == nil && s.filename != "" {
var f *os.File
if f, err = os.Create(s.filename); err != nil {
return
}
defer f.Close()
writer = f
}
if writer == nil {
return errors.New("no output specified")
}
if resp, err = s.client.request(ctx).
SetDoNotParseResponse(true).
SetPathParam("file_id", s.fileID).
Get("files/{file_id}/content"); err != nil {
return
}
defer resp.RawBody().Close()
_, err = io.Copy(writer, resp.RawBody())
return
}

71
zhipu/file_test.go Normal file
View File

@ -0,0 +1,71 @@
package zhipu
import (
"context"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
)
func TestFileServiceFineTune(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.FileCreate(FilePurposeFineTune)
s.SetLocalFile(filepath.Join("testdata", "test-file.jsonl"))
res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotZero(t, res.Bytes)
require.NotZero(t, res.CreatedAt)
require.NotEmpty(t, res.ID)
}
func TestFileServiceKnowledge(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.FileCreate(FilePurposeRetrieval)
s.SetKnowledgeID(os.Getenv("TEST_KNOWLEDGE_ID"))
s.SetLocalFile(filepath.Join("testdata", "test-file.txt"))
res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.SuccessInfos)
require.NotEmpty(t, res.SuccessInfos[0].DocumentID)
require.NotEmpty(t, res.SuccessInfos[0].Filename)
documentID := res.SuccessInfos[0].DocumentID
res2, err := client.FileGet(documentID).Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res2.ID)
err = client.FileEdit(documentID).SetKnowledgeType(KnowledgeTypeCustom).Do(context.Background())
require.True(t, err == nil || GetAPIErrorCode(err) == "10019")
err = client.FileDelete(res.SuccessInfos[0].DocumentID).Do(context.Background())
require.True(t, err == nil || GetAPIErrorCode(err) == "10019")
}
func TestFileListServiceKnowledge(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.FileList(FilePurposeRetrieval).SetKnowledgeID(os.Getenv("TEST_KNOWLEDGE_ID"))
res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.List)
}
func TestFileListServiceFineTune(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.FileList(FilePurposeFineTune)
res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.Data)
}

456
zhipu/fine_tune.go Normal file
View File

@ -0,0 +1,456 @@
package zhipu
import (
"context"
"strconv"
"github.com/go-resty/resty/v2"
)
const (
HyperParameterAuto = "auto"
FineTuneStatusCreate = "create"
FineTuneStatusValidatingFiles = "validating_files"
FineTuneStatusQueued = "queued"
FineTuneStatusRunning = "running"
FineTuneStatusSucceeded = "succeeded"
FineTuneStatusFailed = "failed"
FineTuneStatusCancelled = "cancelled"
)
// FineTuneItem is the item of the FineTune
type FineTuneItem struct {
ID string `json:"id"`
RequestID string `json:"request_id"`
FineTunedModel string `json:"fine_tuned_model"`
Status string `json:"status"`
Object string `json:"object"`
TrainingFile string `json:"training_file"`
ValidationFile string `json:"validation_file"`
Error APIError `json:"error"`
}
// FineTuneCreateService creates a new fine tune
type FineTuneCreateService struct {
client *Client
model string
trainingFile string
validationFile *string
learningRateMultiplier *StringOr[float64]
batchSize *StringOr[int]
nEpochs *StringOr[int]
suffix *string
requestID *string
}
// FineTuneCreateResponse is the response of the FineTuneCreateService
type FineTuneCreateResponse = FineTuneItem
// NewFineTuneCreateService creates a new FineTuneCreateService
func NewFineTuneCreateService(client *Client) *FineTuneCreateService {
return &FineTuneCreateService{
client: client,
}
}
// SetModel sets the model parameter
func (s *FineTuneCreateService) SetModel(model string) *FineTuneCreateService {
s.model = model
return s
}
// SetTrainingFile sets the trainingFile parameter
func (s *FineTuneCreateService) SetTrainingFile(trainingFile string) *FineTuneCreateService {
s.trainingFile = trainingFile
return s
}
// SetValidationFile sets the validationFile parameter
func (s *FineTuneCreateService) SetValidationFile(validationFile string) *FineTuneCreateService {
s.validationFile = &validationFile
return s
}
// SetLearningRateMultiplier sets the learningRateMultiplier parameter
func (s *FineTuneCreateService) SetLearningRateMultiplier(learningRateMultiplier float64) *FineTuneCreateService {
s.learningRateMultiplier = &StringOr[float64]{}
s.learningRateMultiplier.SetValue(learningRateMultiplier)
return s
}
// SetLearningRateMultiplierAuto sets the learningRateMultiplier parameter to auto
func (s *FineTuneCreateService) SetLearningRateMultiplierAuto() *FineTuneCreateService {
s.learningRateMultiplier = &StringOr[float64]{}
s.learningRateMultiplier.SetString(HyperParameterAuto)
return s
}
// SetBatchSize sets the batchSize parameter
func (s *FineTuneCreateService) SetBatchSize(batchSize int) *FineTuneCreateService {
s.batchSize = &StringOr[int]{}
s.batchSize.SetValue(batchSize)
return s
}
// SetBatchSizeAuto sets the batchSize parameter to auto
func (s *FineTuneCreateService) SetBatchSizeAuto() *FineTuneCreateService {
s.batchSize = &StringOr[int]{}
s.batchSize.SetString(HyperParameterAuto)
return s
}
// SetNEpochs sets the nEpochs parameter
func (s *FineTuneCreateService) SetNEpochs(nEpochs int) *FineTuneCreateService {
s.nEpochs = &StringOr[int]{}
s.nEpochs.SetValue(nEpochs)
return s
}
// SetNEpochsAuto sets the nEpochs parameter to auto
func (s *FineTuneCreateService) SetNEpochsAuto() *FineTuneCreateService {
s.nEpochs = &StringOr[int]{}
s.nEpochs.SetString(HyperParameterAuto)
return s
}
// SetSuffix sets the suffix parameter
func (s *FineTuneCreateService) SetSuffix(suffix string) *FineTuneCreateService {
s.suffix = &suffix
return s
}
// SetRequestID sets the requestID parameter
func (s *FineTuneCreateService) SetRequestID(requestID string) *FineTuneCreateService {
s.requestID = &requestID
return s
}
// Do makes the request
func (s *FineTuneCreateService) Do(ctx context.Context) (res FineTuneCreateResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
body := M{
"model": s.model,
"training_file": s.trainingFile,
}
if s.validationFile != nil {
body["validation_file"] = *s.validationFile
}
if s.suffix != nil {
body["suffix"] = *s.suffix
}
if s.requestID != nil {
body["request_id"] = *s.requestID
}
if s.learningRateMultiplier != nil || s.batchSize != nil || s.nEpochs != nil {
hp := M{}
if s.learningRateMultiplier != nil {
hp["learning_rate_multiplier"] = s.learningRateMultiplier
}
if s.batchSize != nil {
hp["batch_size"] = s.batchSize
}
if s.nEpochs != nil {
hp["n_epochs"] = s.nEpochs
}
body["hyperparameters"] = hp
}
if resp, err = s.client.request(ctx).
SetBody(body).
SetResult(&res).
SetError(&apiError).
Post("fine_tuning/jobs"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FineTuneEventListService creates a new fine tune event list
type FineTuneEventListService struct {
client *Client
jobID string
limit *int
after *string
}
// FineTuneEventData is the data of the FineTuneEventItem
type FineTuneEventData struct {
Acc float64 `json:"acc"`
Loss float64 `json:"loss"`
CurrentSteps int64 `json:"current_steps"`
RemainingTime string `json:"remaining_time"`
ElapsedTime string `json:"elapsed_time"`
TotalSteps int64 `json:"total_steps"`
Epoch int64 `json:"epoch"`
TrainedTokens int64 `json:"trained_tokens"`
LearningRate float64 `json:"learning_rate"`
}
// FineTuneEventItem is the item of the FineTuneEventListResponse
type FineTuneEventItem struct {
ID string `json:"id"`
Type string `json:"type"`
Level string `json:"level"`
Message string `json:"message"`
Object string `json:"object"`
CreatedAt int64 `json:"created_at"`
Data FineTuneEventData `json:"data"`
}
// FineTuneEventListResponse is the response of the FineTuneEventListService
type FineTuneEventListResponse struct {
Data []FineTuneEventItem `json:"data"`
HasMore bool `json:"has_more"`
Object string `json:"object"`
}
// NewFineTuneEventListService creates a new FineTuneEventListService
func NewFineTuneEventListService(client *Client) *FineTuneEventListService {
return &FineTuneEventListService{
client: client,
}
}
// SetJobID sets the jobID parameter
func (s *FineTuneEventListService) SetJobID(jobID string) *FineTuneEventListService {
s.jobID = jobID
return s
}
// SetLimit sets the limit parameter
func (s *FineTuneEventListService) SetLimit(limit int) *FineTuneEventListService {
s.limit = &limit
return s
}
// SetAfter sets the after parameter
func (s *FineTuneEventListService) SetAfter(after string) *FineTuneEventListService {
s.after = &after
return s
}
// Do makes the request
func (s *FineTuneEventListService) Do(ctx context.Context) (res FineTuneEventListResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
req := s.client.request(ctx)
if s.limit != nil {
req.SetQueryParam("limit", strconv.Itoa(*s.limit))
}
if s.after != nil {
req.SetQueryParam("after", *s.after)
}
if resp, err = req.
SetPathParam("job_id", s.jobID).
SetResult(&res).
SetError(&apiError).
Get("fine_tuning/jobs/{job_id}/events"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FineTuneGetService creates a new fine tune get
type FineTuneGetService struct {
client *Client
jobID string
}
// NewFineTuneGetService creates a new FineTuneGetService
func NewFineTuneGetService(client *Client) *FineTuneGetService {
return &FineTuneGetService{
client: client,
}
}
// SetJobID sets the jobID parameter
func (s *FineTuneGetService) SetJobID(jobID string) *FineTuneGetService {
s.jobID = jobID
return s
}
// Do makes the request
func (s *FineTuneGetService) Do(ctx context.Context) (res FineTuneItem, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetPathParam("job_id", s.jobID).
SetResult(&res).
SetError(&apiError).
Get("fine_tuning/jobs/{job_id}"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FineTuneListService creates a new fine tune list
type FineTuneListService struct {
client *Client
limit *int
after *string
}
// FineTuneListResponse is the response of the FineTuneListService
type FineTuneListResponse struct {
Data []FineTuneItem `json:"data"`
Object string `json:"object"`
}
// NewFineTuneListService creates a new FineTuneListService
func NewFineTuneListService(client *Client) *FineTuneListService {
return &FineTuneListService{
client: client,
}
}
// SetLimit sets the limit parameter
func (s *FineTuneListService) SetLimit(limit int) *FineTuneListService {
s.limit = &limit
return s
}
// SetAfter sets the after parameter
func (s *FineTuneListService) SetAfter(after string) *FineTuneListService {
s.after = &after
return s
}
// Do makes the request
func (s *FineTuneListService) Do(ctx context.Context) (res FineTuneListResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
req := s.client.request(ctx)
if s.limit != nil {
req.SetQueryParam("limit", strconv.Itoa(*s.limit))
}
if s.after != nil {
req.SetQueryParam("after", *s.after)
}
if resp, err = req.
SetResult(&res).
SetError(&apiError).
Get("fine_tuning/jobs"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FineTuneDeleteService creates a new fine tune delete
type FineTuneDeleteService struct {
client *Client
jobID string
}
// NewFineTuneDeleteService creates a new FineTuneDeleteService
func NewFineTuneDeleteService(client *Client) *FineTuneDeleteService {
return &FineTuneDeleteService{
client: client,
}
}
// SetJobID sets the jobID parameter
func (s *FineTuneDeleteService) SetJobID(jobID string) *FineTuneDeleteService {
s.jobID = jobID
return s
}
// Do makes the request
func (s *FineTuneDeleteService) Do(ctx context.Context) (res FineTuneItem, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetPathParam("job_id", s.jobID).
SetResult(&res).
SetError(&apiError).
Delete("fine_tuning/jobs/{job_id}"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FineTuneCancelService creates a new fine tune cancel
type FineTuneCancelService struct {
client *Client
jobID string
}
// NewFineTuneCancelService creates a new FineTuneCancelService
func NewFineTuneCancelService(client *Client) *FineTuneCancelService {
return &FineTuneCancelService{
client: client,
}
}
// SetJobID sets the jobID parameter
func (s *FineTuneCancelService) SetJobID(jobID string) *FineTuneCancelService {
s.jobID = jobID
return s
}
// Do makes the request
func (s *FineTuneCancelService) Do(ctx context.Context) (res FineTuneItem, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetPathParam("job_id", s.jobID).
SetResult(&res).
SetError(&apiError).
Post("fine_tuning/jobs/{job_id}/cancel"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}

3
zhipu/fine_tune_test.go Normal file
View File

@ -0,0 +1,3 @@
package zhipu
// tests not available since lack of budget to test it

110
zhipu/image_generation.go Normal file
View File

@ -0,0 +1,110 @@
package zhipu
import (
"context"
"github.com/go-resty/resty/v2"
)
// ImageGenerationService creates a new image generation
type ImageGenerationService struct {
client *Client
model string
prompt string
size string
userID string
}
var (
_ BatchSupport = &ImageGenerationService{}
)
// ImageGenerationResponse is the response of the ImageGenerationService
type ImageGenerationResponse struct {
Created int64 `json:"created"`
Data []URLItem `json:"data"`
}
// NewImageGenerationService creates a new ImageGenerationService
func NewImageGenerationService(client *Client) *ImageGenerationService {
return &ImageGenerationService{
client: client,
}
}
func (s *ImageGenerationService) BatchMethod() string {
return "POST"
}
func (s *ImageGenerationService) BatchURL() string {
return BatchEndpointV4ImagesGenerations
}
func (s *ImageGenerationService) BatchBody() any {
return s.buildBody()
}
// SetModel sets the model parameter
func (s *ImageGenerationService) SetModel(model string) *ImageGenerationService {
s.model = model
return s
}
// SetPrompt sets the prompt parameter
func (s *ImageGenerationService) SetPrompt(prompt string) *ImageGenerationService {
s.prompt = prompt
return s
}
func (s *ImageGenerationService) SetSize(size string) *ImageGenerationService {
s.size = size
return s
}
// SetUserID sets the userID parameter
func (s *ImageGenerationService) SetUserID(userID string) *ImageGenerationService {
s.userID = userID
return s
}
func (s *ImageGenerationService) buildBody() M {
body := M{
"model": s.model,
"prompt": s.prompt,
}
if s.userID != "" {
body["user_id"] = s.userID
}
if s.size != "" {
body["size"] = s.size
}
return body
}
func (s *ImageGenerationService) Do(ctx context.Context) (res ImageGenerationResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
body := s.buildBody()
if resp, err = s.client.request(ctx).
SetBody(body).
SetResult(&res).
SetError(&apiError).
Post("images/generations"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}

View File

@ -0,0 +1,21 @@
package zhipu
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestImageGenerationService(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.ImageGeneration("cogview-3")
s.SetPrompt("一只可爱的小猫")
res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.Data)
t.Log(res.Data[0].URL)
}

299
zhipu/knowledge.go Normal file
View File

@ -0,0 +1,299 @@
package zhipu
import (
"context"
"strconv"
"github.com/go-resty/resty/v2"
)
const (
KnowledgeEmbeddingIDEmbedding2 = 3
)
// KnowledgeCreateService creates a new knowledge
type KnowledgeCreateService struct {
client *Client
embeddingID int
name string
description *string
}
// KnowledgeCreateResponse is the response of the KnowledgeCreateService
type KnowledgeCreateResponse = IDItem
// NewKnowledgeCreateService creates a new KnowledgeCreateService
func NewKnowledgeCreateService(client *Client) *KnowledgeCreateService {
return &KnowledgeCreateService{
client: client,
}
}
// SetEmbeddingID sets the embedding id of the knowledge
func (s *KnowledgeCreateService) SetEmbeddingID(embeddingID int) *KnowledgeCreateService {
s.embeddingID = embeddingID
return s
}
// SetName sets the name of the knowledge
func (s *KnowledgeCreateService) SetName(name string) *KnowledgeCreateService {
s.name = name
return s
}
// SetDescription sets the description of the knowledge
func (s *KnowledgeCreateService) SetDescription(description string) *KnowledgeCreateService {
s.description = &description
return s
}
// Do creates the knowledge
func (s *KnowledgeCreateService) Do(ctx context.Context) (res KnowledgeCreateResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
body := M{
"name": s.name,
"embedding_id": s.embeddingID,
}
if s.description != nil {
body["description"] = *s.description
}
if resp, err = s.client.request(ctx).
SetBody(body).
SetResult(&res).
SetError(&apiError).
Post("knowledge"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// KnowledgeEditService edits a knowledge
type KnowledgeEditService struct {
client *Client
knowledgeID string
embeddingID *int
name *string
description *string
}
// NewKnowledgeEditService creates a new KnowledgeEditService
func NewKnowledgeEditService(client *Client) *KnowledgeEditService {
return &KnowledgeEditService{
client: client,
}
}
// SetKnowledgeID sets the knowledge id
func (s *KnowledgeEditService) SetKnowledgeID(knowledgeID string) *KnowledgeEditService {
s.knowledgeID = knowledgeID
return s
}
// SetName sets the name of the knowledge
func (s *KnowledgeEditService) SetName(name string) *KnowledgeEditService {
s.name = &name
return s
}
// SetEmbeddingID sets the embedding id of the knowledge
func (s *KnowledgeEditService) SetEmbeddingID(embeddingID int) *KnowledgeEditService {
s.embeddingID = &embeddingID
return s
}
// SetDescription sets the description of the knowledge
func (s *KnowledgeEditService) SetDescription(description string) *KnowledgeEditService {
s.description = &description
return s
}
// Do edits the knowledge
func (s *KnowledgeEditService) Do(ctx context.Context) (err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
body := M{}
if s.name != nil {
body["name"] = *s.name
}
if s.description != nil {
body["description"] = *s.description
}
if s.embeddingID != nil {
body["embedding_id"] = *s.embeddingID
}
if resp, err = s.client.request(ctx).
SetPathParam("knowledge_id", s.knowledgeID).
SetBody(body).
SetError(&apiError).
Put("knowledge/{knowledge_id}"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// KnowledgeListService lists the knowledge
type KnowledgeListService struct {
client *Client
page *int
size *int
}
// KnowledgeItem is an item in the knowledge list
type KnowledgeItem struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Icon string `json:"icon"`
Background string `json:"background"`
EmbeddingID int `json:"embedding_id"`
CustomIdentifier string `json:"custom_identifier"`
WordNum int64 `json:"word_num"`
Length int64 `json:"length"`
DocumentSize int64 `json:"document_size"`
}
// KnowledgeListResponse is the response of the KnowledgeListService
type KnowledgeListResponse struct {
List []KnowledgeItem `json:"list"`
Total int `json:"total"`
}
// NewKnowledgeListService creates a new KnowledgeListService
func NewKnowledgeListService(client *Client) *KnowledgeListService {
return &KnowledgeListService{client: client}
}
// SetPage sets the page of the knowledge list
func (s *KnowledgeListService) SetPage(page int) *KnowledgeListService {
s.page = &page
return s
}
// SetSize sets the size of the knowledge list
func (s *KnowledgeListService) SetSize(size int) *KnowledgeListService {
s.size = &size
return s
}
// Do lists the knowledge
func (s *KnowledgeListService) Do(ctx context.Context) (res KnowledgeListResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
req := s.client.request(ctx)
if s.page != nil {
req.SetQueryParam("page", strconv.Itoa(*s.page))
}
if s.size != nil {
req.SetQueryParam("size", strconv.Itoa(*s.size))
}
if resp, err = req.
SetResult(&res).
SetError(&apiError).
Get("knowledge"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// KnowledgeDeleteService deletes a knowledge
type KnowledgeDeleteService struct {
client *Client
knowledgeID string
}
// NewKnowledgeDeleteService creates a new KnowledgeDeleteService
func NewKnowledgeDeleteService(client *Client) *KnowledgeDeleteService {
return &KnowledgeDeleteService{
client: client,
}
}
// SetKnowledgeID sets the knowledge id
func (s *KnowledgeDeleteService) SetKnowledgeID(knowledgeID string) *KnowledgeDeleteService {
s.knowledgeID = knowledgeID
return s
}
// Do deletes the knowledge
func (s *KnowledgeDeleteService) Do(ctx context.Context) (err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetPathParam("knowledge_id", s.knowledgeID).
SetError(&apiError).
Delete("knowledge/{knowledge_id}"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// KnowledgeCapacityService query the capacity of the knowledge
type KnowledgeCapacityService struct {
client *Client
}
// KnowledgeCapacityItem is an item in the knowledge capacity
type KnowledgeCapacityItem struct {
WordNum int64 `json:"word_num"`
Length int64 `json:"length"`
}
// KnowledgeCapacityResponse is the response of the KnowledgeCapacityService
type KnowledgeCapacityResponse struct {
Used KnowledgeCapacityItem `json:"used"`
Total KnowledgeCapacityItem `json:"total"`
}
// SetKnowledgeID sets the knowledge id
func NewKnowledgeCapacityService(client *Client) *KnowledgeCapacityService {
return &KnowledgeCapacityService{client: client}
}
// Do query the capacity of the knowledge
func (s *KnowledgeCapacityService) Do(ctx context.Context) (res KnowledgeCapacityResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetResult(&res).
SetError(&apiError).
Get("knowledge/capacity"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}

50
zhipu/knowledge_test.go Normal file
View File

@ -0,0 +1,50 @@
package zhipu
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestKnowledgeCapacity(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.KnowledgeCapacity()
res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.Total.Length)
require.NotEmpty(t, res.Total.WordNum)
}
func TestKnowledgeServiceAll(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.KnowledgeCreate()
s.SetName("test")
s.SetDescription("test description")
s.SetEmbeddingID(KnowledgeEmbeddingIDEmbedding2)
res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.ID)
s2 := client.KnowledgeList()
res2, err := s2.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res2.List)
require.Equal(t, res.ID, res2.List[0].ID)
s3 := client.KnowledgeEdit(res.ID)
s3.SetDescription("test description 2")
s3.SetName("test 2")
s3.SetEmbeddingID(KnowledgeEmbeddingIDEmbedding2)
err = s3.Do(context.Background())
require.NoError(t, err)
s4 := client.KnowledgeDelete(res.ID)
err = s4.Do(context.Background())
require.NoError(t, err)
}

54
zhipu/string_or.go Normal file
View File

@ -0,0 +1,54 @@
package zhipu
import (
"bytes"
"encoding/json"
)
// StringOr is a struct that can be either a string or a value of type T.
type StringOr[T any] struct {
String *string
Value *T
}
var (
_ json.Marshaler = StringOr[float64]{}
_ json.Unmarshaler = &StringOr[float64]{}
)
// SetString sets the string value of the struct.
func (f *StringOr[T]) SetString(v string) {
f.String = &v
f.Value = nil
}
// SetValue sets the value of the struct.
func (f *StringOr[T]) SetValue(v T) {
f.String = nil
f.Value = &v
}
func (f StringOr[T]) MarshalJSON() ([]byte, error) {
if f.Value != nil {
return json.Marshal(f.Value)
}
return json.Marshal(f.String)
}
func (f *StringOr[T]) UnmarshalJSON(data []byte) error {
if len(data) == 0 {
return nil
}
if bytes.Equal(data, []byte("null")) {
return nil
}
if data[0] == '"' {
f.String = new(string)
f.Value = nil
return json.Unmarshal(data, f.String)
} else {
f.Value = new(T)
f.String = nil
return json.Unmarshal(data, f.Value)
}
}

37
zhipu/string_or_test.go Normal file
View File

@ -0,0 +1,37 @@
package zhipu
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestStringOr(t *testing.T) {
data := struct {
Item *StringOr[int] `json:"item,omitempty"`
}{}
data.Item = &StringOr[int]{}
data.Item.SetString("test")
b, err := json.Marshal(data)
require.NoError(t, err)
require.Equal(t, `{"item":"test"}`, string(b))
data.Item.SetValue(1)
b, err = json.Marshal(data)
require.NoError(t, err)
require.Equal(t, `{"item":1}`, string(b))
err = json.Unmarshal([]byte(`{"item":"test2"}`), &data)
require.NoError(t, err)
require.NotNil(t, data.Item.String)
require.Nil(t, data.Item.Value)
require.Equal(t, "test2", *data.Item.String)
err = json.Unmarshal([]byte(`{"item":2}`), &data)
require.NoError(t, err)
require.Nil(t, data.Item.String)
require.NotNil(t, data.Item.Value)
require.Equal(t, 2, *data.Item.Value)
}

3
zhipu/testdata/test-file.jsonl vendored Normal file
View File

@ -0,0 +1,3 @@
{"messages": [{"role": "system", "content": "你是一位乐于助人知识渊博的全能AI助手。"},{"role": "user", "content": "请介绍一下AlphaGo的基本原理。"}, {"role": "assistant", "content": "AlphaGo利用深度神经网络与树搜索算法相结合,通过对弈数据进行训练,逐步提升游戏水平,能够对围棋局面进行判断并选择最佳落子位置。"}]}
{"messages": [{"role": "system", "content": "你是一位乐于助人知识渊博的全能AI助手。"},{"role": "user", "content": "请列举一些健康饮食的重要原则。"}, {"role": "assistant", "content": "健康饮食的重要原则包括:膳食应营养均衡,合理摄入各种营养素;多吃新鲜蔬果,适量饮水;少吃糖、油、盐,忌吃过量快餐;养成规律生活作息习惯,适量运动。"}]}
{"messages": [{"role": "system", "content": "你是一位乐于助人知识渊博的全能AI助手。"},{"role": "user", "content": "请给出一些改善睡眠质量的建议。"}, {"role": "assistant", "content": "改善睡眠质量的建议包括:保证睡眠时间,避免睡眠不足;睡前放松身心,远离刺激;养成规律作息时间,不要频繁熬夜;适量运动,但不要运动过度;睡前可以喝一杯热牛奶等温和饮料。"}]}

1
zhipu/testdata/test-file.txt vendored Normal file
View File

@ -0,0 +1 @@
The quick brown fox jumps over the lazy dog

22
zhipu/util.go Normal file
View File

@ -0,0 +1,22 @@
package zhipu
// URLItem is a struct that contains a URL.
type URLItem struct {
URL string `json:"url,omitempty"`
}
// IDItem is a struct that contains an ID.
type IDItem struct {
ID string `json:"id,omitempty"`
}
// Ptr returns a pointer to the value passed in.
// Example:
//
// web_search_enable = zhipu.Ptr(false)
func Ptr[T any](v T) *T {
return &v
}
// M is a shorthand for map[string]any.
type M = map[string]any

3
zhipu/util_test.go Normal file
View File

@ -0,0 +1,3 @@
package zhipu
// nothing to test

125
zhipu/video_generation.go Normal file
View File

@ -0,0 +1,125 @@
package zhipu
import (
"context"
"github.com/go-resty/resty/v2"
)
const (
VideoGenerationTaskStatusProcessing = "PROCESSING"
VideoGenerationTaskStatusSuccess = "SUCCESS"
VideoGenerationTaskStatusFail = "FAIL"
)
// VideoGenerationService creates a new video generation
type VideoGenerationService struct {
client *Client
model string
prompt string
userID string
imageURL string
requestID string
}
var (
_ BatchSupport = &VideoGenerationService{}
)
// VideoGenerationResponse is the response of the VideoGenerationService
type VideoGenerationResponse struct {
RequestID string `json:"request_id"`
ID string `json:"id"`
Model string `json:"model"`
TaskStatus string `json:"task_status"`
}
func NewVideoGenerationService(client *Client) *VideoGenerationService {
return &VideoGenerationService{
client: client,
}
}
func (s *VideoGenerationService) BatchMethod() string {
return "POST"
}
func (s *VideoGenerationService) BatchURL() string {
return BatchEndpointV4VideosGenerations
}
func (s *VideoGenerationService) BatchBody() any {
return s.buildBody()
}
// SetModel sets the model parameter
func (s *VideoGenerationService) SetModel(model string) *VideoGenerationService {
s.model = model
return s
}
// SetPrompt sets the prompt parameter
func (s *VideoGenerationService) SetPrompt(prompt string) *VideoGenerationService {
s.prompt = prompt
return s
}
// SetUserID sets the userID parameter
func (s *VideoGenerationService) SetUserID(userID string) *VideoGenerationService {
s.userID = userID
return s
}
// SetImageURL sets the imageURL parameter
func (s *VideoGenerationService) SetImageURL(imageURL string) *VideoGenerationService {
s.imageURL = imageURL
return s
}
// SetRequestID sets the requestID parameter
func (s *VideoGenerationService) SetRequestID(requestID string) *VideoGenerationService {
s.requestID = requestID
return s
}
func (s *VideoGenerationService) buildBody() M {
body := M{
"model": s.model,
"prompt": s.prompt,
}
if s.userID != "" {
body["user_id"] = s.userID
}
if s.imageURL != "" {
body["image_url"] = s.imageURL
}
if s.requestID != "" {
body["request_id"] = s.requestID
}
return body
}
func (s *VideoGenerationService) Do(ctx context.Context) (res VideoGenerationResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
body := s.buildBody()
if resp, err = s.client.request(ctx).
SetBody(body).
SetResult(&res).
SetError(&apiError).
Post("videos/generations"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}

View File

@ -0,0 +1,38 @@
package zhipu
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestVideoGeneration(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.VideoGeneration("cogvideox")
s.SetPrompt("一只可爱的小猫")
res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.TaskStatus)
require.NotEmpty(t, res.ID)
t.Log(res.ID)
for {
res, err := client.AsyncResult(res.ID).Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.TaskStatus)
if res.TaskStatus == VideoGenerationTaskStatusSuccess {
require.NotEmpty(t, res.VideoResult)
t.Log(res.VideoResult[0].URL)
t.Log(res.VideoResult[0].CoverImageURL)
}
if res.TaskStatus != VideoGenerationTaskStatusProcessing {
break
}
time.Sleep(time.Second * 5)
}
}

BIN
zhipu/wechat-donation.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB