support js

This commit is contained in:
Star 2024-09-17 18:44:21 +08:00
parent 342f2d9c09
commit b4cddad489
64 changed files with 6465 additions and 176 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
.*
go.sum
env.yml
/tests/lib

204
README.md Normal file
View File

@ -0,0 +1,204 @@
# 低代码AI能力工具包
## 命令行工具
### Install
```shell
go install apigo.cc/ai/ai/ai@latest
```
### Usage
```shell
ai -h | --help show usage
ai -e | --export export ai.ts file to ./lib for develop
ai test | test.js run test.js, if not specified, run ./ai.js
```
### Sample
#### test.js
```javascript
import {glm} from './lib/ai'
import console from './lib/console'
function main(...args) {
if(!args[0]) throw new Error('no ask')
let r = glm.fastAsk(args[0], r => {
console.print(r)
})
console.println()
return r
}
```
#### run sample
```shell
ai test "你好"
```
### Module
#### mod/glm.js
```javascript
import {glm} from './lib/ai'
import console from './lib/console'
function chat(...args) {
if(!args[0]) throw new Error('no prompt')
return glm.fastAsk(args[0])
}
function draw(...args) {
if(!args[0]) throw new Error('no ask')
return glm.makeImage(args[0], {size:'1024x1024'})
}
module.exports = {chat, draw}
```
#### test.js
```javascript
import glm from './mod/glm'
function main(...args) {
console.println(glm.chat(args[0]))
}
```
#### run sample
```shell
ai test "你好"
```
### Configure
#### llm.yml
```yaml
openai:
apiKey: ...
zhipu:
apiKey: ...
```
#### or use env.yml
#### llm.yml
```yaml
llm:
openai:
apiKey: ...
zhipu:
apiKey: ...
```
#### encrypt apiKey
install sskey
```shell
go install github.com/ssgo/tool/sskey@latest
sskey -e 'your apiKey'
```
copy url base64 format encrypted apiKey into llm.yml or env.yml
#### config with special endpoint
```yaml
llm:
openai:
apiKey: ...
endpoint: https://api.openai.com/v1
```
#### config multi api
```yaml
llm:
glm:
apiKey: ...
llm: zhipu
glm2:
apiKey: ...
endpoint: https://......
llm: zhipu
```
## 调用 JavaScript API
### Install
```shell
go get apigo.cc/ai/ai
```
### Usage
```go
package main
import (
"apigo.cc/ai/ai/js"
"fmt"
)
func main() {
result, err := js.Run(`return ai.glm.fastAsk(args[0])`, "", "你好")
// js.RunFile
// js.StartFromFile
// js.StartFromCode
if err != nil {
fmt.Println(err.Error())
} else if result != nil {
fmt.Println(result)
}
}
```
## 调用 Go API
### Install
```shell
go get apigo.cc/ai/ai
```
### Usage
```go
package main
import (
"apigo.cc/ai/ai"
"apigo.cc/ai/ai/llm"
"fmt"
)
func main() {
ai.Init()
glm := llm.Get("zhipu")
r, usage, err := glm.FastAsk(llm.Messages().User().Text("你是什么模型").Make(), func(text string) {
fmt.Print(text)
})
if err != nil {
fmt.Println(err)
} else {
fmt.Println()
fmt.Println("result:", r)
fmt.Println("usage:", usage)
}
}
```

View File

@ -1,89 +0,0 @@
package ai
import (
"apigo.cc/ai/agent"
"github.com/ssgo/config"
"github.com/ssgo/u"
"sync"
)
var confAes = u.NewAes([]byte("?GQ$0K0GgLdO=f+~L68PLm$uhKr4'=tV"), []byte("VFs7@sK61cj^f?HZ"))
var keysIsSet = false
func SetSSKey(key, iv []byte) {
if !keysIsSet {
confAes = u.NewAes(key, iv)
keysIsSet = true
}
}
const (
TypeText = agent.TypeText
TypeImage = agent.TypeImage
TypeVideo = agent.TypeVideo
RoleSystem = agent.RoleSystem
RoleUser = agent.RoleUser
RoleAssistant = agent.RoleAssistant
RoleTool = agent.RoleTool
ToolCodeInterpreter = agent.ToolCodeInterpreter
ToolWebSearch = agent.ToolWebSearch
)
type APIConfig struct {
Endpoint string
ApiKey string
DefaultChatModelConfig ChatModelConfig
}
type AgentConfig struct {
ApiKey string
Endpoint string
Agent string
ChatConfig ChatModelConfig
}
var agentConfigs map[string]*AgentConfig
var agentConfigsLock = sync.RWMutex{}
func GetAgent(name string) agent.Agent {
ag := agent.GetAgent(name)
if ag != nil {
return ag
}
var agConf *AgentConfig
if agentConfigs == nil {
agConfs := make(map[string]*AgentConfig)
config.LoadConfig("agent", &agConfs)
agConf = agConfs[name]
if agConf != nil {
agentConfigsLock.Lock()
agentConfigs = agConfs
agentConfigsLock.Unlock()
}
} else {
agentConfigsLock.RLock()
agConf = agentConfigs[name]
agentConfigsLock.RUnlock()
}
if agConf == nil {
return nil
}
if agConf.Agent == "" {
agConf.Agent = name
}
return agent.CreateAgent(name, agConf.Agent, agent.APIConfig{
Endpoint: agConf.Endpoint,
ApiKey: confAes.DecryptUrlBase64ToString(agConf.ApiKey),
DefaultChatModelConfig: agent.ChatModelConfig{
Model: agConf.ChatConfig.Model,
MaxTokens: agConf.ChatConfig.MaxTokens,
Temperature: agConf.ChatConfig.Temperature,
TopP: agConf.ChatConfig.TopP,
Tools: agConf.ChatConfig.Tools,
},
})
}

61
ai.go Normal file
View File

@ -0,0 +1,61 @@
package ai
import (
"apigo.cc/ai/ai/llm"
_ "apigo.cc/ai/ai/llm/openai"
_ "apigo.cc/ai/ai/llm/zhipu"
"github.com/ssgo/config"
"github.com/ssgo/u"
"os"
)
var confAes = u.NewAes([]byte("?GQ$0K0GgLdO=f+~L68PLm$uhKr4'=tV"), []byte("VFs7@sK61cj^f?HZ"))
var keysIsSet = false
var isInit = false
func SetSSKey(key, iv []byte) {
if !keysIsSet {
confAes = u.NewAes(key, iv)
keysIsSet = true
}
}
func Init() {
InitFrom("")
}
func InitFrom(filePath string) {
if !isInit {
isInit = true
list := map[string]*struct {
Endpoint string
ApiKey string
DefaultChatModelConfig llm.ChatConfig
Llm string
}{}
savedPath := ""
if filePath != "" {
curPath, _ := os.Getwd()
if curPath != filePath {
savedPath = curPath
_ = os.Chdir(filePath)
config.ResetConfigEnv()
}
}
_ = config.LoadConfig("llm", &list)
if savedPath != "" {
_ = os.Chdir(savedPath)
}
for name, llmConf := range list {
if llmConf.Llm == "" {
llmConf.Llm = name
}
llmConf.ApiKey = confAes.DecryptUrlBase64ToString(llmConf.ApiKey)
llm.Create(name, llmConf.Llm, llm.Config{
Endpoint: llmConf.Endpoint,
ApiKey: llmConf.ApiKey,
ChatConfig: llmConf.DefaultChatModelConfig,
})
}
}
}

59
ai/ai.go Normal file
View File

@ -0,0 +1,59 @@
package main
import (
"apigo.cc/ai/ai/js"
"fmt"
"github.com/ssgo/u"
"os"
"strings"
)
func main() {
if len(os.Args) > 1 && (os.Args[1] == "-e" || os.Args[1] == "export") {
imports, err := js.ExportForDev()
if err != nil {
fmt.Println(err.Error())
} else {
fmt.Println(`exported to ./lib
example:
`)
fmt.Println(u.Cyan(imports + `
// import customMod from './customMod'
function main(...args) {
// TODO
return null
}
`))
}
return
} else {
jsFile := "ai.js"
if len(os.Args) > 1 {
jsFile = os.Args[1]
if !strings.HasSuffix(jsFile, ".js") && !u.FileExists(jsFile) {
jsFile = jsFile + ".js"
}
}
if u.FileExists(jsFile) {
args := make([]any, len(os.Args)-2)
for i := 2; i < len(os.Args); i++ {
args[i-2] = os.Args[i]
}
result, err := js.RunFile(jsFile, args...)
if err != nil {
fmt.Println(err.Error())
} else if result != nil {
fmt.Println(u.JsonP(result))
}
return
}
}
fmt.Println(`Usage:
ai -h | --help show usage
ai -e | --export export ai.ts file for develop
ai test | test.js run test.js, if not specified, run ai.js
`)
return
}

24
go.mod
View File

@ -3,9 +3,29 @@ module apigo.cc/ai/ai
go 1.22
require (
apigo.cc/ai/agent v0.0.1
github.com/dop251/goja v0.0.0-20240828124009-016eb7256539
github.com/go-resty/resty/v2 v2.15.0
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/sashabaranov/go-openai v1.29.2
github.com/ssgo/config v1.7.7
github.com/ssgo/log v1.7.7
github.com/ssgo/u v1.7.7
github.com/stretchr/testify v1.9.0
)
require gopkg.in/yaml.v3 v3.0.1 // indirect
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dlclark/regexp2 v1.11.4 // indirect
github.com/dop251/goja_nodejs v0.0.0-20240728170619-29b559befffc // indirect
github.com/go-sourcemap/sourcemap v2.1.4+incompatible // indirect
github.com/google/pprof v0.0.0-20240910150728-a0b0bb1d4134 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/ssgo/standard v1.7.7 // indirect
golang.org/x/net v0.29.0 // indirect
golang.org/x/text v0.18.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
replace (
github.com/ssgo/config v1.7.7 => ../../ssgo/config
)

291
js/ai.go Normal file
View File

@ -0,0 +1,291 @@
package js
import (
"apigo.cc/ai/ai/llm"
"github.com/dop251/goja"
"github.com/ssgo/u"
"reflect"
"strings"
)
type ChatResult struct {
llm.TokenUsage
Result string
Error string
}
type AIGCResult struct {
Result string
Preview string
Results []string
Previews []string
Error string
}
func requireAI(lm llm.LLM) map[string]any {
return map[string]any{
"ask": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
conf, cb := getAskArgs(args.This, vm, args.Arguments)
result, usage, err := lm.Ask(makeChatMessages(args.Arguments), conf, cb)
return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)})
},
"fastAsk": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
_, cb := getAskArgs(args.This, vm, args.Arguments)
result, usage, err := lm.FastAsk(makeChatMessages(args.Arguments), cb)
return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)})
},
"longAsk": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
_, cb := getAskArgs(args.This, vm, args.Arguments)
result, usage, err := lm.LongAsk(makeChatMessages(args.Arguments), cb)
return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)})
},
"batterAsk": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
_, cb := getAskArgs(args.This, vm, args.Arguments)
result, usage, err := lm.BatterAsk(makeChatMessages(args.Arguments), cb)
return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)})
},
"bestAsk": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
_, cb := getAskArgs(args.This, vm, args.Arguments)
result, usage, err := lm.BestAsk(makeChatMessages(args.Arguments), cb)
return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)})
},
"multiAsk": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
_, cb := getAskArgs(args.This, vm, args.Arguments)
result, usage, err := lm.MultiAsk(makeChatMessages(args.Arguments), cb)
return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)})
},
"bestMultiAsk": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
_, cb := getAskArgs(args.This, vm, args.Arguments)
result, usage, err := lm.BestMultiAsk(makeChatMessages(args.Arguments), cb)
return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)})
},
"codeInterpreterAsk": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
_, cb := getAskArgs(args.This, vm, args.Arguments)
result, usage, err := lm.CodeInterpreterAsk(makeChatMessages(args.Arguments), cb)
return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)})
},
"webSearchAsk": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
_, cb := getAskArgs(args.This, vm, args.Arguments)
result, usage, err := lm.WebSearchAsk(makeChatMessages(args.Arguments), cb)
return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)})
},
"makeImage": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
prompt, conf := getAIGCArgs(args.Arguments)
results, err := lm.MakeImage(prompt, conf)
return makeAIGCResult(vm, results, nil, err)
},
"fastMakeImage": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
prompt, conf := getAIGCArgs(args.Arguments)
results, err := lm.FastMakeImage(prompt, conf)
return makeAIGCResult(vm, results, nil, err)
},
"bestMakeImage": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
prompt, conf := getAIGCArgs(args.Arguments)
results, err := lm.BestMakeImage(prompt, conf)
return makeAIGCResult(vm, results, nil, err)
},
"makeVideo": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
prompt, conf := getAIGCArgs(args.Arguments)
results, previews, err := lm.MakeVideo(prompt, conf)
return makeAIGCResult(vm, results, previews, err)
},
"fastMakeVideo": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
prompt, conf := getAIGCArgs(args.Arguments)
results, previews, err := lm.FastMakeVideo(prompt, conf)
return makeAIGCResult(vm, results, previews, err)
},
"bestMakeVideo": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
prompt, conf := getAIGCArgs(args.Arguments)
results, previews, err := lm.BestMakeVideo(prompt, conf)
return makeAIGCResult(vm, results, previews, err)
},
"support": lm.Support(),
}
}
func getErrorStr(err error) string {
if err != nil {
return err.Error()
}
return ""
}
func makeAIGCResult(vm *goja.Runtime, results []string, previews []string, err error) goja.Value {
result := ""
preview := ""
if len(results) > 0 {
result = results[0]
} else {
results = make([]string, 0)
}
if len(previews) > 0 {
preview = previews[0]
} else {
previews = make([]string, 0)
}
return vm.ToValue(AIGCResult{
Result: result,
Preview: preview,
Results: results,
Previews: previews,
Error: getErrorStr(err),
})
}
func getAIGCArgs(args []goja.Value) (string, llm.GCConfig) {
prompt := ""
var config llm.GCConfig
if len(args) > 0 {
prompt = u.String(args[0].Export())
if len(args) > 1 {
u.Convert(args[1].Export(), &config)
}
}
return prompt, config
}
func getAskArgs(thisArg goja.Value, vm *goja.Runtime, args []goja.Value) (llm.ChatConfig, func(string)) {
var chatConfig llm.ChatConfig
var callback func(answer string)
if len(args) > 0 {
for i := 1; i < len(args); i++ {
if cb, ok := goja.AssertFunction(args[i]); ok {
callback = func(answer string) {
_, _ = cb(thisArg, vm.ToValue(answer))
}
} else {
switch args[i].ExportType().Kind() {
case reflect.Map, reflect.Struct:
u.Convert(args[i].Export(), &chatConfig)
default:
chatConfig.Model = u.String(args[i].Export())
}
}
}
}
return chatConfig, callback
}
func makeChatMessages(args []goja.Value) []llm.ChatMessage {
out := make([]llm.ChatMessage, 0)
if len(args) > 0 {
v := args[0].Export()
vv := reflect.ValueOf(v)
t := args[0].ExportType()
lastRoleIsUser := false
switch t.Kind() {
// 数组,根据成员类型处理
// 字符串:
// 含有媒体:单条多模态消息
// 无媒体:多条文本消息
// 数组:多条消息(第一个成员不是 role 则自动生成)
// 对象:多条消息(无 role 则自动生成)(支持 content 或 contents
// 结构:转换为 llm.ChatMessage
// 对象:单条消息(支持 content 或 contents
// 结构:转换为 llm.ChatMessage
// 字符串:单条文本消息
case reflect.Slice:
hasSub := false
hasMulti := false
for i := 0; i < vv.Len(); i++ {
if vv.Index(i).Kind() == reflect.Slice || vv.Index(i).Kind() == reflect.Map || vv.Index(i).Kind() == reflect.Struct {
hasSub = true
break
}
if vv.Index(i).Kind() == reflect.String {
str := vv.Index(i).String()
if strings.HasPrefix(str, "data:") || strings.HasPrefix(str, "https://") || strings.HasPrefix(str, "http://") {
hasMulti = true
}
}
}
if hasSub || !hasMulti {
// 有子对象或纯文本数组
var defaultRole string
for i := 0; i < vv.Len(); i++ {
lastRoleIsUser = !lastRoleIsUser
if lastRoleIsUser {
defaultRole = llm.RoleUser
} else {
defaultRole = llm.RoleAssistant
}
vv2 := vv.Index(i)
switch vv2.Kind() {
case reflect.Slice:
out = append(out, makeChatMessageFromSlice(vv2, defaultRole))
case reflect.Map:
out = append(out, makeChatMessageFromSlice(vv2, defaultRole))
case reflect.Struct:
item := llm.ChatMessage{}
u.Convert(vv2.Interface(), &item)
out = append(out, item)
default:
out = append(out, llm.ChatMessage{Role: llm.RoleUser, Contents: []llm.ChatMessageContent{makeChatMessageContent(u.String(vv2.Interface()))}})
}
lastRoleIsUser = out[len(out)-1].Role != llm.RoleUser
}
} else {
// 单条多模态消息
out = append(out, makeChatMessageFromSlice(vv, llm.RoleUser))
}
case reflect.Map:
out = append(out, makeChatMessageFromMap(vv, llm.RoleUser))
case reflect.Struct:
item := llm.ChatMessage{}
u.Convert(v, &item)
out = append(out, item)
default:
out = append(out, llm.ChatMessage{Role: llm.RoleUser, Contents: []llm.ChatMessageContent{makeChatMessageContent(u.String(v))}})
}
}
return out
}
func makeChatMessageFromSlice(vv reflect.Value, defaultRole string) llm.ChatMessage {
role := u.String(vv.Index(0).Interface())
j := 0
if role == llm.RoleUser || role == llm.RoleAssistant || role == llm.RoleSystem || role == llm.RoleTool {
j = 1
} else {
role = defaultRole
}
contents := make([]llm.ChatMessageContent, 0)
for ; j < vv.Len(); j++ {
contents = append(contents, makeChatMessageContent(u.String(vv.Index(j).Interface())))
}
return llm.ChatMessage{Role: role, Contents: contents}
}
func makeChatMessageFromMap(vv reflect.Value, defaultRole string) llm.ChatMessage {
role := u.String(vv.MapIndex(reflect.ValueOf("role")).Interface())
if role == "" {
role = defaultRole
}
contents := make([]llm.ChatMessageContent, 0)
content := u.String(vv.MapIndex(reflect.ValueOf("content")).Interface())
if content != "" {
contents = append(contents, makeChatMessageContent(content))
} else {
contentsV := vv.MapIndex(reflect.ValueOf("contents"))
if contentsV.IsValid() && contentsV.Kind() == reflect.Slice {
for i := 0; i < contentsV.Len(); i++ {
contents = append(contents, makeChatMessageContent(u.String(contentsV.Index(i).Interface())))
}
}
}
return llm.ChatMessage{Role: role, Contents: contents}
}
func makeChatMessageContent(contnet string) llm.ChatMessageContent {
if strings.HasPrefix(contnet, "data:image/") || ((strings.HasPrefix(contnet, "https://") || strings.HasPrefix(contnet, "http://")) && (strings.HasSuffix(contnet, ".png") || strings.HasSuffix(contnet, ".jpg") || strings.HasSuffix(contnet, ".jpeg") || strings.HasSuffix(contnet, ".gif") || strings.HasSuffix(contnet, ".svg"))) {
return llm.ChatMessageContent{Type: llm.TypeImage, Content: contnet}
} else if strings.HasPrefix(contnet, "data:video/") || ((strings.HasPrefix(contnet, "https://") || strings.HasPrefix(contnet, "http://")) && (strings.HasSuffix(contnet, ".mp4") || strings.HasSuffix(contnet, ".mov") || strings.HasSuffix(contnet, ".m4v") || strings.HasSuffix(contnet, ".avi") || strings.HasSuffix(contnet, ".wmv"))) {
return llm.ChatMessageContent{Type: llm.TypeVideo, Content: contnet}
}
return llm.ChatMessageContent{Type: llm.TypeText, Content: contnet}
}

81
js/console.go Normal file
View File

@ -0,0 +1,81 @@
package js
import (
"fmt"
"github.com/dop251/goja"
"github.com/ssgo/u"
"strings"
)
func requireConsole() map[string]any {
return map[string]any{
"print": func(args goja.FunctionCall) goja.Value {
consolePrint(args, "print", nil)
return nil
},
"println": func(args goja.FunctionCall) goja.Value {
consolePrint(args, "log", nil)
return nil
},
"log": func(args goja.FunctionCall) goja.Value {
consolePrint(args, "log", nil)
return nil
},
"info": func(args goja.FunctionCall) goja.Value {
consolePrint(args, "info", nil)
return nil
},
"warn": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
consolePrint(args, "warn", vm)
return nil
},
"error": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
consolePrint(args, "error", vm)
return nil
},
"debug": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
consolePrint(args, "debug", vm)
return nil
},
"input": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
consolePrint(args, "print", nil)
line := ""
_, _ = fmt.Scanln(&line)
return vm.ToValue(line)
},
}
}
func consolePrint(args goja.FunctionCall, typ string, vm *goja.Runtime) {
arr := make([]any, len(args.Arguments))
textColor := u.TextNone
switch typ {
case "info":
textColor = u.TextCyan
case "warn":
textColor = u.TextYellow
case "error":
textColor = u.TextRed
}
for i, arg := range args.Arguments {
if textColor != u.TextNone {
arr[i] = u.Color(u.StringP(arg.Export()), textColor, u.BgNone)
} else {
arr[i] = u.StringP(arg.Export())
}
}
if (typ == "warn" || typ == "error" || typ == "debug") && vm != nil {
callStacks := make([]string, 0)
for _, stack := range vm.CaptureCallStack(0, nil) {
callStacks = append(callStacks, u.Color(" "+stack.Position().String(), textColor, u.BgNone))
}
fmt.Println(arr...)
fmt.Println(strings.Join(callStacks, "\n"))
} else if typ == "print" {
fmt.Print(arr...)
} else {
fmt.Println(arr...)
}
}

48
js/file.go Normal file
View File

@ -0,0 +1,48 @@
package js
import (
"errors"
"github.com/dop251/goja"
"github.com/ssgo/u"
)
func requireFile() map[string]any {
return map[string]any{
"read": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
if len(args.Arguments) < 1 {
return vm.NewGoError(errors.New("arguments need 1 given " + u.String(len(args.Arguments))))
}
if r, err := u.ReadFile(u.String(args.Arguments[0].Export())); err == nil {
return vm.ToValue(r)
} else {
return vm.NewGoError(err)
}
},
"write": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
if len(args.Arguments) < 2 {
return vm.NewGoError(errors.New("arguments need 2 given " + u.String(len(args.Arguments))))
}
if err := u.WriteFileBytes(u.String(args.Arguments[0].Export()), u.Bytes(args.Arguments[0].Export())); err == nil {
return nil
} else {
return vm.NewGoError(err)
}
},
"dir": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
if len(args.Arguments) < 1 {
return vm.NewGoError(errors.New("arguments need 1 given " + u.String(len(args.Arguments))))
}
if r, err := u.ReadDir(u.String(args.Arguments[0].Export())); err == nil {
return vm.ToValue(r)
} else {
return vm.NewGoError(err)
}
},
"stat": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
if len(args.Arguments) < 1 {
return vm.NewGoError(errors.New("arguments need 1 given " + u.String(len(args.Arguments))))
}
return vm.ToValue(u.GetFileInfo(u.String(args.Arguments[0].Export())))
},
}
}

238
js/js.go Normal file
View File

@ -0,0 +1,238 @@
package js
import (
"apigo.cc/ai/ai"
"apigo.cc/ai/ai/llm"
"bytes"
_ "embed"
"encoding/json"
"errors"
"fmt"
"github.com/dop251/goja"
"github.com/dop251/goja_nodejs/require"
"github.com/ssgo/u"
"path/filepath"
"regexp"
"strings"
"text/template"
)
//go:embed lib/ai.ts
var aiTS string
//go:embed lib/console.ts
var consoleTS string
//go:embed lib/file.ts
var fileTS string
func RunFile(file string, args ...any) (any, error) {
return Run(u.ReadFileN(file), file, args...)
}
func Run(code string, refFile string, args ...any) (any, error) {
var r any
js, err := StartFromCode(code, refFile)
if err == nil {
r, err = js.Run(args...)
}
return r, err
}
var importModMatcher = regexp.MustCompile(`(?im)^\s*import\s+(.+?)\s+from\s+['"](.+?)['"]`)
var importLibMatcher = regexp.MustCompile(`(?im)^\s*(import)\s+(.+?)\s+from\s+['"][./\\\w:]+lib[/\\](.+?)(\.ts)?['"]`)
var requireLibMatcher = regexp.MustCompile(`(?im)^\s*(const|let|var)\s+(.+?)\s*=\s*require\s*\(\s*['"][./\\\w:]+lib[/\\](.+?)(\.ts)?['"]\s*\)`)
var checkMainMatcher = regexp.MustCompile(`(?im)^\s*function\s+main\s*\(`)
type JS struct {
vm *goja.Runtime
required map[string]bool
file string
srcCode string
code string
}
func (js *JS) requireMod(name string) error {
var err error
if name == "console" || name == "" {
if !js.required["console"] {
js.required["console"] = true
err = js.vm.Set("console", requireConsole())
}
}
if err == nil && (name == "file" || name == "") {
if !js.required["file"] {
js.required["file"] = true
err = js.vm.Set("file", requireFile())
}
}
if err == nil && (name == "ai" || name == "") {
if !js.required["ai"] {
js.required["ai"] = true
aiList := make(map[string]any)
for name, lm := range llm.List() {
aiList[name] = requireAI(lm)
}
err = js.vm.Set("ai", aiList)
}
}
return err
}
func (js *JS) makeImport(matcher *regexp.Regexp, code string) (string, int, error) {
var modErr error
importCount := 0
code = matcher.ReplaceAllStringFunc(code, func(str string) string {
if m := matcher.FindStringSubmatch(str); m != nil && len(m) > 3 {
optName := m[1]
if optName == "import" {
optName = "let"
}
varName := m[2]
modName := m[3]
importCount++
if modErr == nil {
if err := js.requireMod(modName); err != nil {
modErr = err
}
}
if varName != modName {
return fmt.Sprintf("%s %s = %s", optName, varName, modName)
}
}
return ""
})
return code, importCount, modErr
}
func StartFromFile(file string) (*JS, error) {
return StartFromCode(u.ReadFileN(file), file)
}
func StartFromCode(code, refFile string) (*JS, error) {
if refFile == "" {
refFile = "main.js"
}
if absFile, err := filepath.Abs(refFile); err == nil {
refFile = absFile
}
ai.InitFrom(filepath.Dir(refFile))
js := &JS{
vm: goja.New(),
required: map[string]bool{},
file: refFile,
srcCode: code,
code: code,
}
// 按需加载引用
var importCount int
var modErr error
js.code, importCount, modErr = js.makeImport(importLibMatcher, js.code)
if modErr == nil {
importCount1 := importCount
js.code, importCount, modErr = js.makeImport(requireLibMatcher, js.code)
importCount += importCount1
}
// 将 import 转换为 require
js.code = importModMatcher.ReplaceAllString(js.code, "let $1 = require('$2')")
// 如果没有import默认import所有
if modErr == nil && importCount == 0 {
modErr = js.requireMod("")
}
if modErr != nil {
return nil, modErr
}
//fmt.Println(u.BCyan(js.code))
// 处理模块引用
require.NewRegistryWithLoader(func(path string) ([]byte, error) {
refPath := filepath.Join(filepath.Dir(js.file), path)
if !strings.HasSuffix(refPath, ".js") && !u.FileExists(refPath) {
refPath += ".js"
}
modCode, err := u.ReadFile(refPath)
if err != nil {
return nil, err
}
modCode, _, _ = js.makeImport(importLibMatcher, modCode)
modCode, _, _ = js.makeImport(requireLibMatcher, modCode)
return []byte(modCode), modErr
}).Enable(js.vm)
// 初始化主函数
if !checkMainMatcher.MatchString(js.code) {
js.code = "function main(...args){" + js.code + "}"
}
if _, err := js.vm.RunScript("main", js.code); err != nil {
return nil, err
}
return js, nil
}
func (js *JS) Run(args ...any) (any, error) {
// 解析参数
for i, arg := range args {
if str, ok := arg.(string); ok {
var v interface{}
if err := json.Unmarshal([]byte(str), &v); err == nil {
args[i] = v
}
}
}
if err := js.vm.Set("__args", args); err != nil {
return nil, err
}
jsResult, err := js.vm.RunScript(js.file, "main(...__args)")
var result any
if err == nil {
if jsResult != nil && !jsResult.Equals(goja.Undefined()) {
result = jsResult.Export()
}
}
return result, err
}
type Exports struct {
LLMList []string
}
func ExportForDev() (string, error) {
ai.Init()
if len(llm.List()) == 0 && !u.FileExists("env.yml") && !u.FileExists("env.json") && !u.FileExists("llm.yml") && !u.FileExists("llm.json") {
return "", errors.New("no llm config found, please run `ai -e` on env.yml or llm.yml path")
}
exports := Exports{}
for name, _ := range llm.List() {
exports.LLMList = append(exports.LLMList, name)
}
exportFile := filepath.Join("lib", "ai.ts")
var tpl *template.Template
var err error
if tpl, err = template.New(exportFile).Parse(aiTS); err == nil {
buf := bytes.NewBuffer(make([]byte, 0))
if err = tpl.Execute(buf, exports); err == nil {
err = u.WriteFileBytes(exportFile, buf.Bytes())
}
}
if err != nil {
return "", err
}
_ = u.WriteFile(filepath.Join("lib", "console.ts"), consoleTS)
_ = u.WriteFile(filepath.Join("lib", "file.ts"), fileTS)
return `import {` + strings.Join(exports.LLMList, ", ") + `} from './lib/ai'
import console from './lib/console'
import file from './lib/file'`, nil
}

81
js/lib/ai.ts Normal file
View File

@ -0,0 +1,81 @@
// just for develop
{{range .LLMList}}
let {{.}}: LLM
{{end}}
export default {
{{range .LLMList}}
{{.}},
{{end}}
}
interface ChatModelConfig {
model: string
ratio: number
maxTokens: number
temperature: number
topP: number
tools: Object
}
interface ChatResult {
result: string
askTokens: number
answerTokens: number
totalTokens: number
error: string
}
interface GCResult {
result: string
preview: string
results: Array<string>
previews: Array<string>
error: string
}
interface Support {
ask: boolean
askWithImage: boolean
askWithVideo: boolean
askWithCodeInterpreter: boolean
askWithWebSearch: boolean
makeImage: boolean
makeVideo: boolean
models: Array<string>
}
interface LLM {
ask(messages: any, config?: ChatModelConfig, callback?: (answer: string) => void): ChatResult
fastAsk(messages: any, callback?: (answer: string) => void): ChatResult
longAsk(messages: any, callback?: (answer: string) => void): ChatResult
batterAsk(messages: any, callback?: (answer: string) => void): ChatResult
bestAsk(messages: any, callback?: (answer: string) => void): ChatResult
multiAsk(messages: any, callback?: (answer: string) => void): ChatResult
bestMultiAsk(messages: any, callback?: (answer: string) => void): ChatResult
codeInterpreterAsk(messages: any, callback?: (answer: string) => void): ChatResult
webSearchAsk(messages: any, callback?: (answer: string) => void): ChatResult
makeImage(model: string, prompt: string, size?: string, refImage?: string): GCResult
fastMakeImage(prompt: string, size?: string, refImage?: string): GCResult
bestMakeImage(prompt: string, size?: string, refImage?: string): GCResult
makeVideo(arg2: string, arg3: string, arg4: string, arg5: string): GCResult
fastMakeVideo(prompt: string, size?: string, refImage?: string): GCResult
bestMakeVideo(prompt: string, size?: string, refImage?: string): GCResult
support: Support
}

37
js/lib/console.ts Normal file
View File

@ -0,0 +1,37 @@
// just for develop
export default {
print,
println,
log,
debug,
info,
warn,
error,
input,
}
function print(...data: any[]): void {
}
function println(...data: any[]): void {
}
function log(...data: any[]): void {
}
function debug(...data: any[]): void {
}
function info(...data: any[]): void {
}
function warn(...data: any[]): void {
}
function error(...data: any[]): void {
}
function input(...data: any[]): string {
return ''
}

32
js/lib/file.ts Normal file
View File

@ -0,0 +1,32 @@
// just for develop
export default {
read,
write,
dir,
stat
}
function read(filename: string): string {
return ''
}
function write(filename: string, data: any): void {
}
function dir(filename: string): Array<FileInfo> {
return null
}
function stat(filename: string): FileInfo {
return null
}
interface FileInfo {
Name: string
FullName: string
IsDir: boolean
Size: number
ModTime: number
}

View File

@ -1,15 +1,68 @@
package ai
package llm
import "apigo.cc/ai/agent"
type ChatMessage struct {
Role string
Contents []ChatMessageContent
}
type ChatMessage = agent.ChatMessage
type ChatMessageContent = agent.ChatMessageContent
type ChatModelConfig struct {
Model string
MaxTokens int
Temperature float64
TopP float64
Tools map[string]any
type ChatMessageContent struct {
Type string // text, image, audio, video
Content string
}
type ChatConfig struct {
defaultConfig *ChatConfig
Model string
Ratio float64
MaxTokens int
Temperature float64
TopP float64
Tools map[string]any
}
func (chatConfig *ChatConfig) SetDefault(config *ChatConfig) {
chatConfig.defaultConfig = config
}
func (chatConfig *ChatConfig) GetModel() string {
if chatConfig.Model == "" && chatConfig.defaultConfig != nil {
return chatConfig.defaultConfig.Model
}
return chatConfig.Model
}
func (chatConfig *ChatConfig) GetMaxTokens() int {
if chatConfig.MaxTokens == 0 && chatConfig.defaultConfig != nil {
return chatConfig.defaultConfig.MaxTokens
}
return chatConfig.MaxTokens
}
func (chatConfig *ChatConfig) GetTemperature() float64 {
if chatConfig.Temperature == 0 && chatConfig.defaultConfig != nil {
return chatConfig.defaultConfig.Temperature
}
return chatConfig.Temperature
}
func (chatConfig *ChatConfig) GetTopP() float64 {
if chatConfig.TopP == 0 && chatConfig.defaultConfig != nil {
return chatConfig.defaultConfig.TopP
}
return chatConfig.TopP
}
func (chatConfig *ChatConfig) GetTools() map[string]any {
if chatConfig.Tools == nil && chatConfig.defaultConfig != nil {
return chatConfig.defaultConfig.Tools
}
return chatConfig.Tools
}
type TokenUsage struct {
AskTokens int64
AnswerTokens int64
TotalTokens int64
}
type MessagesMaker struct {

33
llm/gc.go Normal file
View File

@ -0,0 +1,33 @@
package llm
type GCConfig struct {
defaultConfig *GCConfig
Model string
Size string
Ref string
}
func (gcConfig *GCConfig) SetDefault(config *GCConfig) {
gcConfig.defaultConfig = config
}
func (gcConfig *GCConfig) GetModel() string {
if gcConfig.Model == "" && gcConfig.defaultConfig != nil {
return gcConfig.defaultConfig.Model
}
return gcConfig.Model
}
func (gcConfig *GCConfig) GetSize() string {
if gcConfig.Size == "" && gcConfig.defaultConfig != nil {
return gcConfig.defaultConfig.Size
}
return gcConfig.Size
}
func (gcConfig *GCConfig) GetRef() string {
if gcConfig.Ref == "" && gcConfig.defaultConfig != nil {
return gcConfig.defaultConfig.Ref
}
return gcConfig.Ref
}

96
llm/llm.go Normal file
View File

@ -0,0 +1,96 @@
package llm
import "sync"
const (
TypeText = "text"
TypeImage = "image"
TypeVideo = "video"
RoleSystem = "system"
RoleUser = "user"
RoleAssistant = "assistant"
RoleTool = "tool"
ToolCodeInterpreter = "codeInterpreter"
ToolWebSearch = "webSearch"
)
type Support struct {
Ask bool
AskWithImage bool
AskWithVideo bool
AskWithCodeInterpreter bool
AskWithWebSearch bool
MakeImage bool
MakeVideo bool
Models []string
}
type Config struct {
Endpoint string
ApiKey string
ChatConfig ChatConfig
GCConfig GCConfig
}
type LLM interface {
Support() Support
Ask(messages []ChatMessage, config ChatConfig, callback func(answer string)) (string, TokenUsage, error)
FastAsk(messages []ChatMessage, callback func(answer string)) (string, TokenUsage, error)
LongAsk(messages []ChatMessage, callback func(answer string)) (string, TokenUsage, error)
BatterAsk(messages []ChatMessage, callback func(answer string)) (string, TokenUsage, error)
BestAsk(messages []ChatMessage, callback func(answer string)) (string, TokenUsage, error)
MultiAsk(messages []ChatMessage, callback func(answer string)) (string, TokenUsage, error)
BestMultiAsk(messages []ChatMessage, callback func(answer string)) (string, TokenUsage, error)
CodeInterpreterAsk(messages []ChatMessage, callback func(answer string)) (string, TokenUsage, error)
WebSearchAsk(messages []ChatMessage, callback func(answer string)) (string, TokenUsage, error)
MakeImage(prompt string, config GCConfig) ([]string, error)
FastMakeImage(prompt string, config GCConfig) ([]string, error)
BestMakeImage(prompt string, config GCConfig) ([]string, error)
MakeVideo(prompt string, config GCConfig) ([]string, []string, error)
FastMakeVideo(prompt string, config GCConfig) ([]string, []string, error)
BestMakeVideo(prompt string, config GCConfig) ([]string, []string, error)
}
var llmMakers = map[string]func(Config) LLM{}
var llmMakersLock = sync.RWMutex{}
var llms = map[string]LLM{}
var llmsLock = sync.RWMutex{}
func Register(llmId string, maker func(Config) LLM) {
llmMakersLock.Lock()
llmMakers[llmId] = maker
llmMakersLock.Unlock()
}
func Create(name, llmId string, config Config) LLM {
llmMakersLock.RLock()
maker := llmMakers[llmId]
llmMakersLock.RUnlock()
if maker != nil {
llm := maker(config)
llmsLock.Lock()
llms[name] = llm
llmsLock.Unlock()
return llm
}
return nil
}
func Get(name string) LLM {
llmsLock.RLock()
llm := llms[name]
llmsLock.RUnlock()
return llm
}
func List() map[string]LLM {
list := map[string]LLM{}
llmsLock.RLock()
for name, llm := range llms {
list[name] = llm
}
llmsLock.RUnlock()
return list
}

165
llm/openai/chat.go Normal file
View File

@ -0,0 +1,165 @@
package openai
import (
"apigo.cc/ai/ai/llm"
"context"
"github.com/sashabaranov/go-openai"
"github.com/ssgo/log"
"strings"
)
func (lm *LLM) FastAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
return lm.Ask(messages, llm.ChatConfig{
Model: ModelGPT_4o_mini_2024_07_18,
}, callback)
}
func (lm *LLM) LongAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
return lm.Ask(messages, llm.ChatConfig{
Model: ModelGPT_4_32k_0613,
}, callback)
}
func (lm *LLM) BatterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
return lm.Ask(messages, llm.ChatConfig{
Model: ModelGPT_4_turbo,
}, callback)
}
func (lm *LLM) BestAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
return lm.Ask(messages, llm.ChatConfig{
Model: ModelGPT_4o_2024_08_06,
}, callback)
}
func (lm *LLM) MultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
return lm.Ask(messages, llm.ChatConfig{
Model: ModelGPT_4o_mini_2024_07_18,
}, callback)
}
func (lm *LLM) BestMultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
return lm.Ask(messages, llm.ChatConfig{
Model: ModelGPT_4o_2024_08_06,
}, callback)
}
func (lm *LLM) CodeInterpreterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
return lm.Ask(messages, llm.ChatConfig{
Model: ModelGPT_4o,
Tools: map[string]any{llm.ToolCodeInterpreter: nil},
}, callback)
}
func (lm *LLM) WebSearchAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
return lm.Ask(messages, llm.ChatConfig{
Model: ModelGPT_4o_mini_2024_07_18,
Tools: map[string]any{llm.ToolWebSearch: nil},
}, callback)
}
func (lm *LLM) Ask(messages []llm.ChatMessage, config llm.ChatConfig, callback func(answer string)) (string, llm.TokenUsage, error) {
openaiConf := openai.DefaultConfig(lm.config.ApiKey)
if lm.config.Endpoint != "" {
openaiConf.BaseURL = lm.config.Endpoint
}
config.SetDefault(&lm.config.ChatConfig)
agentMessages := make([]openai.ChatCompletionMessage, len(messages))
for i, msg := range messages {
var contents []openai.ChatMessagePart
if msg.Contents != nil {
contents = make([]openai.ChatMessagePart, len(msg.Contents))
for j, inPart := range msg.Contents {
part := openai.ChatMessagePart{}
part.Type = TypeMap[inPart.Type]
switch inPart.Type {
case llm.TypeText:
part.Text = inPart.Content
case llm.TypeImage:
part.ImageURL = &openai.ChatMessageImageURL{
URL: inPart.Content,
Detail: openai.ImageURLDetailAuto,
}
}
contents[j] = part
}
}
agentMessages[i] = openai.ChatCompletionMessage{
Role: RoleMap[msg.Role],
MultiContent: contents,
}
}
opt := openai.ChatCompletionRequest{
Model: config.GetModel(),
Messages: agentMessages,
MaxTokens: config.GetMaxTokens(),
Temperature: float32(config.GetTemperature()),
TopP: float32(config.GetTopP()),
StreamOptions: &openai.StreamOptions{
IncludeUsage: true,
},
}
for name := range config.GetTools() {
switch name {
case llm.ToolCodeInterpreter:
opt.Tools = append(opt.Tools, openai.Tool{Type: "code_interpreter"})
case llm.ToolWebSearch:
}
}
c := openai.NewClientWithConfig(openaiConf)
if callback != nil {
opt.Stream = true
r, err := c.CreateChatCompletionStream(context.Background(), opt)
if err == nil {
results := make([]string, 0)
usage := llm.TokenUsage{}
for {
if r2, err := r.Recv(); err == nil {
if r2.Choices != nil {
for _, ch := range r2.Choices {
text := ch.Delta.Content
callback(text)
results = append(results, text)
}
}
if r2.Usage != nil {
usage.AskTokens += int64(r2.Usage.PromptTokens)
usage.AnswerTokens += int64(r2.Usage.CompletionTokens)
usage.TotalTokens += int64(r2.Usage.TotalTokens)
}
} else {
break
}
}
_ = r.Close()
return strings.Join(results, ""), usage, nil
} else {
log.DefaultLogger.Error(err.Error())
return "", llm.TokenUsage{}, err
}
} else {
r, err := c.CreateChatCompletion(context.Background(), opt)
if err == nil {
results := make([]string, 0)
if r.Choices != nil {
for _, ch := range r.Choices {
results = append(results, ch.Message.Content)
}
}
return strings.Join(results, ""), llm.TokenUsage{
AskTokens: int64(r.Usage.PromptTokens),
AnswerTokens: int64(r.Usage.CompletionTokens),
TotalTokens: int64(r.Usage.TotalTokens),
}, nil
} else {
//fmt.Println(u.BMagenta(err.Error()), u.BMagenta(u.JsonP(r)))
return "", llm.TokenUsage{}, err
}
}
}

81
llm/openai/config.go Normal file
View File

@ -0,0 +1,81 @@
package openai
import (
"apigo.cc/ai/ai/llm"
"github.com/sashabaranov/go-openai"
)
type LLM struct {
config llm.Config
}
var TypeMap = map[string]openai.ChatMessagePartType{
llm.TypeText: openai.ChatMessagePartTypeText,
llm.TypeImage: openai.ChatMessagePartTypeImageURL,
//llm.TypeVideo: "video_url",
}
var RoleMap = map[string]string{
llm.RoleSystem: openai.ChatMessageRoleSystem,
llm.RoleUser: openai.ChatMessageRoleUser,
llm.RoleAssistant: openai.ChatMessageRoleAssistant,
llm.RoleTool: openai.ChatMessageRoleTool,
}
const (
ModelGPT_4_32k_0613 = "gpt-4-32k-0613"
ModelGPT_4_32k_0314 = "gpt-4-32k-0314"
ModelGPT_4_32k = "gpt-4-32k"
ModelGPT_4_0613 = "gpt-4-0613"
ModelGPT_4_0314 = "gpt-4-0314"
ModelGPT_4o = "gpt-4o"
ModelGPT_4o_2024_05_13 = "gpt-4o-2024-05-13"
ModelGPT_4o_2024_08_06 = "gpt-4o-2024-08-06"
ModelGPT_4o_mini = "gpt-4o-mini"
ModelGPT_4o_mini_2024_07_18 = "gpt-4o-mini-2024-07-18"
ModelGPT_4_turbo = "gpt-4-turbo"
ModelGPT_4_turbo_2024_04_09 = "gpt-4-turbo-2024-04-09"
ModelGPT_4_0125_preview = "gpt-4-0125-preview"
ModelGPT_4_1106_preview = "gpt-4-1106-preview"
ModelGPT_4_turbo_preview = "gpt-4-turbo-preview"
ModelGPT_4_vision_preview = "gpt-4-vision-preview"
ModelGPT_4 = "gpt-4"
ModelGPT_3_5_turbo_0125 = "gpt-3.5-turbo-0125"
ModelGPT_3_5_turbo_1106 = "gpt-3.5-turbo-1106"
ModelGPT_3_5_turbo_0613 = "gpt-3.5-turbo-0613"
ModelGPT_3_5_turbo_0301 = "gpt-3.5-turbo-0301"
ModelGPT_3_5_turbo_16k = "gpt-3.5-turbo-16k"
ModelGPT_3_5_turbo_16k_0613 = "gpt-3.5-turbo-16k-0613"
ModelGPT_3_5_turbo = "gpt-3.5-turbo"
ModelGPT_3_5_turbo_instruct = "gpt-3.5-turbo-instruct"
ModelDavinci_002 = "davinci-002"
ModelCurie = "curie"
ModelCurie_002 = "curie-002"
ModelAda_002 = "ada-002"
ModelBabbage_002 = "babbage-002"
ModelCode_davinci_002 = "code-davinci-002"
ModelCode_cushman_001 = "code-cushman-001"
ModelCode_davinci_001 = "code-davinci-001"
ModelDallE2Std = "dall-e-2"
ModelDallE2HD = "dall-e-2-hd"
ModelDallE3Std = "dall-e-3"
ModelDallE3HD = "dall-e-3-hd"
)
func (ag *LLM) Support() llm.Support {
return llm.Support{
Ask: true,
AskWithImage: true,
AskWithVideo: false,
AskWithCodeInterpreter: true,
AskWithWebSearch: false,
MakeImage: true,
MakeVideo: false,
Models: []string{ModelGPT_4_32k_0613, ModelGPT_4_32k_0314, ModelGPT_4_32k, ModelGPT_4_0613, ModelGPT_4_0314, ModelGPT_4o, ModelGPT_4o_2024_05_13, ModelGPT_4o_2024_08_06, ModelGPT_4o_mini, ModelGPT_4o_mini_2024_07_18, ModelGPT_4_turbo, ModelGPT_4_turbo_2024_04_09, ModelGPT_4_0125_preview, ModelGPT_4_1106_preview, ModelGPT_4_turbo_preview, ModelGPT_4_vision_preview, ModelGPT_4, ModelGPT_3_5_turbo_0125, ModelGPT_3_5_turbo_1106, ModelGPT_3_5_turbo_0613, ModelGPT_3_5_turbo_0301, ModelGPT_3_5_turbo_16k, ModelGPT_3_5_turbo_16k_0613, ModelGPT_3_5_turbo, ModelGPT_3_5_turbo_instruct, ModelDavinci_002, ModelCurie, ModelCurie_002, ModelAda_002, ModelBabbage_002, ModelCode_davinci_002, ModelCode_cushman_001, ModelCode_davinci_001, ModelDallE2Std, ModelDallE2HD, ModelDallE3Std, ModelDallE3HD},
}
}
func init() {
llm.Register("openai", func(config llm.Config) llm.LLM {
return &LLM{config: config}
})
}

73
llm/openai/gc.go Normal file
View File

@ -0,0 +1,73 @@
package openai
import (
"apigo.cc/ai/ai/llm"
"context"
"github.com/sashabaranov/go-openai"
"strings"
)
// func (lm *LLM) FastMakeImage(prompt, size, refImage string) ([]string, error) {
// return lm.MakeImage(ModelDallE3Std, prompt, size, refImage)
// }
//
// func (lm *LLM) BestMakeImage(prompt, size, refImage string) ([]string, error) {
// return lm.MakeImage(ModelDallE3HD, prompt, size, refImage)
// }
//
// func (lm *LLM) MakeImage(model, prompt, size, refImage string) ([]string, error) {
func (lm *LLM) FastMakeImage(prompt string, config llm.GCConfig) ([]string, error) {
config.Model = ModelDallE3Std
return lm.MakeImage(prompt, config)
}
func (lm *LLM) BestMakeImage(prompt string, config llm.GCConfig) ([]string, error) {
config.Model = ModelDallE3HD
return lm.MakeImage(prompt, config)
}
func (lm *LLM) MakeImage(prompt string, config llm.GCConfig) ([]string, error) {
openaiConf := openai.DefaultConfig(lm.config.ApiKey)
if lm.config.Endpoint != "" {
openaiConf.BaseURL = lm.config.Endpoint
}
c := openai.NewClientWithConfig(openaiConf)
style := openai.CreateImageStyleVivid
if (!strings.Contains(prompt, "vivid") || !strings.Contains(prompt, "生动的")) && (strings.Contains(prompt, "natural") || strings.Contains(prompt, "自然的")) {
style = openai.CreateImageStyleNatural
}
quality := openai.CreateImageQualityStandard
if strings.HasSuffix(config.Model, "-hd") {
quality = openai.CreateImageQualityHD
config.Model = config.Model[0 : len(config.Model)-3]
}
r, err := c.CreateImage(context.Background(), openai.ImageRequest{
Prompt: prompt,
Model: config.Model,
Quality: quality,
Size: config.Size,
Style: style,
ResponseFormat: openai.CreateImageResponseFormatURL,
})
if err == nil {
results := make([]string, 0)
for _, item := range r.Data {
results = append(results, item.URL)
}
return results, nil
} else {
return nil, err
}
}
func (lm *LLM) FastMakeVideo(prompt string, config llm.GCConfig) ([]string, []string, error) {
return lm.MakeVideo(prompt, config)
}
func (lm *LLM) BestMakeVideo(prompt string, config llm.GCConfig) ([]string, []string, error) {
return lm.MakeVideo(prompt, config)
}
func (lm *LLM) MakeVideo(prompt string, config llm.GCConfig) ([]string, []string, error) {
return nil, nil, nil
}

137
llm/zhipu/chat.go Normal file
View File

@ -0,0 +1,137 @@
package zhipu
import (
"apigo.cc/ai/ai/llm"
"apigo.cc/ai/ai/llm/zhipu/zhipu"
"context"
"strings"
)
func (lm *LLM) FastAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
return lm.Ask(messages, llm.ChatConfig{
Model: ModelGLM4Flash,
}, callback)
}
func (lm *LLM) LongAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
return lm.Ask(messages, llm.ChatConfig{
Model: ModelGLM4Long,
}, callback)
}
func (lm *LLM) BatterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
return lm.Ask(messages, llm.ChatConfig{
Model: ModelGLM4Plus,
}, callback)
}
func (lm *LLM) BestAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
return lm.Ask(messages, llm.ChatConfig{
Model: ModelGLM40520,
}, callback)
}
func (lm *LLM) MultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
return lm.Ask(messages, llm.ChatConfig{
Model: ModelGLM4VPlus,
}, callback)
}
func (lm *LLM) BestMultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
return lm.Ask(messages, llm.ChatConfig{
Model: ModelGLM4V,
}, callback)
}
func (lm *LLM) CodeInterpreterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
return lm.Ask(messages, llm.ChatConfig{
Model: ModelGLM4AllTools,
Tools: map[string]any{llm.ToolCodeInterpreter: nil},
}, callback)
}
func (lm *LLM) WebSearchAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
return lm.Ask(messages, llm.ChatConfig{
Model: ModelGLM4AllTools,
Tools: map[string]any{llm.ToolWebSearch: nil},
}, callback)
}
func (lm *LLM) Ask(messages []llm.ChatMessage, config llm.ChatConfig, callback func(answer string)) (string, llm.TokenUsage, error) {
config.SetDefault(&lm.config.ChatConfig)
c, err := zhipu.NewClient(zhipu.WithAPIKey(lm.config.ApiKey), zhipu.WithBaseURL(lm.config.Endpoint))
if err != nil {
return "", llm.TokenUsage{}, err
}
cc := c.ChatCompletion(config.GetModel())
for _, msg := range messages {
var contents []zhipu.ChatCompletionMultiContent
if msg.Contents != nil {
contents = make([]zhipu.ChatCompletionMultiContent, len(msg.Contents))
for j, inPart := range msg.Contents {
part := zhipu.ChatCompletionMultiContent{}
part.Type = NameMap[inPart.Type]
switch inPart.Type {
case llm.TypeText:
part.Text = inPart.Content
case llm.TypeImage:
part.ImageURL = &zhipu.URLItem{URL: inPart.Content}
case llm.TypeVideo:
part.VideoURL = &zhipu.URLItem{URL: inPart.Content}
}
contents[j] = part
}
}
cc.AddMessage(zhipu.ChatCompletionMultiMessage{
Role: NameMap[msg.Role],
Content: contents,
})
}
for name := range config.GetTools() {
switch name {
case llm.ToolCodeInterpreter:
cc.AddTool(zhipu.ChatCompletionToolCodeInterpreter{})
case llm.ToolWebSearch:
cc.AddTool(zhipu.ChatCompletionToolWebBrowser{})
}
}
if config.GetMaxTokens() != 0 {
cc.SetMaxTokens(config.GetMaxTokens())
}
if config.GetTemperature() != 0 {
cc.SetTemperature(config.GetTemperature())
}
if config.GetTopP() != 0 {
cc.SetTopP(config.GetTopP())
}
if callback != nil {
cc.SetStreamHandler(func(r2 zhipu.ChatCompletionResponse) error {
if r2.Choices != nil {
for _, ch := range r2.Choices {
text := ch.Delta.Content
callback(text)
}
}
return nil
})
}
if r, err := cc.Do(context.Background()); err == nil {
results := make([]string, 0)
if r.Choices != nil {
for _, ch := range r.Choices {
results = append(results, ch.Message.Content)
}
}
return strings.Join(results, ""), llm.TokenUsage{
AskTokens: r.Usage.PromptTokens,
AnswerTokens: r.Usage.CompletionTokens,
TotalTokens: r.Usage.TotalTokens,
}, nil
} else {
return "", llm.TokenUsage{}, err
}
}

60
llm/zhipu/config.go Normal file
View File

@ -0,0 +1,60 @@
package zhipu
import (
"apigo.cc/ai/ai/llm"
"apigo.cc/ai/ai/llm/zhipu/zhipu"
)
type LLM struct {
config llm.Config
}
var NameMap = map[string]string{
llm.TypeText: zhipu.MultiContentTypeText,
llm.TypeImage: zhipu.MultiContentTypeImageURL,
llm.TypeVideo: zhipu.MultiContentTypeVideoURL,
llm.RoleSystem: zhipu.RoleSystem,
llm.RoleUser: zhipu.RoleUser,
llm.RoleAssistant: zhipu.RoleAssistant,
llm.RoleTool: zhipu.RoleTool,
}
const (
ModelGLM4Plus = "GLM-4-Plus"
ModelGLM40520 = "GLM-4-0520"
ModelGLM4Long = "GLM-4-Long"
ModelGLM4AirX = "GLM-4-AirX"
ModelGLM4Air = "GLM-4-Air"
ModelGLM4Flash = "GLM-4-Flash"
ModelGLM4AllTools = "GLM-4-AllTools"
ModelGLM4 = "GLM-4"
ModelGLM4VPlus = "GLM-4V-Plus"
ModelGLM4V = "GLM-4V"
ModelCogVideoX = "CogVideoX"
ModelCogView3Plus = "CogView-3-Plus"
ModelCogView3 = "CogView-3"
ModelEmbedding3 = "Embedding-3"
ModelEmbedding2 = "Embedding-2"
ModelCharGLM3 = "CharGLM-3"
ModelEmohaa = "Emohaa"
ModelCodeGeeX4 = "CodeGeeX-4"
)
func (lm *LLM) Support() llm.Support {
return llm.Support{
Ask: true,
AskWithImage: true,
AskWithVideo: true,
AskWithCodeInterpreter: true,
AskWithWebSearch: true,
MakeImage: true,
MakeVideo: true,
Models: []string{ModelGLM4Plus, ModelGLM40520, ModelGLM4Long, ModelGLM4AirX, ModelGLM4Air, ModelGLM4Flash, ModelGLM4AllTools, ModelGLM4, ModelGLM4VPlus, ModelGLM4V, ModelCogVideoX, ModelCogView3Plus, ModelCogView3, ModelEmbedding3, ModelEmbedding2, ModelCharGLM3, ModelEmohaa, ModelCodeGeeX4},
}
}
func init() {
llm.Register("zhipu", func(config llm.Config) llm.LLM {
return &LLM{config: config}
})
}

88
llm/zhipu/gc.go Normal file
View File

@ -0,0 +1,88 @@
package zhipu
import (
"apigo.cc/ai/ai/llm"
"apigo.cc/ai/ai/llm/zhipu/zhipu"
"context"
"errors"
"time"
)
func (lm *LLM) FastMakeImage(prompt string, config llm.GCConfig) ([]string, error) {
config.Model = ModelCogView3Plus
return lm.MakeImage(prompt, config)
}
func (lm *LLM) BestMakeImage(prompt string, config llm.GCConfig) ([]string, error) {
config.Model = ModelCogView3
return lm.MakeImage(prompt, config)
}
func (lm *LLM) MakeImage(prompt string, config llm.GCConfig) ([]string, error) {
c, err := zhipu.NewClient(zhipu.WithAPIKey(lm.config.ApiKey), zhipu.WithBaseURL(lm.config.Endpoint))
if err != nil {
return nil, err
}
cc := c.ImageGeneration(config.Model).SetPrompt(prompt)
if config.Size != "" {
cc.SetSize(config.Size)
}
if r, err := cc.Do(context.Background()); err == nil {
results := make([]string, 0)
for _, item := range r.Data {
results = append(results, item.URL)
}
return results, nil
} else {
return nil, err
}
}
func (lm *LLM) FastMakeVideo(prompt string, config llm.GCConfig) ([]string, []string, error) {
config.Model = ModelCogVideoX
return lm.MakeVideo(prompt, config)
}
func (lm *LLM) BestMakeVideo(prompt string, config llm.GCConfig) ([]string, []string, error) {
config.Model = ModelCogVideoX
return lm.MakeVideo(prompt, config)
}
func (lm *LLM) MakeVideo(prompt string, config llm.GCConfig) ([]string, []string, error) {
c, err := zhipu.NewClient(zhipu.WithAPIKey(lm.config.ApiKey), zhipu.WithBaseURL(lm.config.Endpoint))
if err != nil {
return nil, nil, err
}
cc := c.VideoGeneration(config.Model).SetPrompt(prompt)
if config.Ref != "" {
cc.SetImageURL(config.Ref)
}
if resp, err := cc.Do(context.Background()); err == nil {
for i := 0; i < 1200; i++ {
r, err := c.AsyncResult(resp.ID).Do(context.Background())
if err != nil {
return nil, nil, err
}
if r.TaskStatus == zhipu.VideoGenerationTaskStatusSuccess {
covers := make([]string, 0)
results := make([]string, 0)
for _, item := range r.VideoResult {
results = append(results, item.URL)
covers = append(covers, item.CoverImageURL)
}
return results, covers, nil
}
if r.TaskStatus == zhipu.VideoGenerationTaskStatusFail {
return nil, nil, errors.New("fail on task " + resp.ID)
}
time.Sleep(3 * time.Second)
}
return nil, nil, errors.New("timeout on task " + resp.ID)
} else {
return nil, nil, err
}
}

View File

@ -0,0 +1,108 @@
# Changelog
All notable changes to this project will be documented in this file. See [conventional commits](https://www.conventionalcommits.org/) for commit guidelines.
- - -
## v0.1.2 - 2024-08-15
#### Bug Fixes
- add FinishReasonStopSequence - (01b4201) - GUO YANKE
#### Documentation
- update README.md [skip ci] - (e48a88b) - GUO YANKE
#### Features
- add videos/generations - (7261999) - GUO YANKE
#### Miscellaneous Chores
- relaxing go version to 1.18 - (6acc17c) - GUO YANKE
- - -
## v0.1.1 - 2024-07-17
#### Documentation
- update README.md [skip ci] - (695432a) - GUO YANKE
#### Features
- add support for GLM-4-AllTools - (9627a36) - GUO YANKE
- - -
## v0.1.0 - 2024-06-28
#### Bug Fixes
- rename client function for batch list - (40ac05f) - GUO YANKE
#### Documentation
- update README.md [skip ci] - (6ce5754) - GUO YANKE
#### Features
- add knowledge capacity service - (4ce62b3) - GUO YANKE
#### Refactoring
- update batch service - (b92d438) - GUO YANKE
- update chat completion service - (19dd77f) - GUO YANKE
- update embedding service - (c1bbc2d) - GUO YANKE
- update file services - (7ef4d87) - GUO YANKE
- update fine tune services, using APIError - (15aed88) - GUO YANKE
- update fine tune services - (664523b) - GUO YANKE
- update image generation service - (a18e028) - GUO YANKE
- update knowledge services - (c7bfb73) - GUO YANKE
- - -
## v0.0.6 - 2024-06-28
#### Features
- add batch support for result reader - (c062095) - GUO YANKE
- add fine tune services - (f172f51) - GUO YANKE
- add knowledge service - (09792b5) - GUO YANKE
- - -
## v0.0.5 - 2024-06-28
#### Bug Fixes
- api error parsing - (60a17f4) - GUO YANKE
#### Features
- add batch service - (389aec3) - GUO YANKE
- add batch support for chat completions, image generations and embeddings - (c017ffd) - GUO YANKE
- add file edit/get/delete service - (8a4d309) - GUO YANKE
- add file create serivce - (6d2140b) - GUO YANKE
- - -
## v0.0.4 - 2024-06-26
#### Bug Fixes
- remove Client.R(), hide resty for future removal - (dc2a4ca) - GUO YANKE
#### Features
- add meta support for charglm - (fdd20e7) - GUO YANKE
- add client option to custom http client - (c62d6a9) - GUO YANKE
- - -
## v0.0.3 - 2024-06-26
#### Features
- add image generation service - (9f3f54f) - GUO YANKE
- add support for vision models - (2dcd82a) - GUO YANKE
- add embedding service - (f57806a) - GUO YANKE
- - -
## v0.0.2 - 2024-06-26
#### Bug Fixes
- **(deps)** update golang-jwt/jwt to v5 - (2f76a57) - GUO YANKE
#### Features
- add constants for roles - (3d08a72) - GUO YANKE
- - -
## v0.0.1 - 2024-06-26
#### Bug Fixes
- add json tag "omitempty" to various types - (bf81097) - GUO YANKE
#### Continuous Integration
- add github action workflows for testing - (5a64987) - GUO YANKE
#### Documentation
- update README.md [skip ci] - (d504f57) - GUO YANKE
#### Features
- add chat completion in stream mode - (130fe1d) - GUO YANKE
- add chat completion in non-stream mode - (2326e37) - GUO YANKE
- support debug option while creating client - (0f104d8) - GUO YANKE
- add APIError and APIErrorResponse - (1886d85) - GUO YANKE
- add client struct - (710d8e8) - GUO YANKE
#### Refactoring
- change signature of Client#createJWT since there is no reason to fail - (f0d7887) - GUO YANKE
#### Tests
- add client_test.go - (a3fc217) - GUO YANKE
- - -
Changelog generated by [cocogitto](https://github.com/cocogitto/cocogitto).

21
llm/zhipu/zhipu/LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 Yanke G.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

280
llm/zhipu/zhipu/README.md Normal file
View File

@ -0,0 +1,280 @@
# zhipu
[![Go Reference](https://pkg.go.dev/badge/github.com/yankeguo/zhipu.svg)](https://pkg.go.dev/github.com/yankeguo/zhipu)
[![go](https://github.com/yankeguo/zhipu/actions/workflows/go.yml/badge.svg)](https://github.com/yankeguo/zhipu/actions/workflows/go.yml)
[中文文档](README.zh.md)
A 3rd-Party Golang Client Library for Zhipu AI Platform
## Usage
### Install the package
```bash
go get -u github.com/yankeguo/zhipu
```
### Create a client
```go
// this will use environment variables ZHIPUAI_API_KEY
client, err := zhipu.NewClient()
// or you can specify the API key
client, err = zhipu.NewClient(zhipu.WithAPIKey("your api key"))
```
### Use the client
**ChatCompletion**
```go
service := client.ChatCompletion("glm-4-flash").
AddMessage(zhipu.ChatCompletionMessage{
Role: "user",
Content: "你好",
})
res, err := service.Do(context.Background())
if err != nil {
zhipu.GetAPIErrorCode(err) // get the API error code
} else {
println(res.Choices[0].Message.Content)
}
```
**ChatCompletion (Stream)**
```go
service := client.ChatCompletion("glm-4-flash").
AddMessage(zhipu.ChatCompletionMessage{
Role: "user",
Content: "你好",
}).SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
println(chunk.Choices[0].Delta.Content)
return nil
})
res, err := service.Do(context.Background())
if err != nil {
zhipu.GetAPIErrorCode(err) // get the API error code
} else {
// this package will combine the stream chunks and build a final result mimicking the non-streaming API
println(res.Choices[0].Message.Content)
}
```
**ChatCompletion (Stream with GLM-4-AllTools)**
```go
// CodeInterpreter
s := client.ChatCompletion("GLM-4-AllTools")
s.AddMessage(zhipu.ChatCompletionMultiMessage{
Role: "user",
Content: []zhipu.ChatCompletionMultiContent{
{
Type: "text",
Text: "计算[5,10,20,700,99,310,978,100]的平均值和方差。",
},
},
})
s.AddTool(zhipu.ChatCompletionToolCodeInterpreter{
Sandbox: zhipu.Ptr(CodeInterpreterSandboxAuto),
})
s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
for _, c := range chunk.Choices {
for _, tc := range c.Delta.ToolCalls {
if tc.Type == ToolTypeCodeInterpreter && tc.CodeInterpreter != nil {
if tc.CodeInterpreter.Input != "" {
// DO SOMETHING
}
if len(tc.CodeInterpreter.Outputs) > 0 {
// DO SOMETHING
}
}
}
}
return nil
})
// WebBrowser
// CAUTION: NOT 'WebSearch'
s := client.ChatCompletion("GLM-4-AllTools")
s.AddMessage(zhipu.ChatCompletionMultiMessage{
Role: "user",
Content: []zhipu.ChatCompletionMultiContent{
{
Type: "text",
Text: "搜索下本周深圳天气如何",
},
},
})
s.AddTool(zhipu.ChatCompletionToolWebBrowser{})
s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
for _, c := range chunk.Choices {
for _, tc := range c.Delta.ToolCalls {
if tc.Type == ToolTypeWebBrowser && tc.WebBrowser != nil {
if tc.WebBrowser.Input != "" {
// DO SOMETHING
}
if len(tc.WebBrowser.Outputs) > 0 {
// DO SOMETHING
}
}
}
}
return nil
})
s.Do(context.Background())
// DrawingTool
s := client.ChatCompletion("GLM-4-AllTools")
s.AddMessage(zhipu.ChatCompletionMultiMessage{
Role: "user",
Content: []zhipu.ChatCompletionMultiContent{
{
Type: "text",
Text: "画一个正弦函数图像",
},
},
})
s.AddTool(zhipu.ChatCompletionToolDrawingTool{})
s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
for _, c := range chunk.Choices {
for _, tc := range c.Delta.ToolCalls {
if tc.Type == ToolTypeDrawingTool && tc.DrawingTool != nil {
if tc.DrawingTool.Input != "" {
// DO SOMETHING
}
if len(tc.DrawingTool.Outputs) > 0 {
// DO SOMETHING
}
}
}
}
return nil
})
s.Do(context.Background())
```
**Embedding**
```go
service := client.Embedding("embedding-v2").SetInput("你好呀")
service.Do(context.Background())
```
**Image Generation**
```go
service := client.ImageGeneration("cogview-3").SetPrompt("一只可爱的小猫咪")
service.Do(context.Background())
```
**Video Generation**
```go
service := client.VideoGeneration("cogvideox").SetPrompt("一只可爱的小猫咪")
resp, err := service.Do(context.Background())
for {
result, err := client.AsyncResult(resp.ID).Do(context.Background())
if result.TaskStatus == zhipu.VideoGenerationTaskStatusSuccess {
_ = result.VideoResult[0].URL
_ = result.VideoResult[0].CoverImageURL
break
}
if result.TaskStatus != zhipu.VideoGenerationTaskStatusProcessing {
break
}
time.Sleep(5 * time.Second)
}
```
**Upload File (Retrieval)**
```go
service := client.FileCreate(zhipu.FilePurposeRetrieval)
service.SetLocalFile(filepath.Join("testdata", "test-file.txt"))
service.SetKnowledgeID("your-knowledge-id")
service.Do(context.Background())
```
**Upload File (Fine-Tune)**
```go
service := client.FileCreate(zhipu.FilePurposeFineTune)
service.SetLocalFile(filepath.Join("testdata", "test-file.jsonl"))
service.Do(context.Background())
```
**Batch Create**
```go
service := client.BatchCreate().
SetInputFileID("fileid").
SetCompletionWindow(zhipu.BatchCompletionWindow24h).
SetEndpoint(BatchEndpointV4ChatCompletions)
service.Do(context.Background())
```
**Knowledge Base**
```go
client.KnowledgeCreate("")
client.KnowledgeEdit("")
```
**Fine Tune**
```go
client.FineTuneCreate("")
```
### Batch Support
**Batch File Writer**
```go
f, err := os.OpenFile("batch.jsonl", os.O_CREATE|os.O_WRONLY, 0644)
bw := zhipu.NewBatchFileWriter(f)
bw.Add("action_1", client.ChatCompletion("glm-4-flash").
AddMessage(zhipu.ChatCompletionMessage{
Role: "user",
Content: "你好",
}))
bw.Add("action_2", client.Embedding("embedding-v2").SetInput("你好呀"))
bw.Add("action_3", client.ImageGeneration("cogview-3").SetPrompt("一只可爱的小猫咪"))
```
**Batch Result Reader**
```go
br := zhipu.NewBatchResultReader[zhipu.ChatCompletionResponse](r)
for {
var res zhipu.BatchResult[zhipu.ChatCompletionResponse]
err := br.Read(&res)
if err != nil {
break
}
}
```
## Donation
Executing unit tests will actually call the ChatGLM API and consume my quota. Please donate and thank you for your support!
<img src="./wechat-donation.png" width="180"/>
## Credits
GUO YANKE, MIT License

View File

@ -0,0 +1,278 @@
# zhipu
[![Go Reference](https://pkg.go.dev/badge/github.com/yankeguo/zhipu.svg)](https://pkg.go.dev/github.com/yankeguo/zhipu)
[![go](https://github.com/yankeguo/zhipu/actions/workflows/go.yml/badge.svg)](https://github.com/yankeguo/zhipu/actions/workflows/go.yml)
Zhipu AI 平台第三方 Golang 客户端库
## 用法
### 安装库
```bash
go get -u github.com/yankeguo/zhipu
```
### 创建客户端
```go
// 默认使用环境变量 ZHIPUAI_API_KEY
client, err := zhipu.NewClient()
// 或者手动指定密钥
client, err = zhipu.NewClient(zhipu.WithAPIKey("your api key"))
```
### 使用客户端
**ChatCompletion(大语言模型)**
```go
service := client.ChatCompletion("glm-4-flash").
AddMessage(zhipu.ChatCompletionMessage{
Role: "user",
Content: "你好",
})
res, err := service.Do(context.Background())
if err != nil {
zhipu.GetAPIErrorCode(err) // get the API error code
} else {
println(res.Choices[0].Message.Content)
}
```
**ChatCompletion(流式调用大语言模型)**
```go
service := client.ChatCompletion("glm-4-flash").
AddMessage(zhipu.ChatCompletionMessage{
Role: "user",
Content: "你好",
}).SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
println(chunk.Choices[0].Delta.Content)
return nil
})
res, err := service.Do(context.Background())
if err != nil {
zhipu.GetAPIErrorCode(err) // get the API error code
} else {
// this package will combine the stream chunks and build a final result mimicking the non-streaming API
println(res.Choices[0].Message.Content)
}
```
**ChatCompletion(流式调用大语言工具模型GLM-4-AllTools)**
```go
// CodeInterpreter
s := client.ChatCompletion("GLM-4-AllTools")
s.AddMessage(zhipu.ChatCompletionMultiMessage{
Role: "user",
Content: []zhipu.ChatCompletionMultiContent{
{
Type: "text",
Text: "计算[5,10,20,700,99,310,978,100]的平均值和方差。",
},
},
})
s.AddTool(zhipu.ChatCompletionToolCodeInterpreter{
Sandbox: zhipu.Ptr(CodeInterpreterSandboxAuto),
})
s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
for _, c := range chunk.Choices {
for _, tc := range c.Delta.ToolCalls {
if tc.Type == ToolTypeCodeInterpreter && tc.CodeInterpreter != nil {
if tc.CodeInterpreter.Input != "" {
// DO SOMETHING
}
if len(tc.CodeInterpreter.Outputs) > 0 {
// DO SOMETHING
}
}
}
}
return nil
})
// WebBrowser
// CAUTION: NOT 'WebSearch'
s := client.ChatCompletion("GLM-4-AllTools")
s.AddMessage(zhipu.ChatCompletionMultiMessage{
Role: "user",
Content: []zhipu.ChatCompletionMultiContent{
{
Type: "text",
Text: "搜索下本周深圳天气如何",
},
},
})
s.AddTool(zhipu.ChatCompletionToolWebBrowser{})
s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
for _, c := range chunk.Choices {
for _, tc := range c.Delta.ToolCalls {
if tc.Type == ToolTypeWebBrowser && tc.WebBrowser != nil {
if tc.WebBrowser.Input != "" {
// DO SOMETHING
}
if len(tc.WebBrowser.Outputs) > 0 {
// DO SOMETHING
}
}
}
}
return nil
})
s.Do(context.Background())
// DrawingTool
s := client.ChatCompletion("GLM-4-AllTools")
s.AddMessage(zhipu.ChatCompletionMultiMessage{
Role: "user",
Content: []zhipu.ChatCompletionMultiContent{
{
Type: "text",
Text: "画一个正弦函数图像",
},
},
})
s.AddTool(zhipu.ChatCompletionToolDrawingTool{})
s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
for _, c := range chunk.Choices {
for _, tc := range c.Delta.ToolCalls {
if tc.Type == ToolTypeDrawingTool && tc.DrawingTool != nil {
if tc.DrawingTool.Input != "" {
// DO SOMETHING
}
if len(tc.DrawingTool.Outputs) > 0 {
// DO SOMETHING
}
}
}
}
return nil
})
s.Do(context.Background())
```
**Embedding**
```go
service := client.Embedding("embedding-v2").SetInput("你好呀")
service.Do(context.Background())
```
**ImageGeneration(图像生成)**
```go
service := client.ImageGeneration("cogview-3").SetPrompt("一只可爱的小猫咪")
service.Do(context.Background())
```
**VideoGeneration(视频生成)**
```go
service := client.VideoGeneration("cogvideox").SetPrompt("一只可爱的小猫咪")
resp, err := service.Do(context.Background())
for {
result, err := client.AsyncResult(resp.ID).Do(context.Background())
if result.TaskStatus == zhipu.VideoGenerationTaskStatusSuccess {
_ = result.VideoResult[0].URL
_ = result.VideoResult[0].CoverImageURL
break
}
if result.TaskStatus != zhipu.VideoGenerationTaskStatusProcessing {
break
}
time.Sleep(5 * time.Second)
}
```
**UploadFile(上传文件用于取回)**
```go
service := client.FileCreate(zhipu.FilePurposeRetrieval)
service.SetLocalFile(filepath.Join("testdata", "test-file.txt"))
service.SetKnowledgeID("your-knowledge-id")
service.Do(context.Background())
```
**UploadFile(上传文件用于微调)**
```go
service := client.FileCreate(zhipu.FilePurposeFineTune)
service.SetLocalFile(filepath.Join("testdata", "test-file.jsonl"))
service.Do(context.Background())
```
**BatchCreate(创建批量任务)**
```go
service := client.BatchCreate().
SetInputFileID("fileid").
SetCompletionWindow(zhipu.BatchCompletionWindow24h).
SetEndpoint(BatchEndpointV4ChatCompletions)
service.Do(context.Background())
```
**KnowledgeBase(知识库)**
```go
client.KnowledgeCreate("")
client.KnowledgeEdit("")
```
**FineTune(微调)**
```go
client.FineTuneCreate("")
```
### 批量任务辅助工具
**批量任务文件创建**
```go
f, err := os.OpenFile("batch.jsonl", os.O_CREATE|os.O_WRONLY, 0644)
bw := zhipu.NewBatchFileWriter(f)
bw.Add("action_1", client.ChatCompletion("glm-4-flash").
AddMessage(zhipu.ChatCompletionMessage{
Role: "user",
Content: "你好",
}))
bw.Add("action_2", client.Embedding("embedding-v2").SetInput("你好呀"))
bw.Add("action_3", client.ImageGeneration("cogview-3").SetPrompt("一只可爱的小猫咪"))
```
**批量任务结果解析**
```go
br := zhipu.NewBatchResultReader[zhipu.ChatCompletionResponse](r)
for {
var res zhipu.BatchResult[zhipu.ChatCompletionResponse]
err := br.Read(&res)
if err != nil {
break
}
}
```
## 赞助
执行单元测试会真实调用GLM接口消耗我充值的额度开发不易请微信扫码捐赠感谢您的支持
<img src="./wechat-donation.png" width="180"/>
## 许可证
GUO YANKE, MIT License

View File

@ -0,0 +1,63 @@
package zhipu
import (
"context"
"github.com/go-resty/resty/v2"
)
// AsyncResultService creates a new async result get service
type AsyncResultService struct {
client *Client
id string
}
// AsyncResultVideo is the video result of the AsyncResultService
type AsyncResultVideo struct {
URL string `json:"url"`
CoverImageURL string `json:"cover_image_url"`
}
// AsyncResultResponse is the response of the AsyncResultService
type AsyncResultResponse struct {
Model string `json:"model"`
TaskStatus string `json:"task_status"`
RequestID string `json:"request_id"`
ID string `json:"id"`
VideoResult []AsyncResultVideo `json:"video_result"`
}
// NewAsyncResultService creates a new async result get service
func NewAsyncResultService(client *Client) *AsyncResultService {
return &AsyncResultService{
client: client,
}
}
// SetID sets the id parameter
func (s *AsyncResultService) SetID(id string) *AsyncResultService {
s.id = id
return s
}
func (s *AsyncResultService) Do(ctx context.Context) (res AsyncResultResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetResult(&res).
SetError(&apiError).
Get("async-result/" + s.id); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}

View File

@ -0,0 +1 @@
package zhipu

258
llm/zhipu/zhipu/batch.go Normal file
View File

@ -0,0 +1,258 @@
package zhipu
import (
"context"
"encoding/json"
"strconv"
"github.com/go-resty/resty/v2"
)
const (
BatchEndpointV4ChatCompletions = "/v4/chat/completions"
BatchEndpointV4ImagesGenerations = "/v4/images/generations"
BatchEndpointV4Embeddings = "/v4/embeddings"
BatchEndpointV4VideosGenerations = "/v4/videos/generations"
BatchCompletionWindow24h = "24h"
)
// BatchRequestCounts represents the counts of the batch requests.
type BatchRequestCounts struct {
Total int64 `json:"total"`
Completed int64 `json:"completed"`
Failed int64 `json:"failed"`
}
// BatchItem represents a batch item.
type BatchItem struct {
ID string `json:"id"`
Object any `json:"object"`
Endpoint string `json:"endpoint"`
InputFileID string `json:"input_file_id"`
CompletionWindow string `json:"completion_window"`
Status string `json:"status"`
OutputFileID string `json:"output_file_id"`
ErrorFileID string `json:"error_file_id"`
CreatedAt int64 `json:"created_at"`
InProgressAt int64 `json:"in_progress_at"`
ExpiresAt int64 `json:"expires_at"`
FinalizingAt int64 `json:"finalizing_at"`
CompletedAt int64 `json:"completed_at"`
FailedAt int64 `json:"failed_at"`
ExpiredAt int64 `json:"expired_at"`
CancellingAt int64 `json:"cancelling_at"`
CancelledAt int64 `json:"cancelled_at"`
RequestCounts BatchRequestCounts `json:"request_counts"`
Metadata json.RawMessage `json:"metadata"`
}
// BatchCreateService is a service to create a batch.
type BatchCreateService struct {
client *Client
inputFileID string
endpoint string
completionWindow string
metadata any
}
// NewBatchCreateService creates a new BatchCreateService.
func NewBatchCreateService(client *Client) *BatchCreateService {
return &BatchCreateService{client: client}
}
// SetInputFileID sets the input file id for the batch.
func (s *BatchCreateService) SetInputFileID(inputFileID string) *BatchCreateService {
s.inputFileID = inputFileID
return s
}
// SetEndpoint sets the endpoint for the batch.
func (s *BatchCreateService) SetEndpoint(endpoint string) *BatchCreateService {
s.endpoint = endpoint
return s
}
// SetCompletionWindow sets the completion window for the batch.
func (s *BatchCreateService) SetCompletionWindow(window string) *BatchCreateService {
s.completionWindow = window
return s
}
// SetMetadata sets the metadata for the batch.
func (s *BatchCreateService) SetMetadata(metadata any) *BatchCreateService {
s.metadata = metadata
return s
}
// Do executes the batch create service.
func (s *BatchCreateService) Do(ctx context.Context) (res BatchItem, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetBody(M{
"input_file_id": s.inputFileID,
"endpoint": s.endpoint,
"completion_window": s.completionWindow,
"metadata": s.metadata,
}).
SetResult(&res).
SetError(&apiError).
Post("batches"); err != nil {
return
}
if resp.IsError() {
err = apiError
}
return
}
// BatchGetService is a service to get a batch.
type BatchGetService struct {
client *Client
batchID string
}
// BatchGetResponse represents the response of the batch get service.
type BatchGetResponse = BatchItem
// NewBatchGetService creates a new BatchGetService.
func NewBatchGetService(client *Client) *BatchGetService {
return &BatchGetService{client: client}
}
// SetBatchID sets the batch id for the batch get service.
func (s *BatchGetService) SetBatchID(batchID string) *BatchGetService {
s.batchID = batchID
return s
}
// Do executes the batch get service.
func (s *BatchGetService) Do(ctx context.Context) (res BatchGetResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetPathParam("batch_id", s.batchID).
SetResult(&res).
SetError(&apiError).
Get("batches/{batch_id}"); err != nil {
return
}
if resp.IsError() {
err = apiError
}
return
}
// BatchCancelService is a service to cancel a batch.
type BatchCancelService struct {
client *Client
batchID string
}
// NewBatchCancelService creates a new BatchCancelService.
func NewBatchCancelService(client *Client) *BatchCancelService {
return &BatchCancelService{client: client}
}
// SetBatchID sets the batch id for the batch cancel service.
func (s *BatchCancelService) SetBatchID(batchID string) *BatchCancelService {
s.batchID = batchID
return s
}
// Do executes the batch cancel service.
func (s *BatchCancelService) Do(ctx context.Context) (err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetPathParam("batch_id", s.batchID).
SetBody(M{}).
SetError(&apiError).
Post("batches/{batch_id}/cancel"); err != nil {
return
}
if resp.IsError() {
err = apiError
}
return
}
// BatchListService is a service to list batches.
type BatchListService struct {
client *Client
after *string
limit *int
}
// BatchListResponse represents the response of the batch list service.
type BatchListResponse struct {
Object string `json:"object"`
Data []BatchItem `json:"data"`
FirstID string `json:"first_id"`
LastID string `json:"last_id"`
HasMore bool `json:"has_more"`
}
// NewBatchListService creates a new BatchListService.
func NewBatchListService(client *Client) *BatchListService {
return &BatchListService{client: client}
}
// SetAfter sets the after cursor for the batch list service.
func (s *BatchListService) SetAfter(after string) *BatchListService {
s.after = &after
return s
}
// SetLimit sets the limit for the batch list service.
func (s *BatchListService) SetLimit(limit int) *BatchListService {
s.limit = &limit
return s
}
// Do executes the batch list service.
func (s *BatchListService) Do(ctx context.Context) (res BatchListResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
req := s.client.request(ctx)
if s.after != nil {
req.SetQueryParam("after", *s.after)
}
if s.limit != nil {
req.SetQueryParam("limit", strconv.Itoa(*s.limit))
}
if resp, err = req.
SetResult(&res).
SetError(&apiError).
Get("batches"); err != nil {
return
}
if resp.IsError() {
err = apiError
}
return
}

View File

@ -0,0 +1,63 @@
package zhipu
import (
"encoding/json"
"io"
)
// BatchSupport is the interface for services with batch support.
type BatchSupport interface {
BatchMethod() string
BatchURL() string
BatchBody() any
}
// BatchFileWriter is a writer for batch files.
type BatchFileWriter struct {
w io.Writer
je *json.Encoder
}
// NewBatchFileWriter creates a new BatchFileWriter.
func NewBatchFileWriter(w io.Writer) *BatchFileWriter {
return &BatchFileWriter{w: w, je: json.NewEncoder(w)}
}
// Write writes a batch file.
func (b *BatchFileWriter) Write(customID string, s BatchSupport) error {
return b.je.Encode(M{
"custom_id": customID,
"method": s.BatchMethod(),
"url": s.BatchURL(),
"body": s.BatchBody(),
})
}
// BatchResultResponse is the response of a batch result.
type BatchResultResponse[T any] struct {
StatusCode int `json:"status_code"`
Body T `json:"body"`
}
// BatchResult is the result of a batch.
type BatchResult[T any] struct {
ID string `json:"id"`
CustomID string `json:"custom_id"`
Response BatchResultResponse[T] `json:"response"`
}
// BatchResultReader reads batch results.
type BatchResultReader[T any] struct {
r io.Reader
jd *json.Decoder
}
// NewBatchResultReader creates a new BatchResultReader.
func NewBatchResultReader[T any](r io.Reader) *BatchResultReader[T] {
return &BatchResultReader[T]{r: r, jd: json.NewDecoder(r)}
}
// Read reads a batch result.
func (r *BatchResultReader[T]) Read(out *BatchResult[T]) error {
return r.jd.Decode(out)
}

View File

@ -0,0 +1,73 @@
package zhipu
import (
"bytes"
"io"
"testing"
"github.com/stretchr/testify/require"
)
func TestBatchFileWriter(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
buf := &bytes.Buffer{}
w := NewBatchFileWriter(buf)
err = w.Write("batch-1", client.ChatCompletion("a").AddMessage(ChatCompletionMessage{
Role: "user", Content: "hello",
}))
require.NoError(t, err)
err = w.Write("batch-2", client.Embedding("c").SetInput("whoa"))
require.NoError(t, err)
err = w.Write("batch-3", client.ImageGeneration("d").SetPrompt("whoa"))
require.NoError(t, err)
require.Equal(t, `{"body":{"messages":[{"role":"user","content":"hello"}],"model":"a"},"custom_id":"batch-1","method":"POST","url":"/v4/chat/completions"}
{"body":{"input":"whoa","model":"c"},"custom_id":"batch-2","method":"POST","url":"/v4/embeddings"}
{"body":{"model":"d","prompt":"whoa"},"custom_id":"batch-3","method":"POST","url":"/v4/images/generations"}
`, buf.String())
}
func TestBatchResultReader(t *testing.T) {
result := `
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":26,"prompt_tokens":89,"total_tokens":115},"model":"glm-4","id":"8668357533850320547","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"订单处理慢\"\n}\n'''"}}],"request_id":"615-request-1"}},"custom_id":"request-1","id":"batch_1791490810192076800"}
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":22,"prompt_tokens":94,"total_tokens":116},"model":"glm-4","id":"8668368425887509080","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"产品缺陷\"\n}\n'''"}}],"request_id":"616-request-2"}},"custom_id":"request-2","id":"batch_1791490810192076800"}
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":25,"prompt_tokens":86,"total_tokens":111},"model":"glm-4","id":"8668355815863214980","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"正面\",\n \"特定问题标注\": \"性价比\"\n}\n'''"}}],"request_id":"617-request-3"}},"custom_id":"request-3","id":"batch_1791490810192076800"}
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":28,"prompt_tokens":89,"total_tokens":117},"model":"glm-4","id":"8668355815863214981","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"说明文档不清晰\"\n}\n'''"}}],"request_id":"618-request-4"}},"custom_id":"request-4","id":"batch_1791490810192076800"}
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":26,"prompt_tokens":88,"total_tokens":114},"model":"glm-4","id":"8668357533850320546","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"中性\",\n \"特定问题标注\": \"价格问题\"\n}\n'''"}}],"request_id":"619-request-5"}},"custom_id":"request-5","id":"batch_1791490810192076800"}
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":26,"prompt_tokens":90,"total_tokens":116},"model":"glm-4","id":"8668356159460662846","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"配送延迟\"\n}\n'''"}}],"request_id":"620-request-6"}},"custom_id":"request-6","id":"batch_1791490810192076800"}
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":27,"prompt_tokens":88,"total_tokens":115},"model":"glm-4","id":"8668357671289274638","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"产品描述不符\"\n}\n'''"}}],"request_id":"621-request-7"}},"custom_id":"request-7","id":"batch_1791490810192076800"}
{"response":{"status_code":200,"body":{"created":1715959702,"usage":{"completion_tokens":26,"prompt_tokens":87,"total_tokens":113},"model":"glm-4","id":"8668355644064514872","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"正面\",\n \"特定问题标注\": \"客服态度\"\n}\n'''"}}],"request_id":"622-request-8"}},"custom_id":"request-8","id":"batch_1791490810192076800"}
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":29,"prompt_tokens":90,"total_tokens":119},"model":"glm-4","id":"8668357671289274639","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"包装问题, 产品损坏\"\n}\n'''"}}],"request_id":"623-request-9"}},"custom_id":"request-9","id":"batch_1791490810192076800"}
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":27,"prompt_tokens":87,"total_tokens":114},"model":"glm-4","id":"8668355644064514871","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"正面\",\n \"特定问题标注\": \"产品描述不符\"\n}\n'''"}}],"request_id":"624-request-10"}},"custom_id":"request-10","id":"batch_1791490810192076800"}
`
brr := NewBatchResultReader[ChatCompletionResponse](bytes.NewReader([]byte(result)))
var count int
for {
var res BatchResult[ChatCompletionResponse]
err := brr.Read(&res)
if err != nil {
if err == io.EOF {
err = nil
}
require.Equal(t, 10, count)
require.NoError(t, err)
break
}
require.Equal(t, 200, res.Response.StatusCode)
count++
}
}

View File

@ -0,0 +1,59 @@
package zhipu
import (
"bytes"
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestBatchServiceAll(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
buf := &bytes.Buffer{}
bfw := NewBatchFileWriter(buf)
err = bfw.Write("batch_1", client.ChatCompletion("glm-4-flash").AddMessage(ChatCompletionMessage{
Role: RoleUser, Content: "你好呀",
}))
require.NoError(t, err)
err = bfw.Write("batch_2", client.ChatCompletion("glm-4-flash").AddMessage(ChatCompletionMessage{
Role: RoleUser, Content: "你叫什么名字",
}))
require.NoError(t, err)
res, err := client.FileCreate(FilePurposeBatch).SetFile(bytes.NewReader(buf.Bytes()), "batch.jsonl").Do(context.Background())
require.NoError(t, err)
fileID := res.FileCreateFineTuneResponse.ID
require.NotEmpty(t, fileID)
res1, err := client.BatchCreate().
SetInputFileID(fileID).
SetCompletionWindow(BatchCompletionWindow24h).
SetEndpoint(BatchEndpointV4ChatCompletions).Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res1.ID)
res2, err := client.BatchGet(res1.ID).Do(context.Background())
require.NoError(t, err)
require.Equal(t, res2.ID, res1.ID)
res3, err := client.BatchList().Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res3.Data)
err = client.BatchCancel(res1.ID).Do(context.Background())
require.NoError(t, err)
}
func TestBatchListService(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
res, err := client.BatchList().Do(context.Background())
require.NoError(t, err)
t.Log(res)
}

View File

@ -0,0 +1,577 @@
package zhipu
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"io"
"github.com/go-resty/resty/v2"
)
const (
RoleSystem = "system"
RoleUser = "user"
RoleAssistant = "assistant"
RoleTool = "tool"
ToolChoiceAuto = "auto"
FinishReasonStop = "stop"
FinishReasonStopSequence = "stop_sequence"
FinishReasonToolCalls = "tool_calls"
FinishReasonLength = "length"
FinishReasonSensitive = "sensitive"
FinishReasonNetworkError = "network_error"
ToolTypeFunction = "function"
ToolTypeWebSearch = "web_search"
ToolTypeRetrieval = "retrieval"
MultiContentTypeText = "text"
MultiContentTypeImageURL = "image_url"
MultiContentTypeVideoURL = "video_url"
// New in GLM-4-AllTools
ToolTypeCodeInterpreter = "code_interpreter"
ToolTypeDrawingTool = "drawing_tool"
ToolTypeWebBrowser = "web_browser"
CodeInterpreterSandboxNone = "none"
CodeInterpreterSandboxAuto = "auto"
ChatCompletionStatusFailed = "failed"
ChatCompletionStatusCompleted = "completed"
ChatCompletionStatusRequiresAction = "requires_action"
)
// ChatCompletionTool is the interface for chat completion tool
type ChatCompletionTool interface {
isChatCompletionTool()
}
// ChatCompletionToolFunction is the function for chat completion
type ChatCompletionToolFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters any `json:"parameters"`
}
func (ChatCompletionToolFunction) isChatCompletionTool() {}
// ChatCompletionToolRetrieval is the retrieval for chat completion
type ChatCompletionToolRetrieval struct {
KnowledgeID string `json:"knowledge_id"`
PromptTemplate string `json:"prompt_template,omitempty"`
}
func (ChatCompletionToolRetrieval) isChatCompletionTool() {}
// ChatCompletionToolWebSearch is the web search for chat completion
type ChatCompletionToolWebSearch struct {
Enable *bool `json:"enable,omitempty"`
SearchQuery string `json:"search_query,omitempty"`
SearchResult bool `json:"search_result,omitempty"`
}
func (ChatCompletionToolWebSearch) isChatCompletionTool() {}
// ChatCompletionToolCodeInterpreter is the code interpreter for chat completion
// only in GLM-4-AllTools
type ChatCompletionToolCodeInterpreter struct {
Sandbox *string `json:"sandbox,omitempty"`
}
func (ChatCompletionToolCodeInterpreter) isChatCompletionTool() {}
// ChatCompletionToolDrawingTool is the drawing tool for chat completion
// only in GLM-4-AllTools
type ChatCompletionToolDrawingTool struct {
// no fields
}
func (ChatCompletionToolDrawingTool) isChatCompletionTool() {}
// ChatCompletionToolWebBrowser is the web browser for chat completion
type ChatCompletionToolWebBrowser struct {
// no fields
}
func (ChatCompletionToolWebBrowser) isChatCompletionTool() {}
// ChatCompletionUsage is the usage for chat completion
type ChatCompletionUsage struct {
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
TotalTokens int64 `json:"total_tokens"`
}
// ChatCompletionWebSearch is the web search result for chat completion
type ChatCompletionWebSearch struct {
Icon string `json:"icon"`
Title string `json:"title"`
Link string `json:"link"`
Media string `json:"media"`
Content string `json:"content"`
}
// ChatCompletionToolCallFunction is the function for chat completion tool call
type ChatCompletionToolCallFunction struct {
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
}
// ChatCompletionToolCallCodeInterpreterOutput is the output for chat completion tool call code interpreter
type ChatCompletionToolCallCodeInterpreterOutput struct {
Type string `json:"type"`
Logs string `json:"logs"`
File string `json:"file"`
}
// ChatCompletionToolCallCodeInterpreter is the code interpreter for chat completion tool call
type ChatCompletionToolCallCodeInterpreter struct {
Input string `json:"input"`
Outputs []ChatCompletionToolCallCodeInterpreterOutput `json:"outputs"`
}
// ChatCompletionToolCallDrawingToolOutput is the output for chat completion tool call drawing tool
type ChatCompletionToolCallDrawingToolOutput struct {
Image string `json:"image"`
}
// ChatCompletionToolCallDrawingTool is the drawing tool for chat completion tool call
type ChatCompletionToolCallDrawingTool struct {
Input string `json:"input"`
Outputs []ChatCompletionToolCallDrawingToolOutput `json:"outputs"`
}
// ChatCompletionToolCallWebBrowserOutput is the output for chat completion tool call web browser
type ChatCompletionToolCallWebBrowserOutput struct {
Title string `json:"title"`
Link string `json:"link"`
Content string `json:"content"`
}
// ChatCompletionToolCallWebBrowser is the web browser for chat completion tool call
type ChatCompletionToolCallWebBrowser struct {
Input string `json:"input"`
Outputs []ChatCompletionToolCallWebBrowserOutput `json:"outputs"`
}
// ChatCompletionToolCall is the tool call for chat completion
type ChatCompletionToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function *ChatCompletionToolCallFunction `json:"function,omitempty"`
CodeInterpreter *ChatCompletionToolCallCodeInterpreter `json:"code_interpreter,omitempty"`
DrawingTool *ChatCompletionToolCallDrawingTool `json:"drawing_tool,omitempty"`
WebBrowser *ChatCompletionToolCallWebBrowser `json:"web_browser,omitempty"`
}
type ChatCompletionMessageType interface {
isChatCompletionMessageType()
}
// ChatCompletionMessage is the message for chat completion
type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content,omitempty"`
ToolCalls []ChatCompletionToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
func (ChatCompletionMessage) isChatCompletionMessageType() {}
type ChatCompletionMultiContent struct {
Type string `json:"type"`
Text string `json:"text"`
ImageURL *URLItem `json:"image_url,omitempty"`
VideoURL *URLItem `json:"video_url,omitempty"`
}
// ChatCompletionMultiMessage is the multi message for chat completion
type ChatCompletionMultiMessage struct {
Role string `json:"role"`
Content []ChatCompletionMultiContent `json:"content"`
}
func (ChatCompletionMultiMessage) isChatCompletionMessageType() {}
// ChatCompletionMeta is the meta for chat completion
type ChatCompletionMeta struct {
UserInfo string `json:"user_info"`
BotInfo string `json:"bot_info"`
UserName string `json:"user_name"`
BotName string `json:"bot_name"`
}
// ChatCompletionChoice is the choice for chat completion
type ChatCompletionChoice struct {
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
Delta ChatCompletionMessage `json:"delta"` // stream mode
Message ChatCompletionMessage `json:"message"` // non-stream mode
}
// ChatCompletionResponse is the response for chat completion
type ChatCompletionResponse struct {
ID string `json:"id"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionChoice `json:"choices"`
Usage ChatCompletionUsage `json:"usage"`
WebSearch []ChatCompletionWebSearch `json:"web_search"`
// Status is the status of the chat completion, only in GLM-4-AllTools
Status string `json:"status"`
}
// ChatCompletionStreamHandler is the handler for chat completion stream
type ChatCompletionStreamHandler func(chunk ChatCompletionResponse) error
var (
chatCompletionStreamPrefix = []byte("data:")
chatCompletionStreamDone = []byte("[DONE]")
)
// chatCompletionReduceResponse reduce the chunk to the response
func chatCompletionReduceResponse(out *ChatCompletionResponse, chunk ChatCompletionResponse) {
if len(out.Choices) == 0 {
out.Choices = append(out.Choices, ChatCompletionChoice{})
}
// basic
out.ID = chunk.ID
out.Created = chunk.Created
out.Model = chunk.Model
// choices
if len(chunk.Choices) != 0 {
oc := &out.Choices[0]
cc := chunk.Choices[0]
oc.Index = cc.Index
if cc.Delta.Role != "" {
oc.Message.Role = cc.Delta.Role
}
oc.Message.Content += cc.Delta.Content
oc.Message.ToolCalls = append(oc.Message.ToolCalls, cc.Delta.ToolCalls...)
if cc.FinishReason != "" {
oc.FinishReason = cc.FinishReason
}
}
// usage
if chunk.Usage.CompletionTokens != 0 {
out.Usage.CompletionTokens = chunk.Usage.CompletionTokens
}
if chunk.Usage.PromptTokens != 0 {
out.Usage.PromptTokens = chunk.Usage.PromptTokens
}
if chunk.Usage.TotalTokens != 0 {
out.Usage.TotalTokens = chunk.Usage.TotalTokens
}
// web search
out.WebSearch = append(out.WebSearch, chunk.WebSearch...)
}
// chatCompletionDecodeStream decode the sse stream of chat completion
func chatCompletionDecodeStream(r io.Reader, fn func(chunk ChatCompletionResponse) error) (err error) {
br := bufio.NewReader(r)
for {
var line []byte
if line, err = br.ReadBytes('\n'); err != nil {
if errors.Is(err, io.EOF) {
err = nil
}
break
}
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if !bytes.HasPrefix(line, chatCompletionStreamPrefix) {
continue
}
data := bytes.TrimSpace(line[len(chatCompletionStreamPrefix):])
if bytes.Equal(data, chatCompletionStreamDone) {
break
}
if len(data) == 0 {
continue
}
var chunk ChatCompletionResponse
if err = json.Unmarshal(data, &chunk); err != nil {
return
}
if err = fn(chunk); err != nil {
return
}
}
return
}
// ChatCompletionStreamService is the service for chat completion stream
type ChatCompletionService struct {
client *Client
model string
requestID *string
doSample *bool
temperature *float64
topP *float64
maxTokens *int
stop []string
toolChoice *string
userID *string
meta *ChatCompletionMeta
messages []any
tools []any
streamHandler ChatCompletionStreamHandler
}
var (
_ BatchSupport = &ChatCompletionService{}
)
// NewChatCompletionService creates a new ChatCompletionService.
func NewChatCompletionService(client *Client) *ChatCompletionService {
return &ChatCompletionService{
client: client,
}
}
func (s *ChatCompletionService) BatchMethod() string {
return "POST"
}
func (s *ChatCompletionService) BatchURL() string {
return BatchEndpointV4ChatCompletions
}
func (s *ChatCompletionService) BatchBody() any {
return s.buildBody()
}
// SetModel set the model of the chat completion
func (s *ChatCompletionService) SetModel(model string) *ChatCompletionService {
s.model = model
return s
}
// SetMeta set the meta of the chat completion, optional
func (s *ChatCompletionService) SetMeta(meta ChatCompletionMeta) *ChatCompletionService {
s.meta = &meta
return s
}
// SetRequestID set the request id of the chat completion, optional
func (s *ChatCompletionService) SetRequestID(requestID string) *ChatCompletionService {
s.requestID = &requestID
return s
}
// SetTemperature set the temperature of the chat completion, optional
func (s *ChatCompletionService) SetDoSample(doSample bool) *ChatCompletionService {
s.doSample = &doSample
return s
}
// SetTemperature set the temperature of the chat completion, optional
func (s *ChatCompletionService) SetTemperature(temperature float64) *ChatCompletionService {
s.temperature = &temperature
return s
}
// SetTopP set the top p of the chat completion, optional
func (s *ChatCompletionService) SetTopP(topP float64) *ChatCompletionService {
s.topP = &topP
return s
}
// SetMaxTokens set the max tokens of the chat completion, optional
func (s *ChatCompletionService) SetMaxTokens(maxTokens int) *ChatCompletionService {
s.maxTokens = &maxTokens
return s
}
// SetStop set the stop of the chat completion, optional
func (s *ChatCompletionService) SetStop(stop ...string) *ChatCompletionService {
s.stop = stop
return s
}
// SetToolChoice set the tool choice of the chat completion, optional
func (s *ChatCompletionService) SetToolChoice(toolChoice string) *ChatCompletionService {
s.toolChoice = &toolChoice
return s
}
// SetUserID set the user id of the chat completion, optional
func (s *ChatCompletionService) SetUserID(userID string) *ChatCompletionService {
s.userID = &userID
return s
}
// SetStreamHandler set the stream handler of the chat completion, optional
// this will enable the stream mode
func (s *ChatCompletionService) SetStreamHandler(handler ChatCompletionStreamHandler) *ChatCompletionService {
s.streamHandler = handler
return s
}
// AddMessage add the message to the chat completion
func (s *ChatCompletionService) AddMessage(messages ...ChatCompletionMessageType) *ChatCompletionService {
for _, message := range messages {
s.messages = append(s.messages, message)
}
return s
}
// AddFunction add the function to the chat completion
func (s *ChatCompletionService) AddTool(tools ...ChatCompletionTool) *ChatCompletionService {
for _, tool := range tools {
switch tool := tool.(type) {
case ChatCompletionToolFunction:
s.tools = append(s.tools, map[string]any{
"type": ToolTypeFunction,
ToolTypeFunction: tool,
})
case ChatCompletionToolRetrieval:
s.tools = append(s.tools, map[string]any{
"type": ToolTypeRetrieval,
ToolTypeRetrieval: tool,
})
case ChatCompletionToolWebSearch:
s.tools = append(s.tools, map[string]any{
"type": ToolTypeWebSearch,
ToolTypeWebSearch: tool,
})
case ChatCompletionToolCodeInterpreter:
s.tools = append(s.tools, map[string]any{
"type": ToolTypeCodeInterpreter,
ToolTypeCodeInterpreter: tool,
})
case ChatCompletionToolDrawingTool:
s.tools = append(s.tools, map[string]any{
"type": ToolTypeDrawingTool,
ToolTypeDrawingTool: tool,
})
case ChatCompletionToolWebBrowser:
s.tools = append(s.tools, map[string]any{
"type": ToolTypeWebBrowser,
ToolTypeWebBrowser: tool,
})
}
}
return s
}
func (s *ChatCompletionService) buildBody() M {
body := map[string]any{
"model": s.model,
"messages": s.messages,
}
if s.requestID != nil {
body["request_id"] = *s.requestID
}
if s.doSample != nil {
body["do_sample"] = *s.doSample
}
if s.temperature != nil {
body["temperature"] = *s.temperature
}
if s.topP != nil {
body["top_p"] = *s.topP
}
if s.maxTokens != nil {
body["max_tokens"] = *s.maxTokens
}
if len(s.stop) != 0 {
body["stop"] = s.stop
}
if len(s.tools) != 0 {
body["tools"] = s.tools
}
if s.toolChoice != nil {
body["tool_choice"] = *s.toolChoice
}
if s.userID != nil {
body["user_id"] = *s.userID
}
if s.meta != nil {
body["meta"] = s.meta
}
return body
}
// Do send the request of the chat completion and return the response
func (s *ChatCompletionService) Do(ctx context.Context) (res ChatCompletionResponse, err error) {
body := s.buildBody()
streamHandler := s.streamHandler
if streamHandler == nil {
var (
resp *resty.Response
apiError APIErrorResponse
)
//fmt.Println(u.BMagenta(u.JsonP(body)), 111)
if resp, err = s.client.request(ctx).SetBody(body).SetResult(&res).SetError(&apiError).Post("chat/completions"); err != nil {
//fmt.Println(u.BRed(err.Error()), 2221)
return
}
if resp.IsError() {
err = apiError
//fmt.Println(u.BRed(err.Error()), 2222)
return
}
//fmt.Println(u.BGreen(u.JsonP(resp.Result())), resp.Status(), resp.Status(), 333)
return
}
// stream mode
body["stream"] = true
var resp *resty.Response
if resp, err = s.client.request(ctx).SetBody(body).SetDoNotParseResponse(true).Post("chat/completions"); err != nil {
return
}
defer resp.RawBody().Close()
if resp.IsError() {
err = errors.New(resp.Status())
return
}
var choice ChatCompletionChoice
if err = chatCompletionDecodeStream(resp.RawBody(), func(chunk ChatCompletionResponse) error {
// reduce the chunk to the response
chatCompletionReduceResponse(&res, chunk)
// invoke the stream handler
return streamHandler(chunk)
}); err != nil {
return
}
res.Choices = append(res.Choices, choice)
return
}

View File

@ -0,0 +1,251 @@
package zhipu
import (
"context"
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestChatCompletionService(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.ChatCompletion("glm-4-flash")
s.AddMessage(ChatCompletionMessage{
Role: RoleUser,
Content: "你好呀",
})
res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.Choices)
choice := res.Choices[0]
require.Equal(t, FinishReasonStop, choice.FinishReason)
require.NotEmpty(t, choice.Message.Content)
}
func TestChatCompletionServiceCharGLM(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.ChatCompletion("charglm-3")
s.SetMeta(
ChatCompletionMeta{
UserName: "啵酱",
UserInfo: "啵酱是小少爷",
BotName: "塞巴斯酱",
BotInfo: "塞巴斯酱是一个冷酷的恶魔管家",
},
).AddMessage(ChatCompletionMessage{
Role: RoleUser,
Content: "早上好",
})
res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.Choices)
choice := res.Choices[0]
require.Contains(t, []string{FinishReasonStop, FinishReasonStopSequence}, choice.FinishReason)
require.NotEmpty(t, choice.Message.Content)
}
func TestChatCompletionServiceAllToolsCodeInterpreter(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.ChatCompletion("GLM-4-AllTools")
s.AddMessage(ChatCompletionMultiMessage{
Role: "user",
Content: []ChatCompletionMultiContent{
{
Type: "text",
Text: "计算[5,10,20,700,99,310,978,100]的平均值和方差。",
},
},
})
s.AddTool(ChatCompletionToolCodeInterpreter{
Sandbox: Ptr(CodeInterpreterSandboxAuto),
})
foundInterpreterInput := false
foundInterpreterOutput := false
s.SetStreamHandler(func(chunk ChatCompletionResponse) error {
for _, c := range chunk.Choices {
for _, tc := range c.Delta.ToolCalls {
if tc.Type == ToolTypeCodeInterpreter && tc.CodeInterpreter != nil {
if tc.CodeInterpreter.Input != "" {
foundInterpreterInput = true
}
if len(tc.CodeInterpreter.Outputs) > 0 {
foundInterpreterOutput = true
}
}
}
}
buf, _ := json.MarshalIndent(chunk, "", " ")
t.Log(string(buf))
return nil
})
res, err := s.Do(context.Background())
require.True(t, foundInterpreterInput)
require.True(t, foundInterpreterOutput)
require.NotNil(t, res)
require.NoError(t, err)
}
func TestChatCompletionServiceAllToolsDrawingTool(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.ChatCompletion("GLM-4-AllTools")
s.AddMessage(ChatCompletionMultiMessage{
Role: "user",
Content: []ChatCompletionMultiContent{
{
Type: "text",
Text: "画一个正弦函数图像",
},
},
})
s.AddTool(ChatCompletionToolDrawingTool{})
foundInput := false
foundOutput := false
outputImage := ""
s.SetStreamHandler(func(chunk ChatCompletionResponse) error {
for _, c := range chunk.Choices {
for _, tc := range c.Delta.ToolCalls {
if tc.Type == ToolTypeDrawingTool && tc.DrawingTool != nil {
if tc.DrawingTool.Input != "" {
foundInput = true
}
if len(tc.DrawingTool.Outputs) > 0 {
foundOutput = true
}
for _, output := range tc.DrawingTool.Outputs {
if output.Image != "" {
outputImage = output.Image
}
}
}
}
}
buf, _ := json.MarshalIndent(chunk, "", " ")
t.Log(string(buf))
return nil
})
res, err := s.Do(context.Background())
require.True(t, foundInput)
require.True(t, foundOutput)
require.NotEmpty(t, outputImage)
t.Log(outputImage)
require.NotNil(t, res)
require.NoError(t, err)
}
func TestChatCompletionServiceAllToolsWebBrowser(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.ChatCompletion("GLM-4-AllTools")
s.AddMessage(ChatCompletionMultiMessage{
Role: "user",
Content: []ChatCompletionMultiContent{
{
Type: "text",
Text: "搜索下本周深圳天气如何",
},
},
})
s.AddTool(ChatCompletionToolWebBrowser{})
foundInput := false
foundOutput := false
outputContent := ""
s.SetStreamHandler(func(chunk ChatCompletionResponse) error {
for _, c := range chunk.Choices {
for _, tc := range c.Delta.ToolCalls {
if tc.Type == ToolTypeWebBrowser && tc.WebBrowser != nil {
if tc.WebBrowser.Input != "" {
foundInput = true
}
if len(tc.WebBrowser.Outputs) > 0 {
foundOutput = true
}
for _, output := range tc.WebBrowser.Outputs {
if output.Content != "" {
outputContent = output.Content
}
}
}
}
}
buf, _ := json.MarshalIndent(chunk, "", " ")
t.Log(string(buf))
return nil
})
res, err := s.Do(context.Background())
require.True(t, foundInput)
require.True(t, foundOutput)
require.NotEmpty(t, outputContent)
t.Log(outputContent)
require.NotNil(t, res)
require.NoError(t, err)
}
func TestChatCompletionServiceStream(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
var content string
s := client.ChatCompletion("glm-4-flash").AddMessage(ChatCompletionMessage{
Role: RoleUser,
Content: "你好呀",
}).SetStreamHandler(func(chunk ChatCompletionResponse) error {
content += chunk.Choices[0].Delta.Content
return nil
})
res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.Choices)
choice := res.Choices[0]
require.Equal(t, FinishReasonStop, choice.FinishReason)
require.NotEmpty(t, choice.Message.Content)
require.Equal(t, content, choice.Message.Content)
}
func TestChatCompletionServiceVision(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.ChatCompletion("glm-4v")
s.AddMessage(ChatCompletionMultiMessage{
Role: RoleUser,
Content: []ChatCompletionMultiContent{
{
Type: MultiContentTypeText,
Text: "图里有什么",
},
{
Type: MultiContentTypeImageURL,
ImageURL: &URLItem{
URL: "https://img1.baidu.com/it/u=1369931113,3388870256&fm=253&app=138&size=w931&n=0&f=JPEG&fmt=auto?sec=1703696400&t=f3028c7a1dca43a080aeb8239f09cc2f",
},
},
},
})
res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.Choices)
require.NotZero(t, res.Usage.CompletionTokens)
choice := res.Choices[0]
require.Equal(t, FinishReasonStop, choice.FinishReason)
require.NotEmpty(t, choice.Message.Content)
}

291
llm/zhipu/zhipu/client.go Normal file
View File

@ -0,0 +1,291 @@
package zhipu
import (
"context"
"errors"
"net/http"
"os"
"strconv"
"strings"
"time"
"github.com/go-resty/resty/v2"
"github.com/golang-jwt/jwt/v5"
)
const (
envAPIKey = "ZHIPUAI_API_KEY"
envBaseURL = "ZHIPUAI_BASE_URL"
envDebug = "ZHIPUAI_DEBUG"
defaultBaseURL = "https://open.bigmodel.cn/api/paas/v4"
)
var (
// ErrAPIKeyMissing is the error when the api key is missing
ErrAPIKeyMissing = errors.New("zhipu: api key is missing")
// ErrAPIKeyMalformed is the error when the api key is malformed
ErrAPIKeyMalformed = errors.New("zhipu: api key is malformed")
)
type clientOptions struct {
baseURL string
apiKey string
client *http.Client
debug *bool
}
// ClientOption is a function that configures the client
type ClientOption func(opts *clientOptions)
// WithAPIKey set the api key of the client
func WithAPIKey(apiKey string) ClientOption {
return func(opts *clientOptions) {
opts.apiKey = apiKey
}
}
// WithBaseURL set the base url of the client
func WithBaseURL(baseURL string) ClientOption {
return func(opts *clientOptions) {
opts.baseURL = baseURL
}
}
// WithHTTPClient set the http client of the client
func WithHTTPClient(client *http.Client) ClientOption {
return func(opts *clientOptions) {
opts.client = client
}
}
// WithDebug set the debug mode of the client
func WithDebug(debug bool) ClientOption {
return func(opts *clientOptions) {
opts.debug = new(bool)
*opts.debug = debug
}
}
// Client is the client for zhipu ai platform
type Client struct {
client *resty.Client
debug bool
keyID string
keySecret []byte
}
func (c *Client) createJWT() string {
timestamp := time.Now().UnixMilli()
exp := timestamp + time.Hour.Milliseconds()*24*7
t := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"api_key": c.keyID,
"timestamp": timestamp,
"exp": exp,
})
t.Header = map[string]interface{}{
"alg": "HS256",
"sign_type": "SIGN",
}
token, err := t.SignedString(c.keySecret)
if err != nil {
panic(err)
}
return token
}
// request creates a new resty request with the jwt token and context
func (c *Client) request(ctx context.Context) *resty.Request {
return c.client.R().SetContext(ctx).SetHeader("Authorization", c.createJWT())
}
// NewClient creates a new client
// It will read the api key from the environment variable ZHIPUAI_API_KEY
// It will read the base url from the environment variable ZHIPUAI_BASE_URL
func NewClient(optFns ...ClientOption) (client *Client, err error) {
var opts clientOptions
for _, optFn := range optFns {
optFn(&opts)
}
// base url
if opts.baseURL == "" {
opts.baseURL = strings.TrimSpace(os.Getenv(envBaseURL))
}
if opts.baseURL == "" {
opts.baseURL = defaultBaseURL
}
// api key
if opts.apiKey == "" {
opts.apiKey = strings.TrimSpace(os.Getenv(envAPIKey))
}
if opts.apiKey == "" {
err = ErrAPIKeyMissing
return
}
// debug
if opts.debug == nil {
if debugStr := strings.TrimSpace(os.Getenv(envDebug)); debugStr != "" {
if debug, err1 := strconv.ParseBool(debugStr); err1 == nil {
opts.debug = &debug
}
}
}
keyComponents := strings.SplitN(opts.apiKey, ".", 2)
if len(keyComponents) != 2 {
err = ErrAPIKeyMalformed
return
}
client = &Client{
keyID: keyComponents[0],
keySecret: []byte(keyComponents[1]),
}
if opts.client == nil {
client.client = resty.New()
} else {
client.client = resty.NewWithClient(opts.client)
}
client.client = client.client.SetBaseURL(opts.baseURL)
if opts.debug != nil {
client.client.SetDebug(*opts.debug)
client.debug = *opts.debug
}
return
}
// BatchCreate creates a new BatchCreateService.
func (c *Client) BatchCreate() *BatchCreateService {
return NewBatchCreateService(c)
}
// BatchGet creates a new BatchGetService.
func (c *Client) BatchGet(batchID string) *BatchGetService {
return NewBatchGetService(c).SetBatchID(batchID)
}
// BatchCancel creates a new BatchCancelService.
func (c *Client) BatchCancel(batchID string) *BatchCancelService {
return NewBatchCancelService(c).SetBatchID(batchID)
}
// BatchList creates a new BatchListService.
func (c *Client) BatchList() *BatchListService {
return NewBatchListService(c)
}
// ChatCompletion creates a new ChatCompletionService.
func (c *Client) ChatCompletion(model string) *ChatCompletionService {
return NewChatCompletionService(c).SetModel(model)
}
// Embedding embeds a list of text into a vector space.
func (c *Client) Embedding(model string) *EmbeddingService {
return NewEmbeddingService(c).SetModel(model)
}
// FileCreate creates a new FileCreateService.
func (c *Client) FileCreate(purpose string) *FileCreateService {
return NewFileCreateService(c).SetPurpose(purpose)
}
// FileEditService creates a new FileEditService.
func (c *Client) FileEdit(documentID string) *FileEditService {
return NewFileEditService(c).SetDocumentID(documentID)
}
// FileList creates a new FileListService.
func (c *Client) FileList(purpose string) *FileListService {
return NewFileListService(c).SetPurpose(purpose)
}
// FileDeleteService creates a new FileDeleteService.
func (c *Client) FileDelete(documentID string) *FileDeleteService {
return NewFileDeleteService(c).SetDocumentID(documentID)
}
// FileGetService creates a new FileGetService.
func (c *Client) FileGet(documentID string) *FileGetService {
return NewFileGetService(c).SetDocumentID(documentID)
}
// FileDownload creates a new FileDownloadService.
func (c *Client) FileDownload(fileID string) *FileDownloadService {
return NewFileDownloadService(c).SetFileID(fileID)
}
// FineTuneCreate creates a new fine tune create service
func (c *Client) FineTuneCreate(model string) *FineTuneCreateService {
return NewFineTuneCreateService(c).SetModel(model)
}
// FineTuneEventList creates a new fine tune event list service
func (c *Client) FineTuneEventList(jobID string) *FineTuneEventListService {
return NewFineTuneEventListService(c).SetJobID(jobID)
}
// FineTuneGet creates a new fine tune get service
func (c *Client) FineTuneGet(jobID string) *FineTuneGetService {
return NewFineTuneGetService(c).SetJobID(jobID)
}
// FineTuneList creates a new fine tune list service
func (c *Client) FineTuneList() *FineTuneListService {
return NewFineTuneListService(c)
}
// FineTuneDelete creates a new fine tune delete service
func (c *Client) FineTuneDelete(jobID string) *FineTuneDeleteService {
return NewFineTuneDeleteService(c).SetJobID(jobID)
}
// FineTuneCancel creates a new fine tune cancel service
func (c *Client) FineTuneCancel(jobID string) *FineTuneCancelService {
return NewFineTuneCancelService(c).SetJobID(jobID)
}
// ImageGeneration creates a new image generation service
func (c *Client) ImageGeneration(model string) *ImageGenerationService {
return NewImageGenerationService(c).SetModel(model)
}
// KnowledgeCreate creates a new knowledge create service
func (c *Client) KnowledgeCreate() *KnowledgeCreateService {
return NewKnowledgeCreateService(c)
}
// KnowledgeEdit creates a new knowledge edit service
func (c *Client) KnowledgeEdit(knowledgeID string) *KnowledgeEditService {
return NewKnowledgeEditService(c).SetKnowledgeID(knowledgeID)
}
// KnowledgeList list all the knowledge
func (c *Client) KnowledgeList() *KnowledgeListService {
return NewKnowledgeListService(c)
}
// KnowledgeDelete creates a new knowledge delete service
func (c *Client) KnowledgeDelete(knowledgeID string) *KnowledgeDeleteService {
return NewKnowledgeDeleteService(c).SetKnowledgeID(knowledgeID)
}
// KnowledgeGet creates a new knowledge get service
func (c *Client) KnowledgeCapacity() *KnowledgeCapacityService {
return NewKnowledgeCapacityService(c)
}
// VideoGeneration creates a new video generation service
func (c *Client) VideoGeneration(model string) *VideoGenerationService {
return NewVideoGenerationService(c).SetModel(model)
}
// AsyncResult creates a new async result get service
func (c *Client) AsyncResult(id string) *AsyncResultService {
return NewAsyncResultService(c).SetID(id)
}

View File

@ -0,0 +1,17 @@
package zhipu
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestClientR(t *testing.T) {
c, err := NewClient()
require.NoError(t, err)
// the only free api is to list fine-tuning jobs
res, err := c.request(context.Background()).Get("fine_tuning/jobs")
require.NoError(t, err)
require.True(t, res.IsSuccess())
}

25
llm/zhipu/zhipu/cog.toml Normal file
View File

@ -0,0 +1,25 @@
from_latest_tag = false
ignore_merge_commits = false
disable_changelog = false
disable_bump_commit = false
generate_mono_repository_global_tag = true
branch_whitelist = []
skip_ci = "[skip ci]"
skip_untracked = false
pre_bump_hooks = []
post_bump_hooks = []
pre_package_bump_hooks = []
post_package_bump_hooks = []
tag_prefix = "v"
[git_hooks]
[commit_types]
[changelog]
path = "CHANGELOG.md"
authors = []
[bump_profiles]
[packages]

View File

@ -0,0 +1,87 @@
package zhipu
import (
"context"
"github.com/go-resty/resty/v2"
)
// EmbeddingData is the data for each embedding.
type EmbeddingData struct {
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
Object string `json:"object"`
}
// EmbeddingResponse is the response from the embedding service.
type EmbeddingResponse struct {
Model string `json:"model"`
Data []EmbeddingData `json:"data"`
Object string `json:"object"`
Usage ChatCompletionUsage `json:"usage"`
}
// EmbeddingService embeds a list of text into a vector space.
type EmbeddingService struct {
client *Client
model string
input string
}
var (
_ BatchSupport = &EmbeddingService{}
)
// NewEmbeddingService creates a new EmbeddingService.
func NewEmbeddingService(client *Client) *EmbeddingService {
return &EmbeddingService{client: client}
}
func (s *EmbeddingService) BatchMethod() string {
return "POST"
}
func (s *EmbeddingService) BatchURL() string {
return BatchEndpointV4Embeddings
}
func (s *EmbeddingService) BatchBody() any {
return s.buildBody()
}
// SetModel sets the model to use for the embedding.
func (s *EmbeddingService) SetModel(model string) *EmbeddingService {
s.model = model
return s
}
// SetInput sets the input text to embed.
func (s *EmbeddingService) SetInput(input string) *EmbeddingService {
s.input = input
return s
}
func (s *EmbeddingService) buildBody() M {
return M{"model": s.model, "input": s.input}
}
func (s *EmbeddingService) Do(ctx context.Context) (res EmbeddingResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetBody(s.buildBody()).
SetResult(&res).
SetError(&apiError).
Post("embeddings"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}

View File

@ -0,0 +1,21 @@
package zhipu
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestEmbeddingService(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
service := client.Embedding("embedding-2")
resp, err := service.SetInput("你好").Do(context.Background())
require.NoError(t, err)
require.NotZero(t, resp.Usage.TotalTokens)
require.NotEmpty(t, resp.Data)
require.NotEmpty(t, resp.Data[0].Embedding)
}

58
llm/zhipu/zhipu/error.go Normal file
View File

@ -0,0 +1,58 @@
package zhipu
type APIError struct {
Code string `json:"code"`
Message string `json:"message"`
}
func (e APIError) Error() string {
return e.Message
}
type APIErrorResponse struct {
APIError `json:"error"`
}
func (e APIErrorResponse) Error() string {
return e.APIError.Error()
}
// GetAPIErrorCode returns the error code of an API error.
func GetAPIErrorCode(err error) string {
if err == nil {
return ""
}
if e, ok := err.(APIError); ok {
return e.Code
}
if e, ok := err.(APIErrorResponse); ok {
return e.Code
}
if e, ok := err.(*APIError); ok && e != nil {
return e.Code
}
if e, ok := err.(*APIErrorResponse); ok && e != nil {
return e.Code
}
return ""
}
// GetAPIErrorMessage returns the error message of an API error.
func GetAPIErrorMessage(err error) string {
if err == nil {
return ""
}
if e, ok := err.(APIError); ok {
return e.Message
}
if e, ok := err.(APIErrorResponse); ok {
return e.Message
}
if e, ok := err.(*APIError); ok && e != nil {
return e.Message
}
if e, ok := err.(*APIErrorResponse); ok && e != nil {
return e.Message
}
return err.Error()
}

View File

@ -0,0 +1,38 @@
package zhipu
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestAPIError(t *testing.T) {
err := APIError{
Code: "code",
Message: "message",
}
require.Equal(t, "message", err.Error())
require.Equal(t, "code", GetAPIErrorCode(err))
require.Equal(t, "message", GetAPIErrorMessage(err))
}
func TestAPIErrorResponse(t *testing.T) {
err := APIErrorResponse{
APIError: APIError{
Code: "code",
Message: "message",
},
}
require.Equal(t, "message", err.Error())
require.Equal(t, "code", GetAPIErrorCode(err))
require.Equal(t, "message", GetAPIErrorMessage(err))
}
func TestAPIErrorResponseFromDoc(t *testing.T) {
var res APIErrorResponse
err := json.Unmarshal([]byte(`{"error":{"code":"1002","message":"Authorization Token非法请确认Authorization Token正确传递。"}}`), &res)
require.NoError(t, err)
require.Equal(t, "1002", res.Code)
require.Equal(t, "1002", GetAPIErrorCode(res))
}

541
llm/zhipu/zhipu/file.go Normal file
View File

@ -0,0 +1,541 @@
package zhipu
import (
"context"
"errors"
"io"
"os"
"path/filepath"
"strconv"
"github.com/go-resty/resty/v2"
)
const (
FilePurposeFineTune = "fine-tune"
FilePurposeRetrieval = "retrieval"
FilePurposeBatch = "batch"
KnowledgeTypeArticle = 1
KnowledgeTypeQADocument = 2
KnowledgeTypeQASpreadsheet = 3
KnowledgeTypeProductDatabaseSpreadsheet = 4
KnowledgeTypeCustom = 5
)
// FileCreateService is a service to create a file.
type FileCreateService struct {
client *Client
purpose string
localFile string
file io.Reader
filename string
customSeparator *string
sentenceSize *int
knowledgeID *string
}
// FileCreateKnowledgeSuccessInfo is the success info of the FileCreateKnowledgeResponse.
type FileCreateKnowledgeSuccessInfo struct {
Filename string `json:"fileName"`
DocumentID string `json:"documentId"`
}
// FileCreateKnowledgeFailedInfo is the failed info of the FileCreateKnowledgeResponse.
type FileCreateKnowledgeFailedInfo struct {
Filename string `json:"fileName"`
FailReason string `json:"failReason"`
}
// FileCreateKnowledgeResponse is the response of the FileCreateService.
type FileCreateKnowledgeResponse struct {
SuccessInfos []FileCreateKnowledgeSuccessInfo `json:"successInfos"`
FailedInfos []FileCreateKnowledgeFailedInfo `json:"failedInfos"`
}
// FileCreateFineTuneResponse is the response of the FileCreateService.
type FileCreateFineTuneResponse struct {
Bytes int64 `json:"bytes"`
CreatedAt int64 `json:"created_at"`
Filename string `json:"filename"`
Object string `json:"object"`
Purpose string `json:"purpose"`
ID string `json:"id"`
}
// FileCreateResponse is the response of the FileCreateService.
type FileCreateResponse struct {
FileCreateFineTuneResponse
FileCreateKnowledgeResponse
}
// NewFileCreateService creates a new FileCreateService.
func NewFileCreateService(client *Client) *FileCreateService {
return &FileCreateService{client: client}
}
// SetLocalFile sets the local_file parameter of the FileCreateService.
func (s *FileCreateService) SetLocalFile(localFile string) *FileCreateService {
s.localFile = localFile
return s
}
// SetFile sets the file parameter of the FileCreateService.
func (s *FileCreateService) SetFile(file io.Reader, filename string) *FileCreateService {
s.file = file
s.filename = filename
return s
}
// SetPurpose sets the purpose parameter of the FileCreateService.
func (s *FileCreateService) SetPurpose(purpose string) *FileCreateService {
s.purpose = purpose
return s
}
// SetCustomSeparator sets the custom_separator parameter of the FileCreateService.
func (s *FileCreateService) SetCustomSeparator(customSeparator string) *FileCreateService {
s.customSeparator = &customSeparator
return s
}
// SetSentenceSize sets the sentence_size parameter of the FileCreateService.
func (s *FileCreateService) SetSentenceSize(sentenceSize int) *FileCreateService {
s.sentenceSize = &sentenceSize
return s
}
// SetKnowledgeID sets the knowledge_id parameter of the FileCreateService.
func (s *FileCreateService) SetKnowledgeID(knowledgeID string) *FileCreateService {
s.knowledgeID = &knowledgeID
return s
}
// Do makes the request.
func (s *FileCreateService) Do(ctx context.Context) (res FileCreateResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
body := map[string]string{"purpose": s.purpose}
if s.customSeparator != nil {
body["custom_separator"] = *s.customSeparator
}
if s.sentenceSize != nil {
body["sentence_size"] = strconv.Itoa(*s.sentenceSize)
}
if s.knowledgeID != nil {
body["knowledge_id"] = *s.knowledgeID
}
file, filename := s.file, s.filename
if file == nil && s.localFile != "" {
var f *os.File
if f, err = os.Open(s.localFile); err != nil {
return
}
defer f.Close()
file = f
filename = filepath.Base(s.localFile)
}
if file == nil {
err = errors.New("no file specified")
return
}
if resp, err = s.client.request(ctx).
SetFileReader("file", filename, file).
SetMultipartFormData(body).
SetResult(&res).
SetError(&apiError).
Post("files"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FileEditService is a service to edit a file.
type FileEditService struct {
client *Client
documentID string
knowledgeType *int
customSeparator []string
sentenceSize *int
}
// NewFileEditService creates a new FileEditService.
func NewFileEditService(client *Client) *FileEditService {
return &FileEditService{client: client}
}
// SetDocumentID sets the document_id parameter of the FileEditService.
func (s *FileEditService) SetDocumentID(documentID string) *FileEditService {
s.documentID = documentID
return s
}
// SetKnowledgeType sets the knowledge_type parameter of the FileEditService.
func (s *FileEditService) SetKnowledgeType(knowledgeType int) *FileEditService {
s.knowledgeType = &knowledgeType
return s
}
// SetSentenceSize sets the sentence_size parameter of the FileEditService.
func (s *FileEditService) SetCustomSeparator(customSeparator ...string) *FileEditService {
s.customSeparator = customSeparator
return s
}
// SetSentenceSize sets the sentence_size parameter of the FileEditService.
func (s *FileEditService) SetSentenceSize(sentenceSize int) *FileEditService {
s.sentenceSize = &sentenceSize
return s
}
// Do makes the request.
func (s *FileEditService) Do(ctx context.Context) (err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
body := M{}
if s.knowledgeType != nil {
body["knowledge_type"] = strconv.Itoa(*s.knowledgeType)
}
if len(s.customSeparator) > 0 {
body["custom_separator"] = s.customSeparator
}
if s.sentenceSize != nil {
body["sentence_size"] = strconv.Itoa(*s.sentenceSize)
}
if resp, err = s.client.request(ctx).
SetPathParam("document_id", s.documentID).
SetBody(body).
SetError(&apiError).
Put("document/{document_id}"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FileListService is a service to list files.
type FileListService struct {
client *Client
purpose string
knowledgeID *string
page *int
limit *int
after *string
orderAsc *bool
}
// FileFailInfo is the failed info of the FileListKnowledgeItem.
type FileFailInfo struct {
EmbeddingCode int `json:"embedding_code"`
EmbeddingMsg string `json:"embedding_msg"`
}
// FileListKnowledgeItem is the item of the FileListKnowledgeResponse.
type FileListKnowledgeItem struct {
ID string `json:"id"`
Name string `json:"name"`
URL string `json:"url"`
Length int64 `json:"length"`
SentenceSize int64 `json:"sentence_size"`
CustomSeparator []string `json:"custom_separator"`
EmbeddingStat int `json:"embedding_stat"`
FailInfo *FileFailInfo `json:"failInfo"`
WordNum int64 `json:"word_num"`
ParseImage int `json:"parse_image"`
}
// FileListKnowledgeResponse is the response of the FileListService.
type FileListKnowledgeResponse struct {
Total int `json:"total"`
List []FileListKnowledgeItem `json:"list"`
}
// FileListFineTuneItem is the item of the FileListFineTuneResponse.
type FileListFineTuneItem struct {
Bytes int64 `json:"bytes"`
CreatedAt int64 `json:"created_at"`
Filename string `json:"filename"`
ID string `json:"id"`
Object string `json:"object"`
Purpose string `json:"purpose"`
}
// FileListFineTuneResponse is the response of the FileListService.
type FileListFineTuneResponse struct {
Object string `json:"object"`
Data []FileListFineTuneItem `json:"data"`
}
// FileListResponse is the response of the FileListService.
type FileListResponse struct {
FileListKnowledgeResponse
FileListFineTuneResponse
}
// NewFileListService creates a new FileListService.
func NewFileListService(client *Client) *FileListService {
return &FileListService{client: client}
}
// SetPurpose sets the purpose parameter of the FileListService.
func (s *FileListService) SetPurpose(purpose string) *FileListService {
s.purpose = purpose
return s
}
// SetKnowledgeID sets the knowledge_id parameter of the FileListService.
func (s *FileListService) SetKnowledgeID(knowledgeID string) *FileListService {
s.knowledgeID = &knowledgeID
return s
}
// SetPage sets the page parameter of the FileListService.
func (s *FileListService) SetPage(page int) *FileListService {
s.page = &page
return s
}
// SetLimit sets the limit parameter of the FileListService.
func (s *FileListService) SetLimit(limit int) *FileListService {
s.limit = &limit
return s
}
// SetAfter sets the after parameter of the FileListService.
func (s *FileListService) SetAfter(after string) *FileListService {
s.after = &after
return s
}
// SetOrder sets the order parameter of the FileListService.
func (s *FileListService) SetOrder(asc bool) *FileListService {
s.orderAsc = &asc
return s
}
// Do makes the request.
func (s *FileListService) Do(ctx context.Context) (res FileListResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
m := map[string]string{
"purpose": s.purpose,
}
if s.knowledgeID != nil {
m["knowledge_id"] = *s.knowledgeID
}
if s.page != nil {
m["page"] = strconv.Itoa(*s.page)
}
if s.limit != nil {
m["limit"] = strconv.Itoa(*s.limit)
}
if s.after != nil {
m["after"] = *s.after
}
if s.orderAsc != nil {
if *s.orderAsc {
m["order"] = "asc"
} else {
m["order"] = "desc"
}
}
if resp, err = s.client.request(ctx).
SetQueryParams(m).
SetResult(&res).
SetError(&apiError).
Get("files"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FileDeleteService is a service to delete a file.
type FileDeleteService struct {
client *Client
documentID string
}
// NewFileDeleteService creates a new FileDeleteService.
func NewFileDeleteService(client *Client) *FileDeleteService {
return &FileDeleteService{client: client}
}
// SetDocumentID sets the document_id parameter of the FileDeleteService.
func (s *FileDeleteService) SetDocumentID(documentID string) *FileDeleteService {
s.documentID = documentID
return s
}
// Do makes the request.
func (s *FileDeleteService) Do(ctx context.Context) (err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetPathParam("document_id", s.documentID).
SetError(&apiError).
Delete("document/{document_id}"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FileGetService is a service to get a file.
type FileGetService struct {
client *Client
documentID string
}
// FileGetResponse is the response of the FileGetService.
type FileGetResponse = FileListKnowledgeItem
// NewFileGetService creates a new FileGetService.
func NewFileGetService(client *Client) *FileGetService {
return &FileGetService{client: client}
}
// SetDocumentID sets the document_id parameter of the FileGetService.
func (s *FileGetService) SetDocumentID(documentID string) *FileGetService {
s.documentID = documentID
return s
}
// Do makes the request.
func (s *FileGetService) Do(ctx context.Context) (res FileGetResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetPathParam("document_id", s.documentID).
SetResult(&res).
SetError(&apiError).
Get("document/{document_id}"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FileDownloadService is a service to download a file.
type FileDownloadService struct {
client *Client
fileID string
writer io.Writer
filename string
}
// NewFileDownloadService creates a new FileDownloadService.
func NewFileDownloadService(client *Client) *FileDownloadService {
return &FileDownloadService{client: client}
}
// SetFileID sets the file_id parameter of the FileDownloadService.
func (s *FileDownloadService) SetFileID(fileID string) *FileDownloadService {
s.fileID = fileID
return s
}
// SetOutput sets the output parameter of the FileDownloadService.
func (s *FileDownloadService) SetOutput(w io.Writer) *FileDownloadService {
s.writer = w
return s
}
// SetOutputFile sets the output_file parameter of the FileDownloadService.
func (s *FileDownloadService) SetOutputFile(filename string) *FileDownloadService {
s.filename = filename
return s
}
// Do makes the request.
func (s *FileDownloadService) Do(ctx context.Context) (err error) {
var resp *resty.Response
writer := s.writer
if writer == nil && s.filename != "" {
var f *os.File
if f, err = os.Create(s.filename); err != nil {
return
}
defer f.Close()
writer = f
}
if writer == nil {
return errors.New("no output specified")
}
if resp, err = s.client.request(ctx).
SetDoNotParseResponse(true).
SetPathParam("file_id", s.fileID).
Get("files/{file_id}/content"); err != nil {
return
}
defer resp.RawBody().Close()
_, err = io.Copy(writer, resp.RawBody())
return
}

View File

@ -0,0 +1,71 @@
package zhipu
import (
"context"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
)
func TestFileServiceFineTune(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.FileCreate(FilePurposeFineTune)
s.SetLocalFile(filepath.Join("testdata", "test-file.jsonl"))
res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotZero(t, res.Bytes)
require.NotZero(t, res.CreatedAt)
require.NotEmpty(t, res.ID)
}
func TestFileServiceKnowledge(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.FileCreate(FilePurposeRetrieval)
s.SetKnowledgeID(os.Getenv("TEST_KNOWLEDGE_ID"))
s.SetLocalFile(filepath.Join("testdata", "test-file.txt"))
res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.SuccessInfos)
require.NotEmpty(t, res.SuccessInfos[0].DocumentID)
require.NotEmpty(t, res.SuccessInfos[0].Filename)
documentID := res.SuccessInfos[0].DocumentID
res2, err := client.FileGet(documentID).Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res2.ID)
err = client.FileEdit(documentID).SetKnowledgeType(KnowledgeTypeCustom).Do(context.Background())
require.True(t, err == nil || GetAPIErrorCode(err) == "10019")
err = client.FileDelete(res.SuccessInfos[0].DocumentID).Do(context.Background())
require.True(t, err == nil || GetAPIErrorCode(err) == "10019")
}
func TestFileListServiceKnowledge(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.FileList(FilePurposeRetrieval).SetKnowledgeID(os.Getenv("TEST_KNOWLEDGE_ID"))
res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.List)
}
func TestFileListServiceFineTune(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.FileList(FilePurposeFineTune)
res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.Data)
}

View File

@ -0,0 +1,456 @@
package zhipu
import (
"context"
"strconv"
"github.com/go-resty/resty/v2"
)
const (
HyperParameterAuto = "auto"
FineTuneStatusCreate = "create"
FineTuneStatusValidatingFiles = "validating_files"
FineTuneStatusQueued = "queued"
FineTuneStatusRunning = "running"
FineTuneStatusSucceeded = "succeeded"
FineTuneStatusFailed = "failed"
FineTuneStatusCancelled = "cancelled"
)
// FineTuneItem is the item of the FineTune
type FineTuneItem struct {
ID string `json:"id"`
RequestID string `json:"request_id"`
FineTunedModel string `json:"fine_tuned_model"`
Status string `json:"status"`
Object string `json:"object"`
TrainingFile string `json:"training_file"`
ValidationFile string `json:"validation_file"`
Error APIError `json:"error"`
}
// FineTuneCreateService creates a new fine tune
type FineTuneCreateService struct {
client *Client
model string
trainingFile string
validationFile *string
learningRateMultiplier *StringOr[float64]
batchSize *StringOr[int]
nEpochs *StringOr[int]
suffix *string
requestID *string
}
// FineTuneCreateResponse is the response of the FineTuneCreateService
type FineTuneCreateResponse = FineTuneItem
// NewFineTuneCreateService creates a new FineTuneCreateService
func NewFineTuneCreateService(client *Client) *FineTuneCreateService {
return &FineTuneCreateService{
client: client,
}
}
// SetModel sets the model parameter
func (s *FineTuneCreateService) SetModel(model string) *FineTuneCreateService {
s.model = model
return s
}
// SetTrainingFile sets the trainingFile parameter
func (s *FineTuneCreateService) SetTrainingFile(trainingFile string) *FineTuneCreateService {
s.trainingFile = trainingFile
return s
}
// SetValidationFile sets the validationFile parameter
func (s *FineTuneCreateService) SetValidationFile(validationFile string) *FineTuneCreateService {
s.validationFile = &validationFile
return s
}
// SetLearningRateMultiplier sets the learningRateMultiplier parameter
func (s *FineTuneCreateService) SetLearningRateMultiplier(learningRateMultiplier float64) *FineTuneCreateService {
s.learningRateMultiplier = &StringOr[float64]{}
s.learningRateMultiplier.SetValue(learningRateMultiplier)
return s
}
// SetLearningRateMultiplierAuto sets the learningRateMultiplier parameter to auto
func (s *FineTuneCreateService) SetLearningRateMultiplierAuto() *FineTuneCreateService {
s.learningRateMultiplier = &StringOr[float64]{}
s.learningRateMultiplier.SetString(HyperParameterAuto)
return s
}
// SetBatchSize sets the batchSize parameter
func (s *FineTuneCreateService) SetBatchSize(batchSize int) *FineTuneCreateService {
s.batchSize = &StringOr[int]{}
s.batchSize.SetValue(batchSize)
return s
}
// SetBatchSizeAuto sets the batchSize parameter to auto
func (s *FineTuneCreateService) SetBatchSizeAuto() *FineTuneCreateService {
s.batchSize = &StringOr[int]{}
s.batchSize.SetString(HyperParameterAuto)
return s
}
// SetNEpochs sets the nEpochs parameter
func (s *FineTuneCreateService) SetNEpochs(nEpochs int) *FineTuneCreateService {
s.nEpochs = &StringOr[int]{}
s.nEpochs.SetValue(nEpochs)
return s
}
// SetNEpochsAuto sets the nEpochs parameter to auto
func (s *FineTuneCreateService) SetNEpochsAuto() *FineTuneCreateService {
s.nEpochs = &StringOr[int]{}
s.nEpochs.SetString(HyperParameterAuto)
return s
}
// SetSuffix sets the suffix parameter
func (s *FineTuneCreateService) SetSuffix(suffix string) *FineTuneCreateService {
s.suffix = &suffix
return s
}
// SetRequestID sets the requestID parameter
func (s *FineTuneCreateService) SetRequestID(requestID string) *FineTuneCreateService {
s.requestID = &requestID
return s
}
// Do makes the request
func (s *FineTuneCreateService) Do(ctx context.Context) (res FineTuneCreateResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
body := M{
"model": s.model,
"training_file": s.trainingFile,
}
if s.validationFile != nil {
body["validation_file"] = *s.validationFile
}
if s.suffix != nil {
body["suffix"] = *s.suffix
}
if s.requestID != nil {
body["request_id"] = *s.requestID
}
if s.learningRateMultiplier != nil || s.batchSize != nil || s.nEpochs != nil {
hp := M{}
if s.learningRateMultiplier != nil {
hp["learning_rate_multiplier"] = s.learningRateMultiplier
}
if s.batchSize != nil {
hp["batch_size"] = s.batchSize
}
if s.nEpochs != nil {
hp["n_epochs"] = s.nEpochs
}
body["hyperparameters"] = hp
}
if resp, err = s.client.request(ctx).
SetBody(body).
SetResult(&res).
SetError(&apiError).
Post("fine_tuning/jobs"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FineTuneEventListService creates a new fine tune event list
type FineTuneEventListService struct {
client *Client
jobID string
limit *int
after *string
}
// FineTuneEventData is the data of the FineTuneEventItem
type FineTuneEventData struct {
Acc float64 `json:"acc"`
Loss float64 `json:"loss"`
CurrentSteps int64 `json:"current_steps"`
RemainingTime string `json:"remaining_time"`
ElapsedTime string `json:"elapsed_time"`
TotalSteps int64 `json:"total_steps"`
Epoch int64 `json:"epoch"`
TrainedTokens int64 `json:"trained_tokens"`
LearningRate float64 `json:"learning_rate"`
}
// FineTuneEventItem is the item of the FineTuneEventListResponse
type FineTuneEventItem struct {
ID string `json:"id"`
Type string `json:"type"`
Level string `json:"level"`
Message string `json:"message"`
Object string `json:"object"`
CreatedAt int64 `json:"created_at"`
Data FineTuneEventData `json:"data"`
}
// FineTuneEventListResponse is the response of the FineTuneEventListService
type FineTuneEventListResponse struct {
Data []FineTuneEventItem `json:"data"`
HasMore bool `json:"has_more"`
Object string `json:"object"`
}
// NewFineTuneEventListService creates a new FineTuneEventListService
func NewFineTuneEventListService(client *Client) *FineTuneEventListService {
return &FineTuneEventListService{
client: client,
}
}
// SetJobID sets the jobID parameter
func (s *FineTuneEventListService) SetJobID(jobID string) *FineTuneEventListService {
s.jobID = jobID
return s
}
// SetLimit sets the limit parameter
func (s *FineTuneEventListService) SetLimit(limit int) *FineTuneEventListService {
s.limit = &limit
return s
}
// SetAfter sets the after parameter
func (s *FineTuneEventListService) SetAfter(after string) *FineTuneEventListService {
s.after = &after
return s
}
// Do makes the request
func (s *FineTuneEventListService) Do(ctx context.Context) (res FineTuneEventListResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
req := s.client.request(ctx)
if s.limit != nil {
req.SetQueryParam("limit", strconv.Itoa(*s.limit))
}
if s.after != nil {
req.SetQueryParam("after", *s.after)
}
if resp, err = req.
SetPathParam("job_id", s.jobID).
SetResult(&res).
SetError(&apiError).
Get("fine_tuning/jobs/{job_id}/events"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FineTuneGetService creates a new fine tune get
type FineTuneGetService struct {
client *Client
jobID string
}
// NewFineTuneGetService creates a new FineTuneGetService
func NewFineTuneGetService(client *Client) *FineTuneGetService {
return &FineTuneGetService{
client: client,
}
}
// SetJobID sets the jobID parameter
func (s *FineTuneGetService) SetJobID(jobID string) *FineTuneGetService {
s.jobID = jobID
return s
}
// Do makes the request
func (s *FineTuneGetService) Do(ctx context.Context) (res FineTuneItem, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetPathParam("job_id", s.jobID).
SetResult(&res).
SetError(&apiError).
Get("fine_tuning/jobs/{job_id}"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FineTuneListService creates a new fine tune list
type FineTuneListService struct {
client *Client
limit *int
after *string
}
// FineTuneListResponse is the response of the FineTuneListService
type FineTuneListResponse struct {
Data []FineTuneItem `json:"data"`
Object string `json:"object"`
}
// NewFineTuneListService creates a new FineTuneListService
func NewFineTuneListService(client *Client) *FineTuneListService {
return &FineTuneListService{
client: client,
}
}
// SetLimit sets the limit parameter
func (s *FineTuneListService) SetLimit(limit int) *FineTuneListService {
s.limit = &limit
return s
}
// SetAfter sets the after parameter
func (s *FineTuneListService) SetAfter(after string) *FineTuneListService {
s.after = &after
return s
}
// Do makes the request
func (s *FineTuneListService) Do(ctx context.Context) (res FineTuneListResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
req := s.client.request(ctx)
if s.limit != nil {
req.SetQueryParam("limit", strconv.Itoa(*s.limit))
}
if s.after != nil {
req.SetQueryParam("after", *s.after)
}
if resp, err = req.
SetResult(&res).
SetError(&apiError).
Get("fine_tuning/jobs"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FineTuneDeleteService creates a new fine tune delete
type FineTuneDeleteService struct {
client *Client
jobID string
}
// NewFineTuneDeleteService creates a new FineTuneDeleteService
func NewFineTuneDeleteService(client *Client) *FineTuneDeleteService {
return &FineTuneDeleteService{
client: client,
}
}
// SetJobID sets the jobID parameter
func (s *FineTuneDeleteService) SetJobID(jobID string) *FineTuneDeleteService {
s.jobID = jobID
return s
}
// Do makes the request
func (s *FineTuneDeleteService) Do(ctx context.Context) (res FineTuneItem, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetPathParam("job_id", s.jobID).
SetResult(&res).
SetError(&apiError).
Delete("fine_tuning/jobs/{job_id}"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// FineTuneCancelService creates a new fine tune cancel
type FineTuneCancelService struct {
client *Client
jobID string
}
// NewFineTuneCancelService creates a new FineTuneCancelService
func NewFineTuneCancelService(client *Client) *FineTuneCancelService {
return &FineTuneCancelService{
client: client,
}
}
// SetJobID sets the jobID parameter
func (s *FineTuneCancelService) SetJobID(jobID string) *FineTuneCancelService {
s.jobID = jobID
return s
}
// Do makes the request
func (s *FineTuneCancelService) Do(ctx context.Context) (res FineTuneItem, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetPathParam("job_id", s.jobID).
SetResult(&res).
SetError(&apiError).
Post("fine_tuning/jobs/{job_id}/cancel"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}

View File

@ -0,0 +1,3 @@
package zhipu
// tests not available since lack of budget to test it

View File

@ -0,0 +1,110 @@
package zhipu
import (
"context"
"github.com/go-resty/resty/v2"
)
// ImageGenerationService creates a new image generation
type ImageGenerationService struct {
client *Client
model string
prompt string
size string
userID string
}
var (
_ BatchSupport = &ImageGenerationService{}
)
// ImageGenerationResponse is the response of the ImageGenerationService
type ImageGenerationResponse struct {
Created int64 `json:"created"`
Data []URLItem `json:"data"`
}
// NewImageGenerationService creates a new ImageGenerationService
func NewImageGenerationService(client *Client) *ImageGenerationService {
return &ImageGenerationService{
client: client,
}
}
func (s *ImageGenerationService) BatchMethod() string {
return "POST"
}
func (s *ImageGenerationService) BatchURL() string {
return BatchEndpointV4ImagesGenerations
}
func (s *ImageGenerationService) BatchBody() any {
return s.buildBody()
}
// SetModel sets the model parameter
func (s *ImageGenerationService) SetModel(model string) *ImageGenerationService {
s.model = model
return s
}
// SetPrompt sets the prompt parameter
func (s *ImageGenerationService) SetPrompt(prompt string) *ImageGenerationService {
s.prompt = prompt
return s
}
func (s *ImageGenerationService) SetSize(size string) *ImageGenerationService {
s.size = size
return s
}
// SetUserID sets the userID parameter
func (s *ImageGenerationService) SetUserID(userID string) *ImageGenerationService {
s.userID = userID
return s
}
func (s *ImageGenerationService) buildBody() M {
body := M{
"model": s.model,
"prompt": s.prompt,
}
if s.userID != "" {
body["user_id"] = s.userID
}
if s.size != "" {
body["size"] = s.size
}
return body
}
func (s *ImageGenerationService) Do(ctx context.Context) (res ImageGenerationResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
body := s.buildBody()
if resp, err = s.client.request(ctx).
SetBody(body).
SetResult(&res).
SetError(&apiError).
Post("images/generations"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}

View File

@ -0,0 +1,21 @@
package zhipu
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestImageGenerationService(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.ImageGeneration("cogview-3")
s.SetPrompt("一只可爱的小猫")
res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.Data)
t.Log(res.Data[0].URL)
}

View File

@ -0,0 +1,299 @@
package zhipu
import (
"context"
"strconv"
"github.com/go-resty/resty/v2"
)
const (
KnowledgeEmbeddingIDEmbedding2 = 3
)
// KnowledgeCreateService creates a new knowledge
type KnowledgeCreateService struct {
client *Client
embeddingID int
name string
description *string
}
// KnowledgeCreateResponse is the response of the KnowledgeCreateService
type KnowledgeCreateResponse = IDItem
// NewKnowledgeCreateService creates a new KnowledgeCreateService
func NewKnowledgeCreateService(client *Client) *KnowledgeCreateService {
return &KnowledgeCreateService{
client: client,
}
}
// SetEmbeddingID sets the embedding id of the knowledge
func (s *KnowledgeCreateService) SetEmbeddingID(embeddingID int) *KnowledgeCreateService {
s.embeddingID = embeddingID
return s
}
// SetName sets the name of the knowledge
func (s *KnowledgeCreateService) SetName(name string) *KnowledgeCreateService {
s.name = name
return s
}
// SetDescription sets the description of the knowledge
func (s *KnowledgeCreateService) SetDescription(description string) *KnowledgeCreateService {
s.description = &description
return s
}
// Do creates the knowledge
func (s *KnowledgeCreateService) Do(ctx context.Context) (res KnowledgeCreateResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
body := M{
"name": s.name,
"embedding_id": s.embeddingID,
}
if s.description != nil {
body["description"] = *s.description
}
if resp, err = s.client.request(ctx).
SetBody(body).
SetResult(&res).
SetError(&apiError).
Post("knowledge"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// KnowledgeEditService edits a knowledge
type KnowledgeEditService struct {
client *Client
knowledgeID string
embeddingID *int
name *string
description *string
}
// NewKnowledgeEditService creates a new KnowledgeEditService
func NewKnowledgeEditService(client *Client) *KnowledgeEditService {
return &KnowledgeEditService{
client: client,
}
}
// SetKnowledgeID sets the knowledge id
func (s *KnowledgeEditService) SetKnowledgeID(knowledgeID string) *KnowledgeEditService {
s.knowledgeID = knowledgeID
return s
}
// SetName sets the name of the knowledge
func (s *KnowledgeEditService) SetName(name string) *KnowledgeEditService {
s.name = &name
return s
}
// SetEmbeddingID sets the embedding id of the knowledge
func (s *KnowledgeEditService) SetEmbeddingID(embeddingID int) *KnowledgeEditService {
s.embeddingID = &embeddingID
return s
}
// SetDescription sets the description of the knowledge
func (s *KnowledgeEditService) SetDescription(description string) *KnowledgeEditService {
s.description = &description
return s
}
// Do edits the knowledge
func (s *KnowledgeEditService) Do(ctx context.Context) (err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
body := M{}
if s.name != nil {
body["name"] = *s.name
}
if s.description != nil {
body["description"] = *s.description
}
if s.embeddingID != nil {
body["embedding_id"] = *s.embeddingID
}
if resp, err = s.client.request(ctx).
SetPathParam("knowledge_id", s.knowledgeID).
SetBody(body).
SetError(&apiError).
Put("knowledge/{knowledge_id}"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// KnowledgeListService lists the knowledge
type KnowledgeListService struct {
client *Client
page *int
size *int
}
// KnowledgeItem is an item in the knowledge list
type KnowledgeItem struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Icon string `json:"icon"`
Background string `json:"background"`
EmbeddingID int `json:"embedding_id"`
CustomIdentifier string `json:"custom_identifier"`
WordNum int64 `json:"word_num"`
Length int64 `json:"length"`
DocumentSize int64 `json:"document_size"`
}
// KnowledgeListResponse is the response of the KnowledgeListService
type KnowledgeListResponse struct {
List []KnowledgeItem `json:"list"`
Total int `json:"total"`
}
// NewKnowledgeListService creates a new KnowledgeListService
func NewKnowledgeListService(client *Client) *KnowledgeListService {
return &KnowledgeListService{client: client}
}
// SetPage sets the page of the knowledge list
func (s *KnowledgeListService) SetPage(page int) *KnowledgeListService {
s.page = &page
return s
}
// SetSize sets the size of the knowledge list
func (s *KnowledgeListService) SetSize(size int) *KnowledgeListService {
s.size = &size
return s
}
// Do lists the knowledge
func (s *KnowledgeListService) Do(ctx context.Context) (res KnowledgeListResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
req := s.client.request(ctx)
if s.page != nil {
req.SetQueryParam("page", strconv.Itoa(*s.page))
}
if s.size != nil {
req.SetQueryParam("size", strconv.Itoa(*s.size))
}
if resp, err = req.
SetResult(&res).
SetError(&apiError).
Get("knowledge"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// KnowledgeDeleteService deletes a knowledge
type KnowledgeDeleteService struct {
client *Client
knowledgeID string
}
// NewKnowledgeDeleteService creates a new KnowledgeDeleteService
func NewKnowledgeDeleteService(client *Client) *KnowledgeDeleteService {
return &KnowledgeDeleteService{
client: client,
}
}
// SetKnowledgeID sets the knowledge id
func (s *KnowledgeDeleteService) SetKnowledgeID(knowledgeID string) *KnowledgeDeleteService {
s.knowledgeID = knowledgeID
return s
}
// Do deletes the knowledge
func (s *KnowledgeDeleteService) Do(ctx context.Context) (err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetPathParam("knowledge_id", s.knowledgeID).
SetError(&apiError).
Delete("knowledge/{knowledge_id}"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}
// KnowledgeCapacityService query the capacity of the knowledge
type KnowledgeCapacityService struct {
client *Client
}
// KnowledgeCapacityItem is an item in the knowledge capacity
type KnowledgeCapacityItem struct {
WordNum int64 `json:"word_num"`
Length int64 `json:"length"`
}
// KnowledgeCapacityResponse is the response of the KnowledgeCapacityService
type KnowledgeCapacityResponse struct {
Used KnowledgeCapacityItem `json:"used"`
Total KnowledgeCapacityItem `json:"total"`
}
// SetKnowledgeID sets the knowledge id
func NewKnowledgeCapacityService(client *Client) *KnowledgeCapacityService {
return &KnowledgeCapacityService{client: client}
}
// Do query the capacity of the knowledge
func (s *KnowledgeCapacityService) Do(ctx context.Context) (res KnowledgeCapacityResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
if resp, err = s.client.request(ctx).
SetResult(&res).
SetError(&apiError).
Get("knowledge/capacity"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}

View File

@ -0,0 +1,50 @@
package zhipu
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestKnowledgeCapacity(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.KnowledgeCapacity()
res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.Total.Length)
require.NotEmpty(t, res.Total.WordNum)
}
func TestKnowledgeServiceAll(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.KnowledgeCreate()
s.SetName("test")
s.SetDescription("test description")
s.SetEmbeddingID(KnowledgeEmbeddingIDEmbedding2)
res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.ID)
s2 := client.KnowledgeList()
res2, err := s2.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res2.List)
require.Equal(t, res.ID, res2.List[0].ID)
s3 := client.KnowledgeEdit(res.ID)
s3.SetDescription("test description 2")
s3.SetName("test 2")
s3.SetEmbeddingID(KnowledgeEmbeddingIDEmbedding2)
err = s3.Do(context.Background())
require.NoError(t, err)
s4 := client.KnowledgeDelete(res.ID)
err = s4.Do(context.Background())
require.NoError(t, err)
}

View File

@ -0,0 +1,54 @@
package zhipu
import (
"bytes"
"encoding/json"
)
// StringOr is a struct that can be either a string or a value of type T.
type StringOr[T any] struct {
String *string
Value *T
}
var (
_ json.Marshaler = StringOr[float64]{}
_ json.Unmarshaler = &StringOr[float64]{}
)
// SetString sets the string value of the struct.
func (f *StringOr[T]) SetString(v string) {
f.String = &v
f.Value = nil
}
// SetValue sets the value of the struct.
func (f *StringOr[T]) SetValue(v T) {
f.String = nil
f.Value = &v
}
func (f StringOr[T]) MarshalJSON() ([]byte, error) {
if f.Value != nil {
return json.Marshal(f.Value)
}
return json.Marshal(f.String)
}
func (f *StringOr[T]) UnmarshalJSON(data []byte) error {
if len(data) == 0 {
return nil
}
if bytes.Equal(data, []byte("null")) {
return nil
}
if data[0] == '"' {
f.String = new(string)
f.Value = nil
return json.Unmarshal(data, f.String)
} else {
f.Value = new(T)
f.String = nil
return json.Unmarshal(data, f.Value)
}
}

View File

@ -0,0 +1,37 @@
package zhipu
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestStringOr(t *testing.T) {
data := struct {
Item *StringOr[int] `json:"item,omitempty"`
}{}
data.Item = &StringOr[int]{}
data.Item.SetString("test")
b, err := json.Marshal(data)
require.NoError(t, err)
require.Equal(t, `{"item":"test"}`, string(b))
data.Item.SetValue(1)
b, err = json.Marshal(data)
require.NoError(t, err)
require.Equal(t, `{"item":1}`, string(b))
err = json.Unmarshal([]byte(`{"item":"test2"}`), &data)
require.NoError(t, err)
require.NotNil(t, data.Item.String)
require.Nil(t, data.Item.Value)
require.Equal(t, "test2", *data.Item.String)
err = json.Unmarshal([]byte(`{"item":2}`), &data)
require.NoError(t, err)
require.Nil(t, data.Item.String)
require.NotNil(t, data.Item.Value)
require.Equal(t, 2, *data.Item.Value)
}

View File

@ -0,0 +1,3 @@
{"messages": [{"role": "system", "content": "你是一位乐于助人知识渊博的全能AI助手。"},{"role": "user", "content": "请介绍一下AlphaGo的基本原理。"}, {"role": "assistant", "content": "AlphaGo利用深度神经网络与树搜索算法相结合,通过对弈数据进行训练,逐步提升游戏水平,能够对围棋局面进行判断并选择最佳落子位置。"}]}
{"messages": [{"role": "system", "content": "你是一位乐于助人知识渊博的全能AI助手。"},{"role": "user", "content": "请列举一些健康饮食的重要原则。"}, {"role": "assistant", "content": "健康饮食的重要原则包括:膳食应营养均衡,合理摄入各种营养素;多吃新鲜蔬果,适量饮水;少吃糖、油、盐,忌吃过量快餐;养成规律生活作息习惯,适量运动。"}]}
{"messages": [{"role": "system", "content": "你是一位乐于助人知识渊博的全能AI助手。"},{"role": "user", "content": "请给出一些改善睡眠质量的建议。"}, {"role": "assistant", "content": "改善睡眠质量的建议包括:保证睡眠时间,避免睡眠不足;睡前放松身心,远离刺激;养成规律作息时间,不要频繁熬夜;适量运动,但不要运动过度;睡前可以喝一杯热牛奶等温和饮料。"}]}

View File

@ -0,0 +1 @@
The quick brown fox jumps over the lazy dog

22
llm/zhipu/zhipu/util.go Normal file
View File

@ -0,0 +1,22 @@
package zhipu
// URLItem is a struct that contains a URL.
type URLItem struct {
URL string `json:"url,omitempty"`
}
// IDItem is a struct that contains an ID.
type IDItem struct {
ID string `json:"id,omitempty"`
}
// Ptr returns a pointer to the value passed in.
// Example:
//
// web_search_enable = zhipu.Ptr(false)
func Ptr[T any](v T) *T {
return &v
}
// M is a shorthand for map[string]any.
type M = map[string]any

View File

@ -0,0 +1,3 @@
package zhipu
// nothing to test

View File

@ -0,0 +1,125 @@
package zhipu
import (
"context"
"github.com/go-resty/resty/v2"
)
const (
VideoGenerationTaskStatusProcessing = "PROCESSING"
VideoGenerationTaskStatusSuccess = "SUCCESS"
VideoGenerationTaskStatusFail = "FAIL"
)
// VideoGenerationService creates a new video generation
type VideoGenerationService struct {
client *Client
model string
prompt string
userID string
imageURL string
requestID string
}
var (
_ BatchSupport = &VideoGenerationService{}
)
// VideoGenerationResponse is the response of the VideoGenerationService
type VideoGenerationResponse struct {
RequestID string `json:"request_id"`
ID string `json:"id"`
Model string `json:"model"`
TaskStatus string `json:"task_status"`
}
func NewVideoGenerationService(client *Client) *VideoGenerationService {
return &VideoGenerationService{
client: client,
}
}
func (s *VideoGenerationService) BatchMethod() string {
return "POST"
}
func (s *VideoGenerationService) BatchURL() string {
return BatchEndpointV4VideosGenerations
}
func (s *VideoGenerationService) BatchBody() any {
return s.buildBody()
}
// SetModel sets the model parameter
func (s *VideoGenerationService) SetModel(model string) *VideoGenerationService {
s.model = model
return s
}
// SetPrompt sets the prompt parameter
func (s *VideoGenerationService) SetPrompt(prompt string) *VideoGenerationService {
s.prompt = prompt
return s
}
// SetUserID sets the userID parameter
func (s *VideoGenerationService) SetUserID(userID string) *VideoGenerationService {
s.userID = userID
return s
}
// SetImageURL sets the imageURL parameter
func (s *VideoGenerationService) SetImageURL(imageURL string) *VideoGenerationService {
s.imageURL = imageURL
return s
}
// SetRequestID sets the requestID parameter
func (s *VideoGenerationService) SetRequestID(requestID string) *VideoGenerationService {
s.requestID = requestID
return s
}
func (s *VideoGenerationService) buildBody() M {
body := M{
"model": s.model,
"prompt": s.prompt,
}
if s.userID != "" {
body["user_id"] = s.userID
}
if s.imageURL != "" {
body["image_url"] = s.imageURL
}
if s.requestID != "" {
body["request_id"] = s.requestID
}
return body
}
func (s *VideoGenerationService) Do(ctx context.Context) (res VideoGenerationResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)
body := s.buildBody()
if resp, err = s.client.request(ctx).
SetBody(body).
SetResult(&res).
SetError(&apiError).
Post("videos/generations"); err != nil {
return
}
if resp.IsError() {
err = apiError
return
}
return
}

View File

@ -0,0 +1,38 @@
package zhipu
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestVideoGeneration(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)
s := client.VideoGeneration("cogvideox")
s.SetPrompt("一只可爱的小猫")
res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.TaskStatus)
require.NotEmpty(t, res.ID)
t.Log(res.ID)
for {
res, err := client.AsyncResult(res.ID).Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.TaskStatus)
if res.TaskStatus == VideoGenerationTaskStatusSuccess {
require.NotEmpty(t, res.VideoResult)
t.Log(res.VideoResult[0].URL)
t.Log(res.VideoResult[0].CoverImageURL)
}
if res.TaskStatus != VideoGenerationTaskStatusProcessing {
break
}
time.Sleep(time.Second * 5)
}
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

9
tests/ask.js Normal file
View File

@ -0,0 +1,9 @@
import {glm, gpt} from './lib/ai'
import console from './lib/console'
import chat from './mod/chat'
function main(...args) {
let ask = args[0]
if (!ask) ask = console.input('请输入问题:')
return chat.fast(ask)
}

View File

@ -2,51 +2,48 @@ package tests
import (
"apigo.cc/ai/ai"
_ "apigo.cc/ai/openai"
_ "apigo.cc/ai/zhipu"
"apigo.cc/ai/ai/js"
"apigo.cc/ai/ai/llm"
"fmt"
"github.com/ssgo/config"
"github.com/ssgo/u"
"testing"
)
var agentConf = map[string]ai.AgentConfig{}
func init() {
_ = config.LoadConfig("agent", &agentConf)
}
func TestAgent(t *testing.T) {
//testChat(t, "openai", "https://api.keya.pw/v1", keys.Openai)
ai.Init()
//testChat(t, "gpt")
// 尚未支持
//testCode(t, "openai", "https://api.keya.pw/v1", keys.Openai)
//testCode(t, "gpt")
// keya 消耗过大,计费不合理
//testAskWithImage(t, "openai", "https://api.keya.pw/v1", keys.Openai, "4032.jpg")
//testAskWithImage(t, "gpt", "4032.jpg")
// keya 不支持
//testMakeImage(t, "openai", "https://api.keya.pw/v1", keys.Openai, "冬天大雪纷飞,一个男人身穿军绿色棉大衣,戴着红色围巾和绿色帽子走在铺面大雪的小镇路上", "")
//testMakeImage(t, "gpt", "冬天大雪纷飞,一个男人身穿军绿色棉大衣,戴着红色围巾和绿色帽子走在铺面大雪的小镇路上", "")
testChat(t, "zhipu")
//testCode(t, "zhipu", "", keys.Zhipu)
//testSearch(t, "zhipu", "", keys.Zhipu)
//testChat(t, "glm")
//testCode(t, "glm", "", keys.Zhipu)
//testSearch(t, "glm", "", keys.Zhipu)
// 测试图片识别
//testAskWithImage(t, "zhipu", "", keys.Zhipu, "4032.jpg")
//testAskWithImage(t, "glm", "", keys.Zhipu, "4032.jpg")
// 视频似乎尚不支持 glm-4v-plus
//testAskWithVideo(t, "zhipu", "", keys.Zhipu, "glm-4v", "1080.mp4")
//testAskWithVideo(t, "glm", "", keys.Zhipu, "glm-4v", "1080.mp4")
//testMakeImage(t, "zhipu", "", keys.Zhipu, "冬天大雪纷飞,一个男人身穿军绿色棉大衣,戴着红色围巾和绿色帽子走在铺面大雪的小镇路上", "")
//testMakeVideo(t, "zhipu", "", keys.Zhipu, "大雪纷飞,男人蹦蹦跳跳", "https://aigc-files.bigmodel.cn/api/cogview/20240904133130c4b7121019724aa3_0.png")
//testMakeImage(t, "glm", "", keys.Zhipu, "冬天大雪纷飞,一个男人身穿军绿色棉大衣,戴着红色围巾和绿色帽子走在铺面大雪的小镇路上", "")
//testMakeVideo(t, "glm", "", keys.Zhipu, "大雪纷飞,男人蹦蹦跳跳", "https://aigc-files.bigmodel.cn/api/cogview/20240904133130c4b7121019724aa3_0.png")
testJS(t)
//testFile(t)
}
func testChat(t *testing.T, agent string) {
ag := ai.GetAgent(agent)
func testChat(t *testing.T, llmName string) {
lm := llm.Get(llmName)
if ag == nil {
if lm == nil {
t.Fatal("agent is nil")
}
r, usage, err := ag.FastAsk(ai.Messages().User().Text("你是什么模型,请给出具体名称、版本号").Make(), func(text string) {
r, usage, err := lm.FastAsk(llm.Messages().User().Text("你是什么模型,请给出具体名称、版本号").Make(), func(text string) {
fmt.Print(u.BCyan(text))
fmt.Print(" ")
})
@ -60,14 +57,14 @@ func testChat(t *testing.T, agent string) {
fmt.Println("usage:", u.JsonP(usage))
}
func testCode(t *testing.T, agent string) {
ag := ai.GetAgent(agent)
func testCode(t *testing.T, llmName string) {
lm := llm.Get(llmName)
if ag == nil {
if lm == nil {
t.Fatal("agent is nil")
}
r, usage, err := ag.CodeInterpreterAsk(ai.Messages().User().Text("计算[5,10,20,700,99,310,978,100]的平均值和方差。").Make(), func(text string) {
r, usage, err := lm.CodeInterpreterAsk(llm.Messages().User().Text("计算[5,10,20,700,99,310,978,100]的平均值和方差。").Make(), func(text string) {
fmt.Print(u.BCyan(text))
fmt.Print(" ")
})
@ -81,14 +78,14 @@ func testCode(t *testing.T, agent string) {
fmt.Println("usage:", u.JsonP(usage))
}
func testSearch(t *testing.T, agent string) {
ag := ai.GetAgent(agent)
func testSearch(t *testing.T, llmName string) {
lm := llm.Get(llmName)
if ag == nil {
if lm == nil {
t.Fatal("agent is nil")
}
r, usage, err := ag.WebSearchAsk(ai.Messages().User().Text("今天上海的天气怎么样?").Make(), func(text string) {
r, usage, err := lm.WebSearchAsk(llm.Messages().User().Text("今天上海的天气怎么样?").Make(), func(text string) {
fmt.Print(u.BCyan(text))
fmt.Print(" ")
})
@ -102,10 +99,10 @@ func testSearch(t *testing.T, agent string) {
fmt.Println("usage:", u.JsonP(usage))
}
func testAskWithImage(t *testing.T, agent, imageFile string) {
ag := ai.GetAgent(agent)
func testAskWithImage(t *testing.T, llmName, imageFile string) {
lm := llm.Get(llmName)
if ag == nil {
if lm == nil {
t.Fatal("agent is nil")
}
@ -115,7 +112,7 @@ func testAskWithImage(t *testing.T, agent, imageFile string) {
3正在用什么软件播放什么歌谁演唱的歌曲的大意是
4后面的浏览器中正在浏览什么内容猜测一下我浏览这个网页是想干嘛
`
r, usage, err := ag.MultiAsk(ai.Messages().User().Text(ask).Image("data:image/jpeg;base64,"+u.Base64(u.ReadFileBytesN(imageFile))).Make(), func(text string) {
r, usage, err := lm.MultiAsk(llm.Messages().User().Text(ask).Image("data:image/jpeg;base64,"+u.Base64(u.ReadFileBytesN(imageFile))).Make(), func(text string) {
fmt.Print(u.BCyan(text))
fmt.Print(" ")
})
@ -129,10 +126,10 @@ func testAskWithImage(t *testing.T, agent, imageFile string) {
fmt.Println("usage:", u.JsonP(usage))
}
func testAskWithVideo(t *testing.T, agent, videoFile string) {
ag := ai.GetAgent(agent)
func testAskWithVideo(t *testing.T, llmName, videoFile string) {
lm := llm.Get(llmName)
if ag == nil {
if lm == nil {
t.Fatal("agent is nil")
}
@ -141,7 +138,7 @@ func testAskWithVideo(t *testing.T, agent, videoFile string) {
4后面的浏览器中正在浏览什么内容猜测一下我浏览这个网页是想干嘛
`
r, usage, err := ag.MultiAsk(ai.Messages().User().Text(ask).Video("data:video/mp4,"+u.Base64(u.ReadFileBytesN(videoFile))).Make(), func(text string) {
r, usage, err := lm.MultiAsk(llm.Messages().User().Text(ask).Video("data:video/mp4,"+u.Base64(u.ReadFileBytesN(videoFile))).Make(), func(text string) {
fmt.Print(u.BCyan(text))
fmt.Print(" ")
})
@ -155,14 +152,16 @@ func testAskWithVideo(t *testing.T, agent, videoFile string) {
fmt.Println("usage:", u.JsonP(usage))
}
func testMakeImage(t *testing.T, agent, prompt, refImage string) {
ag := ai.GetAgent(agent)
func testMakeImage(t *testing.T, llmName, prompt, refImage string) {
lm := llm.Get(llmName)
if ag == nil {
if lm == nil {
t.Fatal("agent is nil")
}
r, err := ag.FastMakeImage(prompt, "1024x1024", "")
r, err := lm.FastMakeImage(prompt, llm.GCConfig{
Size: "1024x1024",
})
if err != nil {
t.Fatal("发生错误", err.Error())
@ -174,14 +173,17 @@ func testMakeImage(t *testing.T, agent, prompt, refImage string) {
}
}
func testMakeVideo(t *testing.T, agent, prompt, refImage string) {
ag := ai.GetAgent(agent)
func testMakeVideo(t *testing.T, llmName, prompt, refImage string) {
lm := llm.Get(llmName)
if ag == nil {
if lm == nil {
t.Fatal("agent is nil")
}
r, covers, err := ag.FastMakeVideo(prompt, "1280x720", refImage)
r, covers, err := lm.FastMakeVideo(prompt, llm.GCConfig{
Size: "1280x720",
Ref: refImage,
})
if err != nil {
t.Fatal("发生错误", err.Error())
@ -192,3 +194,31 @@ func testMakeVideo(t *testing.T, agent, prompt, refImage string) {
fmt.Println("result:", i, v, covers[i])
}
}
func testJS(t *testing.T) {
r1, err := js.RunFile("test.js", "1+2=4吗")
if err != nil {
t.Fatal("发生错误", err.Error())
}
r := js.ChatResult{}
u.Convert(r1, &r)
fmt.Println()
fmt.Println("result:", r.Result)
fmt.Println("usage:", u.JsonP(r.TokenUsage))
}
func testFile(t *testing.T) {
r1, err := js.Run(`
import fs from './lib/file'
import out from './lib/console'
let r = fs.read('test.js')
return r
`, "test.js")
if err != nil {
t.Fatal("发生错误", err.Error())
}
fmt.Println()
fmt.Println("result:", r1)
}

View File

@ -1,6 +1,7 @@
agent:
openai:
llm:
gpt1:
apiKey: ...
# endpoint: ...
# endpoint: https://.... # api base url
# llm: openai # registered llm name, if not specified, it will use the key name as llm name
zhipu:
apiKey: ...

View File

@ -1,24 +0,0 @@
module tests
go 1.22
require (
apigo.cc/ai/ai v0.0.0
apigo.cc/ai/openai v0.0.1
apigo.cc/ai/zhipu v0.0.1
github.com/ssgo/config v1.7.7
github.com/ssgo/u v1.7.7
)
require (
apigo.cc/ai/agent v0.0.1 // indirect
github.com/go-resty/resty/v2 v2.14.0 // indirect
github.com/golang-jwt/jwt/v5 v5.2.1 // indirect
github.com/sashabaranov/go-openai v1.29.1 // indirect
github.com/ssgo/log v1.7.7 // indirect
github.com/ssgo/standard v1.7.7 // indirect
golang.org/x/net v0.29.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
replace apigo.cc/ai/ai v0.0.0 => ../

14
tests/mod/chat.js Normal file
View File

@ -0,0 +1,14 @@
import {glm} from '../lib/ai'
import console from '../lib/console'
function fast(...args) {
if(!args[0]) throw new Error('no ask')
let r = glm.fastAsk(args[0], r => {
console.print(r)
})
console.println()
return r
}
// exports.fast = fast
module.exports = {fast}

6
tests/test.js Normal file
View File

@ -0,0 +1,6 @@
import {glm, gpt} from './lib/ai'
import chat from './mod/chat'
function main(...args) {
return chat.fast(args[0])
}