support js
This commit is contained in:
parent
342f2d9c09
commit
b4cddad489
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
||||
.*
|
||||
go.sum
|
||||
env.yml
|
||||
/tests/lib
|
||||
|
204
README.md
Normal file
204
README.md
Normal file
@ -0,0 +1,204 @@
|
||||
# 低代码AI能力工具包
|
||||
|
||||
## 命令行工具
|
||||
|
||||
### Install
|
||||
|
||||
```shell
|
||||
go install apigo.cc/ai/ai/ai@latest
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
```shell
|
||||
ai -h | --help show usage
|
||||
ai -e | --export export ai.ts file to ./lib for develop
|
||||
ai test | test.js run test.js, if not specified, run ./ai.js
|
||||
```
|
||||
|
||||
### Sample
|
||||
|
||||
#### test.js
|
||||
|
||||
```javascript
|
||||
import {glm} from './lib/ai'
|
||||
import console from './lib/console'
|
||||
|
||||
function main(...args) {
|
||||
if(!args[0]) throw new Error('no ask')
|
||||
let r = glm.fastAsk(args[0], r => {
|
||||
console.print(r)
|
||||
})
|
||||
console.println()
|
||||
return r
|
||||
}
|
||||
```
|
||||
|
||||
#### run sample
|
||||
|
||||
```shell
|
||||
ai test "你好"
|
||||
```
|
||||
|
||||
### Module
|
||||
|
||||
#### mod/glm.js
|
||||
|
||||
```javascript
|
||||
import {glm} from './lib/ai'
|
||||
import console from './lib/console'
|
||||
|
||||
function chat(...args) {
|
||||
if(!args[0]) throw new Error('no prompt')
|
||||
return glm.fastAsk(args[0])
|
||||
}
|
||||
|
||||
function draw(...args) {
|
||||
if(!args[0]) throw new Error('no ask')
|
||||
return glm.makeImage(args[0], {size:'1024x1024'})
|
||||
}
|
||||
|
||||
module.exports = {chat, draw}
|
||||
```
|
||||
|
||||
#### test.js
|
||||
|
||||
```javascript
|
||||
import glm from './mod/glm'
|
||||
|
||||
function main(...args) {
|
||||
console.println(glm.chat(args[0]))
|
||||
}
|
||||
```
|
||||
|
||||
#### run sample
|
||||
|
||||
```shell
|
||||
ai test "你好"
|
||||
```
|
||||
|
||||
### Configure
|
||||
|
||||
#### llm.yml
|
||||
|
||||
```yaml
|
||||
openai:
|
||||
apiKey: ...
|
||||
zhipu:
|
||||
apiKey: ...
|
||||
```
|
||||
|
||||
#### or use env.yml
|
||||
|
||||
#### llm.yml
|
||||
|
||||
```yaml
|
||||
llm:
|
||||
openai:
|
||||
apiKey: ...
|
||||
zhipu:
|
||||
apiKey: ...
|
||||
```
|
||||
|
||||
#### encrypt apiKey
|
||||
|
||||
install sskey
|
||||
|
||||
```shell
|
||||
go install github.com/ssgo/tool/sskey@latest
|
||||
sskey -e 'your apiKey'
|
||||
```
|
||||
|
||||
copy url base64 format encrypted apiKey into llm.yml or env.yml
|
||||
|
||||
#### config with special endpoint
|
||||
|
||||
```yaml
|
||||
llm:
|
||||
openai:
|
||||
apiKey: ...
|
||||
endpoint: https://api.openai.com/v1
|
||||
```
|
||||
|
||||
#### config multi api
|
||||
|
||||
```yaml
|
||||
llm:
|
||||
glm:
|
||||
apiKey: ...
|
||||
llm: zhipu
|
||||
glm2:
|
||||
apiKey: ...
|
||||
endpoint: https://......
|
||||
llm: zhipu
|
||||
```
|
||||
|
||||
|
||||
## 调用 JavaScript API
|
||||
|
||||
### Install
|
||||
|
||||
```shell
|
||||
go get apigo.cc/ai/ai
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"apigo.cc/ai/ai/js"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func main() {
|
||||
result, err := js.Run(`return ai.glm.fastAsk(args[0])`, "", "你好")
|
||||
// js.RunFile
|
||||
// js.StartFromFile
|
||||
// js.StartFromCode
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
} else if result != nil {
|
||||
fmt.Println(result)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## 调用 Go API
|
||||
|
||||
### Install
|
||||
|
||||
```shell
|
||||
go get apigo.cc/ai/ai
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"apigo.cc/ai/ai"
|
||||
"apigo.cc/ai/ai/llm"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ai.Init()
|
||||
glm := llm.Get("zhipu")
|
||||
|
||||
r, usage, err := glm.FastAsk(llm.Messages().User().Text("你是什么模型").Make(), func(text string) {
|
||||
fmt.Print(text)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
} else {
|
||||
fmt.Println()
|
||||
fmt.Println("result:", r)
|
||||
fmt.Println("usage:", usage)
|
||||
}
|
||||
}
|
||||
```
|
89
agent.go
89
agent.go
@ -1,89 +0,0 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"apigo.cc/ai/agent"
|
||||
"github.com/ssgo/config"
|
||||
"github.com/ssgo/u"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var confAes = u.NewAes([]byte("?GQ$0K0GgLdO=f+~L68PLm$uhKr4'=tV"), []byte("VFs7@sK61cj^f?HZ"))
|
||||
var keysIsSet = false
|
||||
|
||||
func SetSSKey(key, iv []byte) {
|
||||
if !keysIsSet {
|
||||
confAes = u.NewAes(key, iv)
|
||||
keysIsSet = true
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
TypeText = agent.TypeText
|
||||
TypeImage = agent.TypeImage
|
||||
TypeVideo = agent.TypeVideo
|
||||
RoleSystem = agent.RoleSystem
|
||||
RoleUser = agent.RoleUser
|
||||
RoleAssistant = agent.RoleAssistant
|
||||
RoleTool = agent.RoleTool
|
||||
ToolCodeInterpreter = agent.ToolCodeInterpreter
|
||||
ToolWebSearch = agent.ToolWebSearch
|
||||
)
|
||||
|
||||
type APIConfig struct {
|
||||
Endpoint string
|
||||
ApiKey string
|
||||
DefaultChatModelConfig ChatModelConfig
|
||||
}
|
||||
|
||||
type AgentConfig struct {
|
||||
ApiKey string
|
||||
Endpoint string
|
||||
Agent string
|
||||
ChatConfig ChatModelConfig
|
||||
}
|
||||
|
||||
var agentConfigs map[string]*AgentConfig
|
||||
var agentConfigsLock = sync.RWMutex{}
|
||||
|
||||
func GetAgent(name string) agent.Agent {
|
||||
ag := agent.GetAgent(name)
|
||||
if ag != nil {
|
||||
return ag
|
||||
}
|
||||
|
||||
var agConf *AgentConfig
|
||||
if agentConfigs == nil {
|
||||
agConfs := make(map[string]*AgentConfig)
|
||||
config.LoadConfig("agent", &agConfs)
|
||||
agConf = agConfs[name]
|
||||
if agConf != nil {
|
||||
agentConfigsLock.Lock()
|
||||
agentConfigs = agConfs
|
||||
agentConfigsLock.Unlock()
|
||||
}
|
||||
} else {
|
||||
agentConfigsLock.RLock()
|
||||
agConf = agentConfigs[name]
|
||||
agentConfigsLock.RUnlock()
|
||||
}
|
||||
|
||||
if agConf == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if agConf.Agent == "" {
|
||||
agConf.Agent = name
|
||||
}
|
||||
|
||||
return agent.CreateAgent(name, agConf.Agent, agent.APIConfig{
|
||||
Endpoint: agConf.Endpoint,
|
||||
ApiKey: confAes.DecryptUrlBase64ToString(agConf.ApiKey),
|
||||
DefaultChatModelConfig: agent.ChatModelConfig{
|
||||
Model: agConf.ChatConfig.Model,
|
||||
MaxTokens: agConf.ChatConfig.MaxTokens,
|
||||
Temperature: agConf.ChatConfig.Temperature,
|
||||
TopP: agConf.ChatConfig.TopP,
|
||||
Tools: agConf.ChatConfig.Tools,
|
||||
},
|
||||
})
|
||||
}
|
61
ai.go
Normal file
61
ai.go
Normal file
@ -0,0 +1,61 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"apigo.cc/ai/ai/llm"
|
||||
_ "apigo.cc/ai/ai/llm/openai"
|
||||
_ "apigo.cc/ai/ai/llm/zhipu"
|
||||
"github.com/ssgo/config"
|
||||
"github.com/ssgo/u"
|
||||
"os"
|
||||
)
|
||||
|
||||
var confAes = u.NewAes([]byte("?GQ$0K0GgLdO=f+~L68PLm$uhKr4'=tV"), []byte("VFs7@sK61cj^f?HZ"))
|
||||
var keysIsSet = false
|
||||
var isInit = false
|
||||
|
||||
func SetSSKey(key, iv []byte) {
|
||||
if !keysIsSet {
|
||||
confAes = u.NewAes(key, iv)
|
||||
keysIsSet = true
|
||||
}
|
||||
}
|
||||
|
||||
func Init() {
|
||||
InitFrom("")
|
||||
}
|
||||
|
||||
func InitFrom(filePath string) {
|
||||
if !isInit {
|
||||
isInit = true
|
||||
list := map[string]*struct {
|
||||
Endpoint string
|
||||
ApiKey string
|
||||
DefaultChatModelConfig llm.ChatConfig
|
||||
Llm string
|
||||
}{}
|
||||
savedPath := ""
|
||||
if filePath != "" {
|
||||
curPath, _ := os.Getwd()
|
||||
if curPath != filePath {
|
||||
savedPath = curPath
|
||||
_ = os.Chdir(filePath)
|
||||
config.ResetConfigEnv()
|
||||
}
|
||||
}
|
||||
_ = config.LoadConfig("llm", &list)
|
||||
if savedPath != "" {
|
||||
_ = os.Chdir(savedPath)
|
||||
}
|
||||
for name, llmConf := range list {
|
||||
if llmConf.Llm == "" {
|
||||
llmConf.Llm = name
|
||||
}
|
||||
llmConf.ApiKey = confAes.DecryptUrlBase64ToString(llmConf.ApiKey)
|
||||
llm.Create(name, llmConf.Llm, llm.Config{
|
||||
Endpoint: llmConf.Endpoint,
|
||||
ApiKey: llmConf.ApiKey,
|
||||
ChatConfig: llmConf.DefaultChatModelConfig,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
59
ai/ai.go
Normal file
59
ai/ai.go
Normal file
@ -0,0 +1,59 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"apigo.cc/ai/ai/js"
|
||||
"fmt"
|
||||
"github.com/ssgo/u"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) > 1 && (os.Args[1] == "-e" || os.Args[1] == "export") {
|
||||
imports, err := js.ExportForDev()
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
} else {
|
||||
fmt.Println(`exported to ./lib
|
||||
example:
|
||||
`)
|
||||
fmt.Println(u.Cyan(imports + `
|
||||
// import customMod from './customMod'
|
||||
|
||||
function main(...args) {
|
||||
// TODO
|
||||
return null
|
||||
}
|
||||
`))
|
||||
}
|
||||
return
|
||||
} else {
|
||||
jsFile := "ai.js"
|
||||
if len(os.Args) > 1 {
|
||||
jsFile = os.Args[1]
|
||||
if !strings.HasSuffix(jsFile, ".js") && !u.FileExists(jsFile) {
|
||||
jsFile = jsFile + ".js"
|
||||
}
|
||||
}
|
||||
if u.FileExists(jsFile) {
|
||||
args := make([]any, len(os.Args)-2)
|
||||
for i := 2; i < len(os.Args); i++ {
|
||||
args[i-2] = os.Args[i]
|
||||
}
|
||||
result, err := js.RunFile(jsFile, args...)
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
} else if result != nil {
|
||||
fmt.Println(u.JsonP(result))
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println(`Usage:
|
||||
ai -h | --help show usage
|
||||
ai -e | --export export ai.ts file for develop
|
||||
ai test | test.js run test.js, if not specified, run ai.js
|
||||
`)
|
||||
return
|
||||
}
|
24
go.mod
24
go.mod
@ -3,9 +3,29 @@ module apigo.cc/ai/ai
|
||||
go 1.22
|
||||
|
||||
require (
|
||||
apigo.cc/ai/agent v0.0.1
|
||||
github.com/dop251/goja v0.0.0-20240828124009-016eb7256539
|
||||
github.com/go-resty/resty/v2 v2.15.0
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1
|
||||
github.com/sashabaranov/go-openai v1.29.2
|
||||
github.com/ssgo/config v1.7.7
|
||||
github.com/ssgo/log v1.7.7
|
||||
github.com/ssgo/u v1.7.7
|
||||
github.com/stretchr/testify v1.9.0
|
||||
)
|
||||
|
||||
require gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dlclark/regexp2 v1.11.4 // indirect
|
||||
github.com/dop251/goja_nodejs v0.0.0-20240728170619-29b559befffc // indirect
|
||||
github.com/go-sourcemap/sourcemap v2.1.4+incompatible // indirect
|
||||
github.com/google/pprof v0.0.0-20240910150728-a0b0bb1d4134 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/ssgo/standard v1.7.7 // indirect
|
||||
golang.org/x/net v0.29.0 // indirect
|
||||
golang.org/x/text v0.18.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
replace (
|
||||
github.com/ssgo/config v1.7.7 => ../../ssgo/config
|
||||
)
|
||||
|
291
js/ai.go
Normal file
291
js/ai.go
Normal file
@ -0,0 +1,291 @@
|
||||
package js
|
||||
|
||||
import (
|
||||
"apigo.cc/ai/ai/llm"
|
||||
"github.com/dop251/goja"
|
||||
"github.com/ssgo/u"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ChatResult struct {
|
||||
llm.TokenUsage
|
||||
Result string
|
||||
Error string
|
||||
}
|
||||
|
||||
type AIGCResult struct {
|
||||
Result string
|
||||
Preview string
|
||||
Results []string
|
||||
Previews []string
|
||||
Error string
|
||||
}
|
||||
|
||||
func requireAI(lm llm.LLM) map[string]any {
|
||||
return map[string]any{
|
||||
"ask": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||||
conf, cb := getAskArgs(args.This, vm, args.Arguments)
|
||||
result, usage, err := lm.Ask(makeChatMessages(args.Arguments), conf, cb)
|
||||
return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)})
|
||||
},
|
||||
"fastAsk": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||||
_, cb := getAskArgs(args.This, vm, args.Arguments)
|
||||
result, usage, err := lm.FastAsk(makeChatMessages(args.Arguments), cb)
|
||||
return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)})
|
||||
},
|
||||
"longAsk": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||||
_, cb := getAskArgs(args.This, vm, args.Arguments)
|
||||
result, usage, err := lm.LongAsk(makeChatMessages(args.Arguments), cb)
|
||||
return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)})
|
||||
},
|
||||
"batterAsk": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||||
_, cb := getAskArgs(args.This, vm, args.Arguments)
|
||||
result, usage, err := lm.BatterAsk(makeChatMessages(args.Arguments), cb)
|
||||
return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)})
|
||||
},
|
||||
"bestAsk": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||||
_, cb := getAskArgs(args.This, vm, args.Arguments)
|
||||
result, usage, err := lm.BestAsk(makeChatMessages(args.Arguments), cb)
|
||||
return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)})
|
||||
},
|
||||
|
||||
"multiAsk": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||||
_, cb := getAskArgs(args.This, vm, args.Arguments)
|
||||
result, usage, err := lm.MultiAsk(makeChatMessages(args.Arguments), cb)
|
||||
return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)})
|
||||
},
|
||||
"bestMultiAsk": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||||
_, cb := getAskArgs(args.This, vm, args.Arguments)
|
||||
result, usage, err := lm.BestMultiAsk(makeChatMessages(args.Arguments), cb)
|
||||
return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)})
|
||||
},
|
||||
|
||||
"codeInterpreterAsk": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||||
_, cb := getAskArgs(args.This, vm, args.Arguments)
|
||||
result, usage, err := lm.CodeInterpreterAsk(makeChatMessages(args.Arguments), cb)
|
||||
return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)})
|
||||
},
|
||||
"webSearchAsk": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||||
_, cb := getAskArgs(args.This, vm, args.Arguments)
|
||||
result, usage, err := lm.WebSearchAsk(makeChatMessages(args.Arguments), cb)
|
||||
return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)})
|
||||
},
|
||||
|
||||
"makeImage": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||||
prompt, conf := getAIGCArgs(args.Arguments)
|
||||
results, err := lm.MakeImage(prompt, conf)
|
||||
return makeAIGCResult(vm, results, nil, err)
|
||||
},
|
||||
"fastMakeImage": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||||
prompt, conf := getAIGCArgs(args.Arguments)
|
||||
results, err := lm.FastMakeImage(prompt, conf)
|
||||
return makeAIGCResult(vm, results, nil, err)
|
||||
},
|
||||
"bestMakeImage": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||||
prompt, conf := getAIGCArgs(args.Arguments)
|
||||
results, err := lm.BestMakeImage(prompt, conf)
|
||||
return makeAIGCResult(vm, results, nil, err)
|
||||
},
|
||||
|
||||
"makeVideo": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||||
prompt, conf := getAIGCArgs(args.Arguments)
|
||||
results, previews, err := lm.MakeVideo(prompt, conf)
|
||||
return makeAIGCResult(vm, results, previews, err)
|
||||
},
|
||||
"fastMakeVideo": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||||
prompt, conf := getAIGCArgs(args.Arguments)
|
||||
results, previews, err := lm.FastMakeVideo(prompt, conf)
|
||||
return makeAIGCResult(vm, results, previews, err)
|
||||
},
|
||||
"bestMakeVideo": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||||
prompt, conf := getAIGCArgs(args.Arguments)
|
||||
results, previews, err := lm.BestMakeVideo(prompt, conf)
|
||||
return makeAIGCResult(vm, results, previews, err)
|
||||
},
|
||||
|
||||
"support": lm.Support(),
|
||||
}
|
||||
}
|
||||
|
||||
func getErrorStr(err error) string {
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func makeAIGCResult(vm *goja.Runtime, results []string, previews []string, err error) goja.Value {
|
||||
result := ""
|
||||
preview := ""
|
||||
if len(results) > 0 {
|
||||
result = results[0]
|
||||
} else {
|
||||
results = make([]string, 0)
|
||||
}
|
||||
if len(previews) > 0 {
|
||||
preview = previews[0]
|
||||
} else {
|
||||
previews = make([]string, 0)
|
||||
}
|
||||
return vm.ToValue(AIGCResult{
|
||||
Result: result,
|
||||
Preview: preview,
|
||||
Results: results,
|
||||
Previews: previews,
|
||||
Error: getErrorStr(err),
|
||||
})
|
||||
}
|
||||
|
||||
func getAIGCArgs(args []goja.Value) (string, llm.GCConfig) {
|
||||
prompt := ""
|
||||
var config llm.GCConfig
|
||||
if len(args) > 0 {
|
||||
prompt = u.String(args[0].Export())
|
||||
if len(args) > 1 {
|
||||
u.Convert(args[1].Export(), &config)
|
||||
}
|
||||
}
|
||||
return prompt, config
|
||||
}
|
||||
|
||||
func getAskArgs(thisArg goja.Value, vm *goja.Runtime, args []goja.Value) (llm.ChatConfig, func(string)) {
|
||||
var chatConfig llm.ChatConfig
|
||||
var callback func(answer string)
|
||||
if len(args) > 0 {
|
||||
for i := 1; i < len(args); i++ {
|
||||
if cb, ok := goja.AssertFunction(args[i]); ok {
|
||||
callback = func(answer string) {
|
||||
_, _ = cb(thisArg, vm.ToValue(answer))
|
||||
}
|
||||
} else {
|
||||
switch args[i].ExportType().Kind() {
|
||||
case reflect.Map, reflect.Struct:
|
||||
u.Convert(args[i].Export(), &chatConfig)
|
||||
default:
|
||||
chatConfig.Model = u.String(args[i].Export())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return chatConfig, callback
|
||||
}
|
||||
|
||||
func makeChatMessages(args []goja.Value) []llm.ChatMessage {
|
||||
out := make([]llm.ChatMessage, 0)
|
||||
if len(args) > 0 {
|
||||
v := args[0].Export()
|
||||
vv := reflect.ValueOf(v)
|
||||
t := args[0].ExportType()
|
||||
lastRoleIsUser := false
|
||||
switch t.Kind() {
|
||||
// 数组,根据成员类型处理
|
||||
// 字符串:
|
||||
// 含有媒体:单条多模态消息
|
||||
// 无媒体:多条文本消息
|
||||
// 数组:多条消息(第一个成员不是 role 则自动生成)
|
||||
// 对象:多条消息(无 role 则自动生成)(支持 content 或 contents)
|
||||
// 结构:转换为 llm.ChatMessage
|
||||
// 对象:单条消息(支持 content 或 contents)
|
||||
// 结构:转换为 llm.ChatMessage
|
||||
// 字符串:单条文本消息
|
||||
case reflect.Slice:
|
||||
hasSub := false
|
||||
hasMulti := false
|
||||
for i := 0; i < vv.Len(); i++ {
|
||||
if vv.Index(i).Kind() == reflect.Slice || vv.Index(i).Kind() == reflect.Map || vv.Index(i).Kind() == reflect.Struct {
|
||||
hasSub = true
|
||||
break
|
||||
}
|
||||
if vv.Index(i).Kind() == reflect.String {
|
||||
str := vv.Index(i).String()
|
||||
if strings.HasPrefix(str, "data:") || strings.HasPrefix(str, "https://") || strings.HasPrefix(str, "http://") {
|
||||
hasMulti = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if hasSub || !hasMulti {
|
||||
// 有子对象或纯文本数组
|
||||
var defaultRole string
|
||||
for i := 0; i < vv.Len(); i++ {
|
||||
lastRoleIsUser = !lastRoleIsUser
|
||||
if lastRoleIsUser {
|
||||
defaultRole = llm.RoleUser
|
||||
} else {
|
||||
defaultRole = llm.RoleAssistant
|
||||
}
|
||||
vv2 := vv.Index(i)
|
||||
switch vv2.Kind() {
|
||||
case reflect.Slice:
|
||||
out = append(out, makeChatMessageFromSlice(vv2, defaultRole))
|
||||
case reflect.Map:
|
||||
out = append(out, makeChatMessageFromSlice(vv2, defaultRole))
|
||||
case reflect.Struct:
|
||||
item := llm.ChatMessage{}
|
||||
u.Convert(vv2.Interface(), &item)
|
||||
out = append(out, item)
|
||||
default:
|
||||
out = append(out, llm.ChatMessage{Role: llm.RoleUser, Contents: []llm.ChatMessageContent{makeChatMessageContent(u.String(vv2.Interface()))}})
|
||||
}
|
||||
lastRoleIsUser = out[len(out)-1].Role != llm.RoleUser
|
||||
}
|
||||
} else {
|
||||
// 单条多模态消息
|
||||
out = append(out, makeChatMessageFromSlice(vv, llm.RoleUser))
|
||||
}
|
||||
case reflect.Map:
|
||||
out = append(out, makeChatMessageFromMap(vv, llm.RoleUser))
|
||||
case reflect.Struct:
|
||||
item := llm.ChatMessage{}
|
||||
u.Convert(v, &item)
|
||||
out = append(out, item)
|
||||
default:
|
||||
out = append(out, llm.ChatMessage{Role: llm.RoleUser, Contents: []llm.ChatMessageContent{makeChatMessageContent(u.String(v))}})
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func makeChatMessageFromSlice(vv reflect.Value, defaultRole string) llm.ChatMessage {
|
||||
role := u.String(vv.Index(0).Interface())
|
||||
j := 0
|
||||
if role == llm.RoleUser || role == llm.RoleAssistant || role == llm.RoleSystem || role == llm.RoleTool {
|
||||
j = 1
|
||||
} else {
|
||||
role = defaultRole
|
||||
}
|
||||
contents := make([]llm.ChatMessageContent, 0)
|
||||
for ; j < vv.Len(); j++ {
|
||||
contents = append(contents, makeChatMessageContent(u.String(vv.Index(j).Interface())))
|
||||
}
|
||||
return llm.ChatMessage{Role: role, Contents: contents}
|
||||
}
|
||||
|
||||
func makeChatMessageFromMap(vv reflect.Value, defaultRole string) llm.ChatMessage {
|
||||
role := u.String(vv.MapIndex(reflect.ValueOf("role")).Interface())
|
||||
if role == "" {
|
||||
role = defaultRole
|
||||
}
|
||||
contents := make([]llm.ChatMessageContent, 0)
|
||||
content := u.String(vv.MapIndex(reflect.ValueOf("content")).Interface())
|
||||
if content != "" {
|
||||
contents = append(contents, makeChatMessageContent(content))
|
||||
} else {
|
||||
contentsV := vv.MapIndex(reflect.ValueOf("contents"))
|
||||
if contentsV.IsValid() && contentsV.Kind() == reflect.Slice {
|
||||
for i := 0; i < contentsV.Len(); i++ {
|
||||
contents = append(contents, makeChatMessageContent(u.String(contentsV.Index(i).Interface())))
|
||||
}
|
||||
}
|
||||
}
|
||||
return llm.ChatMessage{Role: role, Contents: contents}
|
||||
}
|
||||
|
||||
func makeChatMessageContent(contnet string) llm.ChatMessageContent {
|
||||
if strings.HasPrefix(contnet, "data:image/") || ((strings.HasPrefix(contnet, "https://") || strings.HasPrefix(contnet, "http://")) && (strings.HasSuffix(contnet, ".png") || strings.HasSuffix(contnet, ".jpg") || strings.HasSuffix(contnet, ".jpeg") || strings.HasSuffix(contnet, ".gif") || strings.HasSuffix(contnet, ".svg"))) {
|
||||
return llm.ChatMessageContent{Type: llm.TypeImage, Content: contnet}
|
||||
} else if strings.HasPrefix(contnet, "data:video/") || ((strings.HasPrefix(contnet, "https://") || strings.HasPrefix(contnet, "http://")) && (strings.HasSuffix(contnet, ".mp4") || strings.HasSuffix(contnet, ".mov") || strings.HasSuffix(contnet, ".m4v") || strings.HasSuffix(contnet, ".avi") || strings.HasSuffix(contnet, ".wmv"))) {
|
||||
return llm.ChatMessageContent{Type: llm.TypeVideo, Content: contnet}
|
||||
}
|
||||
return llm.ChatMessageContent{Type: llm.TypeText, Content: contnet}
|
||||
}
|
81
js/console.go
Normal file
81
js/console.go
Normal file
@ -0,0 +1,81 @@
|
||||
package js
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/dop251/goja"
|
||||
"github.com/ssgo/u"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func requireConsole() map[string]any {
|
||||
return map[string]any{
|
||||
"print": func(args goja.FunctionCall) goja.Value {
|
||||
consolePrint(args, "print", nil)
|
||||
return nil
|
||||
},
|
||||
"println": func(args goja.FunctionCall) goja.Value {
|
||||
consolePrint(args, "log", nil)
|
||||
return nil
|
||||
},
|
||||
"log": func(args goja.FunctionCall) goja.Value {
|
||||
consolePrint(args, "log", nil)
|
||||
return nil
|
||||
},
|
||||
"info": func(args goja.FunctionCall) goja.Value {
|
||||
consolePrint(args, "info", nil)
|
||||
return nil
|
||||
},
|
||||
"warn": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||||
consolePrint(args, "warn", vm)
|
||||
return nil
|
||||
},
|
||||
"error": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||||
consolePrint(args, "error", vm)
|
||||
return nil
|
||||
},
|
||||
"debug": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||||
consolePrint(args, "debug", vm)
|
||||
return nil
|
||||
},
|
||||
"input": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||||
consolePrint(args, "print", nil)
|
||||
line := ""
|
||||
_, _ = fmt.Scanln(&line)
|
||||
return vm.ToValue(line)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func consolePrint(args goja.FunctionCall, typ string, vm *goja.Runtime) {
|
||||
arr := make([]any, len(args.Arguments))
|
||||
textColor := u.TextNone
|
||||
switch typ {
|
||||
case "info":
|
||||
textColor = u.TextCyan
|
||||
case "warn":
|
||||
textColor = u.TextYellow
|
||||
case "error":
|
||||
textColor = u.TextRed
|
||||
}
|
||||
|
||||
for i, arg := range args.Arguments {
|
||||
if textColor != u.TextNone {
|
||||
arr[i] = u.Color(u.StringP(arg.Export()), textColor, u.BgNone)
|
||||
} else {
|
||||
arr[i] = u.StringP(arg.Export())
|
||||
}
|
||||
}
|
||||
|
||||
if (typ == "warn" || typ == "error" || typ == "debug") && vm != nil {
|
||||
callStacks := make([]string, 0)
|
||||
for _, stack := range vm.CaptureCallStack(0, nil) {
|
||||
callStacks = append(callStacks, u.Color(" "+stack.Position().String(), textColor, u.BgNone))
|
||||
}
|
||||
fmt.Println(arr...)
|
||||
fmt.Println(strings.Join(callStacks, "\n"))
|
||||
} else if typ == "print" {
|
||||
fmt.Print(arr...)
|
||||
} else {
|
||||
fmt.Println(arr...)
|
||||
}
|
||||
}
|
48
js/file.go
Normal file
48
js/file.go
Normal file
@ -0,0 +1,48 @@
|
||||
package js
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/dop251/goja"
|
||||
"github.com/ssgo/u"
|
||||
)
|
||||
|
||||
func requireFile() map[string]any {
|
||||
return map[string]any{
|
||||
"read": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||||
if len(args.Arguments) < 1 {
|
||||
return vm.NewGoError(errors.New("arguments need 1 given " + u.String(len(args.Arguments))))
|
||||
}
|
||||
if r, err := u.ReadFile(u.String(args.Arguments[0].Export())); err == nil {
|
||||
return vm.ToValue(r)
|
||||
} else {
|
||||
return vm.NewGoError(err)
|
||||
}
|
||||
},
|
||||
"write": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||||
if len(args.Arguments) < 2 {
|
||||
return vm.NewGoError(errors.New("arguments need 2 given " + u.String(len(args.Arguments))))
|
||||
}
|
||||
if err := u.WriteFileBytes(u.String(args.Arguments[0].Export()), u.Bytes(args.Arguments[0].Export())); err == nil {
|
||||
return nil
|
||||
} else {
|
||||
return vm.NewGoError(err)
|
||||
}
|
||||
},
|
||||
"dir": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||||
if len(args.Arguments) < 1 {
|
||||
return vm.NewGoError(errors.New("arguments need 1 given " + u.String(len(args.Arguments))))
|
||||
}
|
||||
if r, err := u.ReadDir(u.String(args.Arguments[0].Export())); err == nil {
|
||||
return vm.ToValue(r)
|
||||
} else {
|
||||
return vm.NewGoError(err)
|
||||
}
|
||||
},
|
||||
"stat": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||||
if len(args.Arguments) < 1 {
|
||||
return vm.NewGoError(errors.New("arguments need 1 given " + u.String(len(args.Arguments))))
|
||||
}
|
||||
return vm.ToValue(u.GetFileInfo(u.String(args.Arguments[0].Export())))
|
||||
},
|
||||
}
|
||||
}
|
238
js/js.go
Normal file
238
js/js.go
Normal file
@ -0,0 +1,238 @@
|
||||
package js
|
||||
|
||||
import (
|
||||
"apigo.cc/ai/ai"
|
||||
"apigo.cc/ai/ai/llm"
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/dop251/goja"
|
||||
"github.com/dop251/goja_nodejs/require"
|
||||
"github.com/ssgo/u"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
//go:embed lib/ai.ts
|
||||
var aiTS string
|
||||
|
||||
//go:embed lib/console.ts
|
||||
var consoleTS string
|
||||
|
||||
//go:embed lib/file.ts
|
||||
var fileTS string
|
||||
|
||||
func RunFile(file string, args ...any) (any, error) {
|
||||
return Run(u.ReadFileN(file), file, args...)
|
||||
}
|
||||
|
||||
func Run(code string, refFile string, args ...any) (any, error) {
|
||||
var r any
|
||||
js, err := StartFromCode(code, refFile)
|
||||
if err == nil {
|
||||
r, err = js.Run(args...)
|
||||
}
|
||||
return r, err
|
||||
}
|
||||
|
||||
var importModMatcher = regexp.MustCompile(`(?im)^\s*import\s+(.+?)\s+from\s+['"](.+?)['"]`)
|
||||
var importLibMatcher = regexp.MustCompile(`(?im)^\s*(import)\s+(.+?)\s+from\s+['"][./\\\w:]+lib[/\\](.+?)(\.ts)?['"]`)
|
||||
var requireLibMatcher = regexp.MustCompile(`(?im)^\s*(const|let|var)\s+(.+?)\s*=\s*require\s*\(\s*['"][./\\\w:]+lib[/\\](.+?)(\.ts)?['"]\s*\)`)
|
||||
var checkMainMatcher = regexp.MustCompile(`(?im)^\s*function\s+main\s*\(`)
|
||||
|
||||
type JS struct {
|
||||
vm *goja.Runtime
|
||||
required map[string]bool
|
||||
file string
|
||||
srcCode string
|
||||
code string
|
||||
}
|
||||
|
||||
func (js *JS) requireMod(name string) error {
|
||||
var err error
|
||||
if name == "console" || name == "" {
|
||||
if !js.required["console"] {
|
||||
js.required["console"] = true
|
||||
err = js.vm.Set("console", requireConsole())
|
||||
}
|
||||
}
|
||||
if err == nil && (name == "file" || name == "") {
|
||||
if !js.required["file"] {
|
||||
js.required["file"] = true
|
||||
err = js.vm.Set("file", requireFile())
|
||||
}
|
||||
}
|
||||
if err == nil && (name == "ai" || name == "") {
|
||||
if !js.required["ai"] {
|
||||
js.required["ai"] = true
|
||||
aiList := make(map[string]any)
|
||||
for name, lm := range llm.List() {
|
||||
aiList[name] = requireAI(lm)
|
||||
}
|
||||
err = js.vm.Set("ai", aiList)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (js *JS) makeImport(matcher *regexp.Regexp, code string) (string, int, error) {
|
||||
var modErr error
|
||||
importCount := 0
|
||||
code = matcher.ReplaceAllStringFunc(code, func(str string) string {
|
||||
if m := matcher.FindStringSubmatch(str); m != nil && len(m) > 3 {
|
||||
optName := m[1]
|
||||
if optName == "import" {
|
||||
optName = "let"
|
||||
}
|
||||
varName := m[2]
|
||||
modName := m[3]
|
||||
importCount++
|
||||
if modErr == nil {
|
||||
if err := js.requireMod(modName); err != nil {
|
||||
modErr = err
|
||||
}
|
||||
}
|
||||
if varName != modName {
|
||||
return fmt.Sprintf("%s %s = %s", optName, varName, modName)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
})
|
||||
return code, importCount, modErr
|
||||
}
|
||||
|
||||
func StartFromFile(file string) (*JS, error) {
|
||||
return StartFromCode(u.ReadFileN(file), file)
|
||||
}
|
||||
|
||||
func StartFromCode(code, refFile string) (*JS, error) {
|
||||
if refFile == "" {
|
||||
refFile = "main.js"
|
||||
}
|
||||
|
||||
if absFile, err := filepath.Abs(refFile); err == nil {
|
||||
refFile = absFile
|
||||
}
|
||||
|
||||
ai.InitFrom(filepath.Dir(refFile))
|
||||
|
||||
js := &JS{
|
||||
vm: goja.New(),
|
||||
required: map[string]bool{},
|
||||
file: refFile,
|
||||
srcCode: code,
|
||||
code: code,
|
||||
}
|
||||
|
||||
// 按需加载引用
|
||||
var importCount int
|
||||
var modErr error
|
||||
js.code, importCount, modErr = js.makeImport(importLibMatcher, js.code)
|
||||
if modErr == nil {
|
||||
importCount1 := importCount
|
||||
js.code, importCount, modErr = js.makeImport(requireLibMatcher, js.code)
|
||||
importCount += importCount1
|
||||
}
|
||||
|
||||
// 将 import 转换为 require
|
||||
js.code = importModMatcher.ReplaceAllString(js.code, "let $1 = require('$2')")
|
||||
|
||||
// 如果没有import,默认import所有
|
||||
if modErr == nil && importCount == 0 {
|
||||
modErr = js.requireMod("")
|
||||
}
|
||||
if modErr != nil {
|
||||
return nil, modErr
|
||||
}
|
||||
|
||||
//fmt.Println(u.BCyan(js.code))
|
||||
|
||||
// 处理模块引用
|
||||
require.NewRegistryWithLoader(func(path string) ([]byte, error) {
|
||||
refPath := filepath.Join(filepath.Dir(js.file), path)
|
||||
if !strings.HasSuffix(refPath, ".js") && !u.FileExists(refPath) {
|
||||
refPath += ".js"
|
||||
}
|
||||
modCode, err := u.ReadFile(refPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
modCode, _, _ = js.makeImport(importLibMatcher, modCode)
|
||||
modCode, _, _ = js.makeImport(requireLibMatcher, modCode)
|
||||
return []byte(modCode), modErr
|
||||
}).Enable(js.vm)
|
||||
|
||||
// 初始化主函数
|
||||
if !checkMainMatcher.MatchString(js.code) {
|
||||
js.code = "function main(...args){" + js.code + "}"
|
||||
}
|
||||
if _, err := js.vm.RunScript("main", js.code); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return js, nil
|
||||
}
|
||||
|
||||
func (js *JS) Run(args ...any) (any, error) {
|
||||
// 解析参数
|
||||
for i, arg := range args {
|
||||
if str, ok := arg.(string); ok {
|
||||
var v interface{}
|
||||
if err := json.Unmarshal([]byte(str), &v); err == nil {
|
||||
args[i] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := js.vm.Set("__args", args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jsResult, err := js.vm.RunScript(js.file, "main(...__args)")
|
||||
|
||||
var result any
|
||||
if err == nil {
|
||||
if jsResult != nil && !jsResult.Equals(goja.Undefined()) {
|
||||
result = jsResult.Export()
|
||||
}
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
type Exports struct {
|
||||
LLMList []string
|
||||
}
|
||||
|
||||
func ExportForDev() (string, error) {
|
||||
ai.Init()
|
||||
if len(llm.List()) == 0 && !u.FileExists("env.yml") && !u.FileExists("env.json") && !u.FileExists("llm.yml") && !u.FileExists("llm.json") {
|
||||
return "", errors.New("no llm config found, please run `ai -e` on env.yml or llm.yml path")
|
||||
}
|
||||
exports := Exports{}
|
||||
for name, _ := range llm.List() {
|
||||
exports.LLMList = append(exports.LLMList, name)
|
||||
}
|
||||
|
||||
exportFile := filepath.Join("lib", "ai.ts")
|
||||
var tpl *template.Template
|
||||
var err error
|
||||
if tpl, err = template.New(exportFile).Parse(aiTS); err == nil {
|
||||
buf := bytes.NewBuffer(make([]byte, 0))
|
||||
if err = tpl.Execute(buf, exports); err == nil {
|
||||
err = u.WriteFileBytes(exportFile, buf.Bytes())
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
_ = u.WriteFile(filepath.Join("lib", "console.ts"), consoleTS)
|
||||
_ = u.WriteFile(filepath.Join("lib", "file.ts"), fileTS)
|
||||
|
||||
return `import {` + strings.Join(exports.LLMList, ", ") + `} from './lib/ai'
|
||||
import console from './lib/console'
|
||||
import file from './lib/file'`, nil
|
||||
}
|
81
js/lib/ai.ts
Normal file
81
js/lib/ai.ts
Normal file
@ -0,0 +1,81 @@
|
||||
// just for develop
|
||||
|
||||
{{range .LLMList}}
|
||||
let {{.}}: LLM
|
||||
{{end}}
|
||||
|
||||
export default {
|
||||
{{range .LLMList}}
|
||||
{{.}},
|
||||
{{end}}
|
||||
}
|
||||
|
||||
interface ChatModelConfig {
|
||||
model: string
|
||||
ratio: number
|
||||
maxTokens: number
|
||||
temperature: number
|
||||
topP: number
|
||||
tools: Object
|
||||
}
|
||||
|
||||
interface ChatResult {
|
||||
result: string
|
||||
askTokens: number
|
||||
answerTokens: number
|
||||
totalTokens: number
|
||||
error: string
|
||||
}
|
||||
|
||||
interface GCResult {
|
||||
result: string
|
||||
preview: string
|
||||
results: Array<string>
|
||||
previews: Array<string>
|
||||
error: string
|
||||
}
|
||||
|
||||
interface Support {
|
||||
ask: boolean
|
||||
askWithImage: boolean
|
||||
askWithVideo: boolean
|
||||
askWithCodeInterpreter: boolean
|
||||
askWithWebSearch: boolean
|
||||
makeImage: boolean
|
||||
makeVideo: boolean
|
||||
models: Array<string>
|
||||
}
|
||||
|
||||
interface LLM {
|
||||
ask(messages: any, config?: ChatModelConfig, callback?: (answer: string) => void): ChatResult
|
||||
|
||||
fastAsk(messages: any, callback?: (answer: string) => void): ChatResult
|
||||
|
||||
longAsk(messages: any, callback?: (answer: string) => void): ChatResult
|
||||
|
||||
batterAsk(messages: any, callback?: (answer: string) => void): ChatResult
|
||||
|
||||
bestAsk(messages: any, callback?: (answer: string) => void): ChatResult
|
||||
|
||||
multiAsk(messages: any, callback?: (answer: string) => void): ChatResult
|
||||
|
||||
bestMultiAsk(messages: any, callback?: (answer: string) => void): ChatResult
|
||||
|
||||
codeInterpreterAsk(messages: any, callback?: (answer: string) => void): ChatResult
|
||||
|
||||
webSearchAsk(messages: any, callback?: (answer: string) => void): ChatResult
|
||||
|
||||
makeImage(model: string, prompt: string, size?: string, refImage?: string): GCResult
|
||||
|
||||
fastMakeImage(prompt: string, size?: string, refImage?: string): GCResult
|
||||
|
||||
bestMakeImage(prompt: string, size?: string, refImage?: string): GCResult
|
||||
|
||||
makeVideo(arg2: string, arg3: string, arg4: string, arg5: string): GCResult
|
||||
|
||||
fastMakeVideo(prompt: string, size?: string, refImage?: string): GCResult
|
||||
|
||||
bestMakeVideo(prompt: string, size?: string, refImage?: string): GCResult
|
||||
|
||||
support: Support
|
||||
}
|
37
js/lib/console.ts
Normal file
37
js/lib/console.ts
Normal file
@ -0,0 +1,37 @@
|
||||
// just for develop
|
||||
|
||||
export default {
|
||||
print,
|
||||
println,
|
||||
log,
|
||||
debug,
|
||||
info,
|
||||
warn,
|
||||
error,
|
||||
input,
|
||||
}
|
||||
|
||||
function print(...data: any[]): void {
|
||||
}
|
||||
|
||||
function println(...data: any[]): void {
|
||||
}
|
||||
|
||||
function log(...data: any[]): void {
|
||||
}
|
||||
|
||||
function debug(...data: any[]): void {
|
||||
}
|
||||
|
||||
function info(...data: any[]): void {
|
||||
}
|
||||
|
||||
function warn(...data: any[]): void {
|
||||
}
|
||||
|
||||
function error(...data: any[]): void {
|
||||
}
|
||||
|
||||
function input(...data: any[]): string {
|
||||
return ''
|
||||
}
|
32
js/lib/file.ts
Normal file
32
js/lib/file.ts
Normal file
@ -0,0 +1,32 @@
|
||||
// just for develop
|
||||
|
||||
export default {
|
||||
read,
|
||||
write,
|
||||
dir,
|
||||
stat
|
||||
}
|
||||
|
||||
function read(filename: string): string {
|
||||
return ''
|
||||
}
|
||||
|
||||
function write(filename: string, data: any): void {
|
||||
}
|
||||
|
||||
|
||||
function dir(filename: string): Array<FileInfo> {
|
||||
return null
|
||||
}
|
||||
|
||||
function stat(filename: string): FileInfo {
|
||||
return null
|
||||
}
|
||||
|
||||
interface FileInfo {
|
||||
Name: string
|
||||
FullName: string
|
||||
IsDir: boolean
|
||||
Size: number
|
||||
ModTime: number
|
||||
}
|
@ -1,15 +1,68 @@
|
||||
package ai
|
||||
package llm
|
||||
|
||||
import "apigo.cc/ai/agent"
|
||||
type ChatMessage struct {
|
||||
Role string
|
||||
Contents []ChatMessageContent
|
||||
}
|
||||
|
||||
type ChatMessage = agent.ChatMessage
|
||||
type ChatMessageContent = agent.ChatMessageContent
|
||||
type ChatModelConfig struct {
|
||||
Model string
|
||||
MaxTokens int
|
||||
Temperature float64
|
||||
TopP float64
|
||||
Tools map[string]any
|
||||
type ChatMessageContent struct {
|
||||
Type string // text, image, audio, video
|
||||
Content string
|
||||
}
|
||||
|
||||
type ChatConfig struct {
|
||||
defaultConfig *ChatConfig
|
||||
Model string
|
||||
Ratio float64
|
||||
MaxTokens int
|
||||
Temperature float64
|
||||
TopP float64
|
||||
Tools map[string]any
|
||||
}
|
||||
|
||||
func (chatConfig *ChatConfig) SetDefault(config *ChatConfig) {
|
||||
chatConfig.defaultConfig = config
|
||||
}
|
||||
|
||||
func (chatConfig *ChatConfig) GetModel() string {
|
||||
if chatConfig.Model == "" && chatConfig.defaultConfig != nil {
|
||||
return chatConfig.defaultConfig.Model
|
||||
}
|
||||
return chatConfig.Model
|
||||
}
|
||||
|
||||
func (chatConfig *ChatConfig) GetMaxTokens() int {
|
||||
if chatConfig.MaxTokens == 0 && chatConfig.defaultConfig != nil {
|
||||
return chatConfig.defaultConfig.MaxTokens
|
||||
}
|
||||
return chatConfig.MaxTokens
|
||||
}
|
||||
|
||||
func (chatConfig *ChatConfig) GetTemperature() float64 {
|
||||
if chatConfig.Temperature == 0 && chatConfig.defaultConfig != nil {
|
||||
return chatConfig.defaultConfig.Temperature
|
||||
}
|
||||
return chatConfig.Temperature
|
||||
}
|
||||
|
||||
func (chatConfig *ChatConfig) GetTopP() float64 {
|
||||
if chatConfig.TopP == 0 && chatConfig.defaultConfig != nil {
|
||||
return chatConfig.defaultConfig.TopP
|
||||
}
|
||||
return chatConfig.TopP
|
||||
}
|
||||
|
||||
func (chatConfig *ChatConfig) GetTools() map[string]any {
|
||||
if chatConfig.Tools == nil && chatConfig.defaultConfig != nil {
|
||||
return chatConfig.defaultConfig.Tools
|
||||
}
|
||||
return chatConfig.Tools
|
||||
}
|
||||
|
||||
type TokenUsage struct {
|
||||
AskTokens int64
|
||||
AnswerTokens int64
|
||||
TotalTokens int64
|
||||
}
|
||||
|
||||
type MessagesMaker struct {
|
33
llm/gc.go
Normal file
33
llm/gc.go
Normal file
@ -0,0 +1,33 @@
|
||||
package llm
|
||||
|
||||
type GCConfig struct {
|
||||
defaultConfig *GCConfig
|
||||
Model string
|
||||
Size string
|
||||
Ref string
|
||||
}
|
||||
|
||||
func (gcConfig *GCConfig) SetDefault(config *GCConfig) {
|
||||
gcConfig.defaultConfig = config
|
||||
}
|
||||
|
||||
func (gcConfig *GCConfig) GetModel() string {
|
||||
if gcConfig.Model == "" && gcConfig.defaultConfig != nil {
|
||||
return gcConfig.defaultConfig.Model
|
||||
}
|
||||
return gcConfig.Model
|
||||
}
|
||||
|
||||
func (gcConfig *GCConfig) GetSize() string {
|
||||
if gcConfig.Size == "" && gcConfig.defaultConfig != nil {
|
||||
return gcConfig.defaultConfig.Size
|
||||
}
|
||||
return gcConfig.Size
|
||||
}
|
||||
|
||||
func (gcConfig *GCConfig) GetRef() string {
|
||||
if gcConfig.Ref == "" && gcConfig.defaultConfig != nil {
|
||||
return gcConfig.defaultConfig.Ref
|
||||
}
|
||||
return gcConfig.Ref
|
||||
}
|
96
llm/llm.go
Normal file
96
llm/llm.go
Normal file
@ -0,0 +1,96 @@
|
||||
package llm
|
||||
|
||||
import "sync"
|
||||
|
||||
const (
|
||||
TypeText = "text"
|
||||
TypeImage = "image"
|
||||
TypeVideo = "video"
|
||||
RoleSystem = "system"
|
||||
RoleUser = "user"
|
||||
RoleAssistant = "assistant"
|
||||
RoleTool = "tool"
|
||||
ToolCodeInterpreter = "codeInterpreter"
|
||||
ToolWebSearch = "webSearch"
|
||||
)
|
||||
|
||||
type Support struct {
|
||||
Ask bool
|
||||
AskWithImage bool
|
||||
AskWithVideo bool
|
||||
AskWithCodeInterpreter bool
|
||||
AskWithWebSearch bool
|
||||
MakeImage bool
|
||||
MakeVideo bool
|
||||
Models []string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Endpoint string
|
||||
ApiKey string
|
||||
ChatConfig ChatConfig
|
||||
GCConfig GCConfig
|
||||
}
|
||||
|
||||
type LLM interface {
|
||||
Support() Support
|
||||
Ask(messages []ChatMessage, config ChatConfig, callback func(answer string)) (string, TokenUsage, error)
|
||||
FastAsk(messages []ChatMessage, callback func(answer string)) (string, TokenUsage, error)
|
||||
LongAsk(messages []ChatMessage, callback func(answer string)) (string, TokenUsage, error)
|
||||
BatterAsk(messages []ChatMessage, callback func(answer string)) (string, TokenUsage, error)
|
||||
BestAsk(messages []ChatMessage, callback func(answer string)) (string, TokenUsage, error)
|
||||
MultiAsk(messages []ChatMessage, callback func(answer string)) (string, TokenUsage, error)
|
||||
BestMultiAsk(messages []ChatMessage, callback func(answer string)) (string, TokenUsage, error)
|
||||
CodeInterpreterAsk(messages []ChatMessage, callback func(answer string)) (string, TokenUsage, error)
|
||||
WebSearchAsk(messages []ChatMessage, callback func(answer string)) (string, TokenUsage, error)
|
||||
MakeImage(prompt string, config GCConfig) ([]string, error)
|
||||
FastMakeImage(prompt string, config GCConfig) ([]string, error)
|
||||
BestMakeImage(prompt string, config GCConfig) ([]string, error)
|
||||
MakeVideo(prompt string, config GCConfig) ([]string, []string, error)
|
||||
FastMakeVideo(prompt string, config GCConfig) ([]string, []string, error)
|
||||
BestMakeVideo(prompt string, config GCConfig) ([]string, []string, error)
|
||||
}
|
||||
|
||||
var llmMakers = map[string]func(Config) LLM{}
|
||||
var llmMakersLock = sync.RWMutex{}
|
||||
|
||||
var llms = map[string]LLM{}
|
||||
var llmsLock = sync.RWMutex{}
|
||||
|
||||
func Register(llmId string, maker func(Config) LLM) {
|
||||
llmMakersLock.Lock()
|
||||
llmMakers[llmId] = maker
|
||||
llmMakersLock.Unlock()
|
||||
}
|
||||
|
||||
func Create(name, llmId string, config Config) LLM {
|
||||
llmMakersLock.RLock()
|
||||
maker := llmMakers[llmId]
|
||||
llmMakersLock.RUnlock()
|
||||
|
||||
if maker != nil {
|
||||
llm := maker(config)
|
||||
llmsLock.Lock()
|
||||
llms[name] = llm
|
||||
llmsLock.Unlock()
|
||||
return llm
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func Get(name string) LLM {
|
||||
llmsLock.RLock()
|
||||
llm := llms[name]
|
||||
llmsLock.RUnlock()
|
||||
return llm
|
||||
}
|
||||
|
||||
func List() map[string]LLM {
|
||||
list := map[string]LLM{}
|
||||
llmsLock.RLock()
|
||||
for name, llm := range llms {
|
||||
list[name] = llm
|
||||
}
|
||||
llmsLock.RUnlock()
|
||||
return list
|
||||
}
|
165
llm/openai/chat.go
Normal file
165
llm/openai/chat.go
Normal file
@ -0,0 +1,165 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"apigo.cc/ai/ai/llm"
|
||||
"context"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/ssgo/log"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (lm *LLM) FastAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
|
||||
return lm.Ask(messages, llm.ChatConfig{
|
||||
Model: ModelGPT_4o_mini_2024_07_18,
|
||||
}, callback)
|
||||
}
|
||||
|
||||
func (lm *LLM) LongAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
|
||||
return lm.Ask(messages, llm.ChatConfig{
|
||||
Model: ModelGPT_4_32k_0613,
|
||||
}, callback)
|
||||
}
|
||||
|
||||
func (lm *LLM) BatterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
|
||||
return lm.Ask(messages, llm.ChatConfig{
|
||||
Model: ModelGPT_4_turbo,
|
||||
}, callback)
|
||||
}
|
||||
|
||||
func (lm *LLM) BestAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
|
||||
return lm.Ask(messages, llm.ChatConfig{
|
||||
Model: ModelGPT_4o_2024_08_06,
|
||||
}, callback)
|
||||
}
|
||||
|
||||
func (lm *LLM) MultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
|
||||
return lm.Ask(messages, llm.ChatConfig{
|
||||
Model: ModelGPT_4o_mini_2024_07_18,
|
||||
}, callback)
|
||||
}
|
||||
|
||||
func (lm *LLM) BestMultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
|
||||
return lm.Ask(messages, llm.ChatConfig{
|
||||
Model: ModelGPT_4o_2024_08_06,
|
||||
}, callback)
|
||||
}
|
||||
|
||||
func (lm *LLM) CodeInterpreterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
|
||||
return lm.Ask(messages, llm.ChatConfig{
|
||||
Model: ModelGPT_4o,
|
||||
Tools: map[string]any{llm.ToolCodeInterpreter: nil},
|
||||
}, callback)
|
||||
}
|
||||
|
||||
func (lm *LLM) WebSearchAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
|
||||
return lm.Ask(messages, llm.ChatConfig{
|
||||
Model: ModelGPT_4o_mini_2024_07_18,
|
||||
Tools: map[string]any{llm.ToolWebSearch: nil},
|
||||
}, callback)
|
||||
}
|
||||
|
||||
func (lm *LLM) Ask(messages []llm.ChatMessage, config llm.ChatConfig, callback func(answer string)) (string, llm.TokenUsage, error) {
|
||||
openaiConf := openai.DefaultConfig(lm.config.ApiKey)
|
||||
if lm.config.Endpoint != "" {
|
||||
openaiConf.BaseURL = lm.config.Endpoint
|
||||
}
|
||||
|
||||
config.SetDefault(&lm.config.ChatConfig)
|
||||
|
||||
agentMessages := make([]openai.ChatCompletionMessage, len(messages))
|
||||
for i, msg := range messages {
|
||||
var contents []openai.ChatMessagePart
|
||||
if msg.Contents != nil {
|
||||
contents = make([]openai.ChatMessagePart, len(msg.Contents))
|
||||
for j, inPart := range msg.Contents {
|
||||
part := openai.ChatMessagePart{}
|
||||
part.Type = TypeMap[inPart.Type]
|
||||
switch inPart.Type {
|
||||
case llm.TypeText:
|
||||
part.Text = inPart.Content
|
||||
case llm.TypeImage:
|
||||
part.ImageURL = &openai.ChatMessageImageURL{
|
||||
URL: inPart.Content,
|
||||
Detail: openai.ImageURLDetailAuto,
|
||||
}
|
||||
}
|
||||
contents[j] = part
|
||||
}
|
||||
}
|
||||
agentMessages[i] = openai.ChatCompletionMessage{
|
||||
Role: RoleMap[msg.Role],
|
||||
MultiContent: contents,
|
||||
}
|
||||
}
|
||||
|
||||
opt := openai.ChatCompletionRequest{
|
||||
Model: config.GetModel(),
|
||||
Messages: agentMessages,
|
||||
MaxTokens: config.GetMaxTokens(),
|
||||
Temperature: float32(config.GetTemperature()),
|
||||
TopP: float32(config.GetTopP()),
|
||||
StreamOptions: &openai.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name := range config.GetTools() {
|
||||
switch name {
|
||||
case llm.ToolCodeInterpreter:
|
||||
opt.Tools = append(opt.Tools, openai.Tool{Type: "code_interpreter"})
|
||||
case llm.ToolWebSearch:
|
||||
}
|
||||
}
|
||||
|
||||
c := openai.NewClientWithConfig(openaiConf)
|
||||
if callback != nil {
|
||||
opt.Stream = true
|
||||
r, err := c.CreateChatCompletionStream(context.Background(), opt)
|
||||
if err == nil {
|
||||
results := make([]string, 0)
|
||||
usage := llm.TokenUsage{}
|
||||
for {
|
||||
if r2, err := r.Recv(); err == nil {
|
||||
if r2.Choices != nil {
|
||||
for _, ch := range r2.Choices {
|
||||
text := ch.Delta.Content
|
||||
callback(text)
|
||||
results = append(results, text)
|
||||
}
|
||||
}
|
||||
if r2.Usage != nil {
|
||||
usage.AskTokens += int64(r2.Usage.PromptTokens)
|
||||
usage.AnswerTokens += int64(r2.Usage.CompletionTokens)
|
||||
usage.TotalTokens += int64(r2.Usage.TotalTokens)
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
_ = r.Close()
|
||||
return strings.Join(results, ""), usage, nil
|
||||
} else {
|
||||
log.DefaultLogger.Error(err.Error())
|
||||
return "", llm.TokenUsage{}, err
|
||||
}
|
||||
} else {
|
||||
r, err := c.CreateChatCompletion(context.Background(), opt)
|
||||
|
||||
if err == nil {
|
||||
results := make([]string, 0)
|
||||
if r.Choices != nil {
|
||||
for _, ch := range r.Choices {
|
||||
results = append(results, ch.Message.Content)
|
||||
}
|
||||
}
|
||||
return strings.Join(results, ""), llm.TokenUsage{
|
||||
AskTokens: int64(r.Usage.PromptTokens),
|
||||
AnswerTokens: int64(r.Usage.CompletionTokens),
|
||||
TotalTokens: int64(r.Usage.TotalTokens),
|
||||
}, nil
|
||||
} else {
|
||||
//fmt.Println(u.BMagenta(err.Error()), u.BMagenta(u.JsonP(r)))
|
||||
return "", llm.TokenUsage{}, err
|
||||
}
|
||||
}
|
||||
}
|
81
llm/openai/config.go
Normal file
81
llm/openai/config.go
Normal file
@ -0,0 +1,81 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"apigo.cc/ai/ai/llm"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type LLM struct {
|
||||
config llm.Config
|
||||
}
|
||||
|
||||
var TypeMap = map[string]openai.ChatMessagePartType{
|
||||
llm.TypeText: openai.ChatMessagePartTypeText,
|
||||
llm.TypeImage: openai.ChatMessagePartTypeImageURL,
|
||||
//llm.TypeVideo: "video_url",
|
||||
}
|
||||
var RoleMap = map[string]string{
|
||||
llm.RoleSystem: openai.ChatMessageRoleSystem,
|
||||
llm.RoleUser: openai.ChatMessageRoleUser,
|
||||
llm.RoleAssistant: openai.ChatMessageRoleAssistant,
|
||||
llm.RoleTool: openai.ChatMessageRoleTool,
|
||||
}
|
||||
|
||||
const (
|
||||
ModelGPT_4_32k_0613 = "gpt-4-32k-0613"
|
||||
ModelGPT_4_32k_0314 = "gpt-4-32k-0314"
|
||||
ModelGPT_4_32k = "gpt-4-32k"
|
||||
ModelGPT_4_0613 = "gpt-4-0613"
|
||||
ModelGPT_4_0314 = "gpt-4-0314"
|
||||
ModelGPT_4o = "gpt-4o"
|
||||
ModelGPT_4o_2024_05_13 = "gpt-4o-2024-05-13"
|
||||
ModelGPT_4o_2024_08_06 = "gpt-4o-2024-08-06"
|
||||
ModelGPT_4o_mini = "gpt-4o-mini"
|
||||
ModelGPT_4o_mini_2024_07_18 = "gpt-4o-mini-2024-07-18"
|
||||
ModelGPT_4_turbo = "gpt-4-turbo"
|
||||
ModelGPT_4_turbo_2024_04_09 = "gpt-4-turbo-2024-04-09"
|
||||
ModelGPT_4_0125_preview = "gpt-4-0125-preview"
|
||||
ModelGPT_4_1106_preview = "gpt-4-1106-preview"
|
||||
ModelGPT_4_turbo_preview = "gpt-4-turbo-preview"
|
||||
ModelGPT_4_vision_preview = "gpt-4-vision-preview"
|
||||
ModelGPT_4 = "gpt-4"
|
||||
ModelGPT_3_5_turbo_0125 = "gpt-3.5-turbo-0125"
|
||||
ModelGPT_3_5_turbo_1106 = "gpt-3.5-turbo-1106"
|
||||
ModelGPT_3_5_turbo_0613 = "gpt-3.5-turbo-0613"
|
||||
ModelGPT_3_5_turbo_0301 = "gpt-3.5-turbo-0301"
|
||||
ModelGPT_3_5_turbo_16k = "gpt-3.5-turbo-16k"
|
||||
ModelGPT_3_5_turbo_16k_0613 = "gpt-3.5-turbo-16k-0613"
|
||||
ModelGPT_3_5_turbo = "gpt-3.5-turbo"
|
||||
ModelGPT_3_5_turbo_instruct = "gpt-3.5-turbo-instruct"
|
||||
ModelDavinci_002 = "davinci-002"
|
||||
ModelCurie = "curie"
|
||||
ModelCurie_002 = "curie-002"
|
||||
ModelAda_002 = "ada-002"
|
||||
ModelBabbage_002 = "babbage-002"
|
||||
ModelCode_davinci_002 = "code-davinci-002"
|
||||
ModelCode_cushman_001 = "code-cushman-001"
|
||||
ModelCode_davinci_001 = "code-davinci-001"
|
||||
ModelDallE2Std = "dall-e-2"
|
||||
ModelDallE2HD = "dall-e-2-hd"
|
||||
ModelDallE3Std = "dall-e-3"
|
||||
ModelDallE3HD = "dall-e-3-hd"
|
||||
)
|
||||
|
||||
func (ag *LLM) Support() llm.Support {
|
||||
return llm.Support{
|
||||
Ask: true,
|
||||
AskWithImage: true,
|
||||
AskWithVideo: false,
|
||||
AskWithCodeInterpreter: true,
|
||||
AskWithWebSearch: false,
|
||||
MakeImage: true,
|
||||
MakeVideo: false,
|
||||
Models: []string{ModelGPT_4_32k_0613, ModelGPT_4_32k_0314, ModelGPT_4_32k, ModelGPT_4_0613, ModelGPT_4_0314, ModelGPT_4o, ModelGPT_4o_2024_05_13, ModelGPT_4o_2024_08_06, ModelGPT_4o_mini, ModelGPT_4o_mini_2024_07_18, ModelGPT_4_turbo, ModelGPT_4_turbo_2024_04_09, ModelGPT_4_0125_preview, ModelGPT_4_1106_preview, ModelGPT_4_turbo_preview, ModelGPT_4_vision_preview, ModelGPT_4, ModelGPT_3_5_turbo_0125, ModelGPT_3_5_turbo_1106, ModelGPT_3_5_turbo_0613, ModelGPT_3_5_turbo_0301, ModelGPT_3_5_turbo_16k, ModelGPT_3_5_turbo_16k_0613, ModelGPT_3_5_turbo, ModelGPT_3_5_turbo_instruct, ModelDavinci_002, ModelCurie, ModelCurie_002, ModelAda_002, ModelBabbage_002, ModelCode_davinci_002, ModelCode_cushman_001, ModelCode_davinci_001, ModelDallE2Std, ModelDallE2HD, ModelDallE3Std, ModelDallE3HD},
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
llm.Register("openai", func(config llm.Config) llm.LLM {
|
||||
return &LLM{config: config}
|
||||
})
|
||||
}
|
73
llm/openai/gc.go
Normal file
73
llm/openai/gc.go
Normal file
@ -0,0 +1,73 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"apigo.cc/ai/ai/llm"
|
||||
"context"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// func (lm *LLM) FastMakeImage(prompt, size, refImage string) ([]string, error) {
|
||||
// return lm.MakeImage(ModelDallE3Std, prompt, size, refImage)
|
||||
// }
|
||||
//
|
||||
// func (lm *LLM) BestMakeImage(prompt, size, refImage string) ([]string, error) {
|
||||
// return lm.MakeImage(ModelDallE3HD, prompt, size, refImage)
|
||||
// }
|
||||
//
|
||||
// func (lm *LLM) MakeImage(model, prompt, size, refImage string) ([]string, error) {
|
||||
func (lm *LLM) FastMakeImage(prompt string, config llm.GCConfig) ([]string, error) {
|
||||
config.Model = ModelDallE3Std
|
||||
return lm.MakeImage(prompt, config)
|
||||
}
|
||||
|
||||
func (lm *LLM) BestMakeImage(prompt string, config llm.GCConfig) ([]string, error) {
|
||||
config.Model = ModelDallE3HD
|
||||
return lm.MakeImage(prompt, config)
|
||||
}
|
||||
|
||||
func (lm *LLM) MakeImage(prompt string, config llm.GCConfig) ([]string, error) {
|
||||
openaiConf := openai.DefaultConfig(lm.config.ApiKey)
|
||||
if lm.config.Endpoint != "" {
|
||||
openaiConf.BaseURL = lm.config.Endpoint
|
||||
}
|
||||
c := openai.NewClientWithConfig(openaiConf)
|
||||
style := openai.CreateImageStyleVivid
|
||||
if (!strings.Contains(prompt, "vivid") || !strings.Contains(prompt, "生动的")) && (strings.Contains(prompt, "natural") || strings.Contains(prompt, "自然的")) {
|
||||
style = openai.CreateImageStyleNatural
|
||||
}
|
||||
quality := openai.CreateImageQualityStandard
|
||||
if strings.HasSuffix(config.Model, "-hd") {
|
||||
quality = openai.CreateImageQualityHD
|
||||
config.Model = config.Model[0 : len(config.Model)-3]
|
||||
}
|
||||
r, err := c.CreateImage(context.Background(), openai.ImageRequest{
|
||||
Prompt: prompt,
|
||||
Model: config.Model,
|
||||
Quality: quality,
|
||||
Size: config.Size,
|
||||
Style: style,
|
||||
ResponseFormat: openai.CreateImageResponseFormatURL,
|
||||
})
|
||||
if err == nil {
|
||||
results := make([]string, 0)
|
||||
for _, item := range r.Data {
|
||||
results = append(results, item.URL)
|
||||
}
|
||||
return results, nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (lm *LLM) FastMakeVideo(prompt string, config llm.GCConfig) ([]string, []string, error) {
|
||||
return lm.MakeVideo(prompt, config)
|
||||
}
|
||||
|
||||
func (lm *LLM) BestMakeVideo(prompt string, config llm.GCConfig) ([]string, []string, error) {
|
||||
return lm.MakeVideo(prompt, config)
|
||||
}
|
||||
|
||||
func (lm *LLM) MakeVideo(prompt string, config llm.GCConfig) ([]string, []string, error) {
|
||||
return nil, nil, nil
|
||||
}
|
137
llm/zhipu/chat.go
Normal file
137
llm/zhipu/chat.go
Normal file
@ -0,0 +1,137 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"apigo.cc/ai/ai/llm"
|
||||
"apigo.cc/ai/ai/llm/zhipu/zhipu"
|
||||
"context"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (lm *LLM) FastAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
|
||||
return lm.Ask(messages, llm.ChatConfig{
|
||||
Model: ModelGLM4Flash,
|
||||
}, callback)
|
||||
}
|
||||
|
||||
func (lm *LLM) LongAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
|
||||
return lm.Ask(messages, llm.ChatConfig{
|
||||
Model: ModelGLM4Long,
|
||||
}, callback)
|
||||
}
|
||||
|
||||
func (lm *LLM) BatterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
|
||||
return lm.Ask(messages, llm.ChatConfig{
|
||||
Model: ModelGLM4Plus,
|
||||
}, callback)
|
||||
}
|
||||
|
||||
func (lm *LLM) BestAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
|
||||
return lm.Ask(messages, llm.ChatConfig{
|
||||
Model: ModelGLM40520,
|
||||
}, callback)
|
||||
}
|
||||
|
||||
func (lm *LLM) MultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
|
||||
return lm.Ask(messages, llm.ChatConfig{
|
||||
Model: ModelGLM4VPlus,
|
||||
}, callback)
|
||||
}
|
||||
|
||||
func (lm *LLM) BestMultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
|
||||
return lm.Ask(messages, llm.ChatConfig{
|
||||
Model: ModelGLM4V,
|
||||
}, callback)
|
||||
}
|
||||
|
||||
func (lm *LLM) CodeInterpreterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
|
||||
return lm.Ask(messages, llm.ChatConfig{
|
||||
Model: ModelGLM4AllTools,
|
||||
Tools: map[string]any{llm.ToolCodeInterpreter: nil},
|
||||
}, callback)
|
||||
}
|
||||
|
||||
func (lm *LLM) WebSearchAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) {
|
||||
return lm.Ask(messages, llm.ChatConfig{
|
||||
Model: ModelGLM4AllTools,
|
||||
Tools: map[string]any{llm.ToolWebSearch: nil},
|
||||
}, callback)
|
||||
}
|
||||
|
||||
func (lm *LLM) Ask(messages []llm.ChatMessage, config llm.ChatConfig, callback func(answer string)) (string, llm.TokenUsage, error) {
|
||||
config.SetDefault(&lm.config.ChatConfig)
|
||||
c, err := zhipu.NewClient(zhipu.WithAPIKey(lm.config.ApiKey), zhipu.WithBaseURL(lm.config.Endpoint))
|
||||
if err != nil {
|
||||
return "", llm.TokenUsage{}, err
|
||||
}
|
||||
|
||||
cc := c.ChatCompletion(config.GetModel())
|
||||
for _, msg := range messages {
|
||||
var contents []zhipu.ChatCompletionMultiContent
|
||||
if msg.Contents != nil {
|
||||
contents = make([]zhipu.ChatCompletionMultiContent, len(msg.Contents))
|
||||
for j, inPart := range msg.Contents {
|
||||
part := zhipu.ChatCompletionMultiContent{}
|
||||
part.Type = NameMap[inPart.Type]
|
||||
switch inPart.Type {
|
||||
case llm.TypeText:
|
||||
part.Text = inPart.Content
|
||||
case llm.TypeImage:
|
||||
part.ImageURL = &zhipu.URLItem{URL: inPart.Content}
|
||||
case llm.TypeVideo:
|
||||
part.VideoURL = &zhipu.URLItem{URL: inPart.Content}
|
||||
}
|
||||
contents[j] = part
|
||||
}
|
||||
}
|
||||
cc.AddMessage(zhipu.ChatCompletionMultiMessage{
|
||||
Role: NameMap[msg.Role],
|
||||
Content: contents,
|
||||
})
|
||||
}
|
||||
|
||||
for name := range config.GetTools() {
|
||||
switch name {
|
||||
case llm.ToolCodeInterpreter:
|
||||
cc.AddTool(zhipu.ChatCompletionToolCodeInterpreter{})
|
||||
case llm.ToolWebSearch:
|
||||
cc.AddTool(zhipu.ChatCompletionToolWebBrowser{})
|
||||
}
|
||||
}
|
||||
|
||||
if config.GetMaxTokens() != 0 {
|
||||
cc.SetMaxTokens(config.GetMaxTokens())
|
||||
}
|
||||
if config.GetTemperature() != 0 {
|
||||
cc.SetTemperature(config.GetTemperature())
|
||||
}
|
||||
if config.GetTopP() != 0 {
|
||||
cc.SetTopP(config.GetTopP())
|
||||
}
|
||||
if callback != nil {
|
||||
cc.SetStreamHandler(func(r2 zhipu.ChatCompletionResponse) error {
|
||||
if r2.Choices != nil {
|
||||
for _, ch := range r2.Choices {
|
||||
text := ch.Delta.Content
|
||||
callback(text)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if r, err := cc.Do(context.Background()); err == nil {
|
||||
results := make([]string, 0)
|
||||
if r.Choices != nil {
|
||||
for _, ch := range r.Choices {
|
||||
results = append(results, ch.Message.Content)
|
||||
}
|
||||
}
|
||||
return strings.Join(results, ""), llm.TokenUsage{
|
||||
AskTokens: r.Usage.PromptTokens,
|
||||
AnswerTokens: r.Usage.CompletionTokens,
|
||||
TotalTokens: r.Usage.TotalTokens,
|
||||
}, nil
|
||||
} else {
|
||||
return "", llm.TokenUsage{}, err
|
||||
}
|
||||
}
|
60
llm/zhipu/config.go
Normal file
60
llm/zhipu/config.go
Normal file
@ -0,0 +1,60 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"apigo.cc/ai/ai/llm"
|
||||
"apigo.cc/ai/ai/llm/zhipu/zhipu"
|
||||
)
|
||||
|
||||
type LLM struct {
|
||||
config llm.Config
|
||||
}
|
||||
|
||||
var NameMap = map[string]string{
|
||||
llm.TypeText: zhipu.MultiContentTypeText,
|
||||
llm.TypeImage: zhipu.MultiContentTypeImageURL,
|
||||
llm.TypeVideo: zhipu.MultiContentTypeVideoURL,
|
||||
llm.RoleSystem: zhipu.RoleSystem,
|
||||
llm.RoleUser: zhipu.RoleUser,
|
||||
llm.RoleAssistant: zhipu.RoleAssistant,
|
||||
llm.RoleTool: zhipu.RoleTool,
|
||||
}
|
||||
|
||||
const (
|
||||
ModelGLM4Plus = "GLM-4-Plus"
|
||||
ModelGLM40520 = "GLM-4-0520"
|
||||
ModelGLM4Long = "GLM-4-Long"
|
||||
ModelGLM4AirX = "GLM-4-AirX"
|
||||
ModelGLM4Air = "GLM-4-Air"
|
||||
ModelGLM4Flash = "GLM-4-Flash"
|
||||
ModelGLM4AllTools = "GLM-4-AllTools"
|
||||
ModelGLM4 = "GLM-4"
|
||||
ModelGLM4VPlus = "GLM-4V-Plus"
|
||||
ModelGLM4V = "GLM-4V"
|
||||
ModelCogVideoX = "CogVideoX"
|
||||
ModelCogView3Plus = "CogView-3-Plus"
|
||||
ModelCogView3 = "CogView-3"
|
||||
ModelEmbedding3 = "Embedding-3"
|
||||
ModelEmbedding2 = "Embedding-2"
|
||||
ModelCharGLM3 = "CharGLM-3"
|
||||
ModelEmohaa = "Emohaa"
|
||||
ModelCodeGeeX4 = "CodeGeeX-4"
|
||||
)
|
||||
|
||||
func (lm *LLM) Support() llm.Support {
|
||||
return llm.Support{
|
||||
Ask: true,
|
||||
AskWithImage: true,
|
||||
AskWithVideo: true,
|
||||
AskWithCodeInterpreter: true,
|
||||
AskWithWebSearch: true,
|
||||
MakeImage: true,
|
||||
MakeVideo: true,
|
||||
Models: []string{ModelGLM4Plus, ModelGLM40520, ModelGLM4Long, ModelGLM4AirX, ModelGLM4Air, ModelGLM4Flash, ModelGLM4AllTools, ModelGLM4, ModelGLM4VPlus, ModelGLM4V, ModelCogVideoX, ModelCogView3Plus, ModelCogView3, ModelEmbedding3, ModelEmbedding2, ModelCharGLM3, ModelEmohaa, ModelCodeGeeX4},
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
llm.Register("zhipu", func(config llm.Config) llm.LLM {
|
||||
return &LLM{config: config}
|
||||
})
|
||||
}
|
88
llm/zhipu/gc.go
Normal file
88
llm/zhipu/gc.go
Normal file
@ -0,0 +1,88 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"apigo.cc/ai/ai/llm"
|
||||
"apigo.cc/ai/ai/llm/zhipu/zhipu"
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (lm *LLM) FastMakeImage(prompt string, config llm.GCConfig) ([]string, error) {
|
||||
config.Model = ModelCogView3Plus
|
||||
return lm.MakeImage(prompt, config)
|
||||
}
|
||||
|
||||
func (lm *LLM) BestMakeImage(prompt string, config llm.GCConfig) ([]string, error) {
|
||||
config.Model = ModelCogView3
|
||||
return lm.MakeImage(prompt, config)
|
||||
}
|
||||
|
||||
func (lm *LLM) MakeImage(prompt string, config llm.GCConfig) ([]string, error) {
|
||||
c, err := zhipu.NewClient(zhipu.WithAPIKey(lm.config.ApiKey), zhipu.WithBaseURL(lm.config.Endpoint))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cc := c.ImageGeneration(config.Model).SetPrompt(prompt)
|
||||
if config.Size != "" {
|
||||
cc.SetSize(config.Size)
|
||||
}
|
||||
|
||||
if r, err := cc.Do(context.Background()); err == nil {
|
||||
results := make([]string, 0)
|
||||
for _, item := range r.Data {
|
||||
results = append(results, item.URL)
|
||||
}
|
||||
return results, nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (lm *LLM) FastMakeVideo(prompt string, config llm.GCConfig) ([]string, []string, error) {
|
||||
config.Model = ModelCogVideoX
|
||||
return lm.MakeVideo(prompt, config)
|
||||
}
|
||||
|
||||
func (lm *LLM) BestMakeVideo(prompt string, config llm.GCConfig) ([]string, []string, error) {
|
||||
config.Model = ModelCogVideoX
|
||||
return lm.MakeVideo(prompt, config)
|
||||
}
|
||||
|
||||
func (lm *LLM) MakeVideo(prompt string, config llm.GCConfig) ([]string, []string, error) {
|
||||
c, err := zhipu.NewClient(zhipu.WithAPIKey(lm.config.ApiKey), zhipu.WithBaseURL(lm.config.Endpoint))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cc := c.VideoGeneration(config.Model).SetPrompt(prompt)
|
||||
if config.Ref != "" {
|
||||
cc.SetImageURL(config.Ref)
|
||||
}
|
||||
|
||||
if resp, err := cc.Do(context.Background()); err == nil {
|
||||
for i := 0; i < 1200; i++ {
|
||||
r, err := c.AsyncResult(resp.ID).Do(context.Background())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if r.TaskStatus == zhipu.VideoGenerationTaskStatusSuccess {
|
||||
covers := make([]string, 0)
|
||||
results := make([]string, 0)
|
||||
for _, item := range r.VideoResult {
|
||||
results = append(results, item.URL)
|
||||
covers = append(covers, item.CoverImageURL)
|
||||
}
|
||||
return results, covers, nil
|
||||
}
|
||||
if r.TaskStatus == zhipu.VideoGenerationTaskStatusFail {
|
||||
return nil, nil, errors.New("fail on task " + resp.ID)
|
||||
}
|
||||
time.Sleep(3 * time.Second)
|
||||
}
|
||||
return nil, nil, errors.New("timeout on task " + resp.ID)
|
||||
} else {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
108
llm/zhipu/zhipu/CHANGELOG.md
Normal file
108
llm/zhipu/zhipu/CHANGELOG.md
Normal file
@ -0,0 +1,108 @@
|
||||
# Changelog
|
||||
All notable changes to this project will be documented in this file. See [conventional commits](https://www.conventionalcommits.org/) for commit guidelines.
|
||||
|
||||
- - -
|
||||
## v0.1.2 - 2024-08-15
|
||||
#### Bug Fixes
|
||||
- add FinishReasonStopSequence - (01b4201) - GUO YANKE
|
||||
#### Documentation
|
||||
- update README.md [skip ci] - (e48a88b) - GUO YANKE
|
||||
#### Features
|
||||
- add videos/generations - (7261999) - GUO YANKE
|
||||
#### Miscellaneous Chores
|
||||
- relaxing go version to 1.18 - (6acc17c) - GUO YANKE
|
||||
|
||||
- - -
|
||||
|
||||
## v0.1.1 - 2024-07-17
|
||||
#### Documentation
|
||||
- update README.md [skip ci] - (695432a) - GUO YANKE
|
||||
#### Features
|
||||
- add support for GLM-4-AllTools - (9627a36) - GUO YANKE
|
||||
|
||||
- - -
|
||||
|
||||
## v0.1.0 - 2024-06-28
|
||||
#### Bug Fixes
|
||||
- rename client function for batch list - (40ac05f) - GUO YANKE
|
||||
#### Documentation
|
||||
- update README.md [skip ci] - (6ce5754) - GUO YANKE
|
||||
#### Features
|
||||
- add knowledge capacity service - (4ce62b3) - GUO YANKE
|
||||
#### Refactoring
|
||||
- update batch service - (b92d438) - GUO YANKE
|
||||
- update chat completion service - (19dd77f) - GUO YANKE
|
||||
- update embedding service - (c1bbc2d) - GUO YANKE
|
||||
- update file services - (7ef4d87) - GUO YANKE
|
||||
- update fine tune services, using APIError - (15aed88) - GUO YANKE
|
||||
- update fine tune services - (664523b) - GUO YANKE
|
||||
- update image generation service - (a18e028) - GUO YANKE
|
||||
- update knowledge services - (c7bfb73) - GUO YANKE
|
||||
|
||||
- - -
|
||||
|
||||
## v0.0.6 - 2024-06-28
|
||||
#### Features
|
||||
- add batch support for result reader - (c062095) - GUO YANKE
|
||||
- add fine tune services - (f172f51) - GUO YANKE
|
||||
- add knowledge service - (09792b5) - GUO YANKE
|
||||
|
||||
- - -
|
||||
|
||||
## v0.0.5 - 2024-06-28
|
||||
#### Bug Fixes
|
||||
- api error parsing - (60a17f4) - GUO YANKE
|
||||
#### Features
|
||||
- add batch service - (389aec3) - GUO YANKE
|
||||
- add batch support for chat completions, image generations and embeddings - (c017ffd) - GUO YANKE
|
||||
- add file edit/get/delete service - (8a4d309) - GUO YANKE
|
||||
- add file create serivce - (6d2140b) - GUO YANKE
|
||||
|
||||
- - -
|
||||
|
||||
## v0.0.4 - 2024-06-26
|
||||
#### Bug Fixes
|
||||
- remove Client.R(), hide resty for future removal - (dc2a4ca) - GUO YANKE
|
||||
#### Features
|
||||
- add meta support for charglm - (fdd20e7) - GUO YANKE
|
||||
- add client option to custom http client - (c62d6a9) - GUO YANKE
|
||||
|
||||
- - -
|
||||
|
||||
## v0.0.3 - 2024-06-26
|
||||
#### Features
|
||||
- add image generation service - (9f3f54f) - GUO YANKE
|
||||
- add support for vision models - (2dcd82a) - GUO YANKE
|
||||
- add embedding service - (f57806a) - GUO YANKE
|
||||
|
||||
- - -
|
||||
|
||||
## v0.0.2 - 2024-06-26
|
||||
#### Bug Fixes
|
||||
- **(deps)** update golang-jwt/jwt to v5 - (2f76a57) - GUO YANKE
|
||||
#### Features
|
||||
- add constants for roles - (3d08a72) - GUO YANKE
|
||||
|
||||
- - -
|
||||
|
||||
## v0.0.1 - 2024-06-26
|
||||
#### Bug Fixes
|
||||
- add json tag "omitempty" to various types - (bf81097) - GUO YANKE
|
||||
#### Continuous Integration
|
||||
- add github action workflows for testing - (5a64987) - GUO YANKE
|
||||
#### Documentation
|
||||
- update README.md [skip ci] - (d504f57) - GUO YANKE
|
||||
#### Features
|
||||
- add chat completion in stream mode - (130fe1d) - GUO YANKE
|
||||
- add chat completion in non-stream mode - (2326e37) - GUO YANKE
|
||||
- support debug option while creating client - (0f104d8) - GUO YANKE
|
||||
- add APIError and APIErrorResponse - (1886d85) - GUO YANKE
|
||||
- add client struct - (710d8e8) - GUO YANKE
|
||||
#### Refactoring
|
||||
- change signature of Client#createJWT since there is no reason to fail - (f0d7887) - GUO YANKE
|
||||
#### Tests
|
||||
- add client_test.go - (a3fc217) - GUO YANKE
|
||||
|
||||
- - -
|
||||
|
||||
Changelog generated by [cocogitto](https://github.com/cocogitto/cocogitto).
|
21
llm/zhipu/zhipu/LICENSE
Normal file
21
llm/zhipu/zhipu/LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Yanke G.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
280
llm/zhipu/zhipu/README.md
Normal file
280
llm/zhipu/zhipu/README.md
Normal file
@ -0,0 +1,280 @@
|
||||
# zhipu
|
||||
|
||||
[![Go Reference](https://pkg.go.dev/badge/github.com/yankeguo/zhipu.svg)](https://pkg.go.dev/github.com/yankeguo/zhipu)
|
||||
[![go](https://github.com/yankeguo/zhipu/actions/workflows/go.yml/badge.svg)](https://github.com/yankeguo/zhipu/actions/workflows/go.yml)
|
||||
|
||||
[中文文档](README.zh.md)
|
||||
|
||||
A 3rd-Party Golang Client Library for Zhipu AI Platform
|
||||
|
||||
## Usage
|
||||
|
||||
### Install the package
|
||||
|
||||
```bash
|
||||
go get -u github.com/yankeguo/zhipu
|
||||
```
|
||||
|
||||
### Create a client
|
||||
|
||||
```go
|
||||
// this will use environment variables ZHIPUAI_API_KEY
|
||||
client, err := zhipu.NewClient()
|
||||
// or you can specify the API key
|
||||
client, err = zhipu.NewClient(zhipu.WithAPIKey("your api key"))
|
||||
```
|
||||
|
||||
### Use the client
|
||||
|
||||
**ChatCompletion**
|
||||
|
||||
```go
|
||||
service := client.ChatCompletion("glm-4-flash").
|
||||
AddMessage(zhipu.ChatCompletionMessage{
|
||||
Role: "user",
|
||||
Content: "你好",
|
||||
})
|
||||
|
||||
res, err := service.Do(context.Background())
|
||||
|
||||
if err != nil {
|
||||
zhipu.GetAPIErrorCode(err) // get the API error code
|
||||
} else {
|
||||
println(res.Choices[0].Message.Content)
|
||||
}
|
||||
```
|
||||
|
||||
**ChatCompletion (Stream)**
|
||||
|
||||
```go
|
||||
service := client.ChatCompletion("glm-4-flash").
|
||||
AddMessage(zhipu.ChatCompletionMessage{
|
||||
Role: "user",
|
||||
Content: "你好",
|
||||
}).SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
|
||||
println(chunk.Choices[0].Delta.Content)
|
||||
return nil
|
||||
})
|
||||
|
||||
res, err := service.Do(context.Background())
|
||||
|
||||
if err != nil {
|
||||
zhipu.GetAPIErrorCode(err) // get the API error code
|
||||
} else {
|
||||
// this package will combine the stream chunks and build a final result mimicking the non-streaming API
|
||||
println(res.Choices[0].Message.Content)
|
||||
}
|
||||
```
|
||||
|
||||
**ChatCompletion (Stream with GLM-4-AllTools)**
|
||||
|
||||
```go
|
||||
// CodeInterpreter
|
||||
s := client.ChatCompletion("GLM-4-AllTools")
|
||||
s.AddMessage(zhipu.ChatCompletionMultiMessage{
|
||||
Role: "user",
|
||||
Content: []zhipu.ChatCompletionMultiContent{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "计算[5,10,20,700,99,310,978,100]的平均值和方差。",
|
||||
},
|
||||
},
|
||||
})
|
||||
s.AddTool(zhipu.ChatCompletionToolCodeInterpreter{
|
||||
Sandbox: zhipu.Ptr(CodeInterpreterSandboxAuto),
|
||||
})
|
||||
s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
|
||||
for _, c := range chunk.Choices {
|
||||
for _, tc := range c.Delta.ToolCalls {
|
||||
if tc.Type == ToolTypeCodeInterpreter && tc.CodeInterpreter != nil {
|
||||
if tc.CodeInterpreter.Input != "" {
|
||||
// DO SOMETHING
|
||||
}
|
||||
if len(tc.CodeInterpreter.Outputs) > 0 {
|
||||
// DO SOMETHING
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// WebBrowser
|
||||
// CAUTION: NOT 'WebSearch'
|
||||
s := client.ChatCompletion("GLM-4-AllTools")
|
||||
s.AddMessage(zhipu.ChatCompletionMultiMessage{
|
||||
Role: "user",
|
||||
Content: []zhipu.ChatCompletionMultiContent{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "搜索下本周深圳天气如何",
|
||||
},
|
||||
},
|
||||
})
|
||||
s.AddTool(zhipu.ChatCompletionToolWebBrowser{})
|
||||
s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
|
||||
for _, c := range chunk.Choices {
|
||||
for _, tc := range c.Delta.ToolCalls {
|
||||
if tc.Type == ToolTypeWebBrowser && tc.WebBrowser != nil {
|
||||
if tc.WebBrowser.Input != "" {
|
||||
// DO SOMETHING
|
||||
}
|
||||
if len(tc.WebBrowser.Outputs) > 0 {
|
||||
// DO SOMETHING
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
s.Do(context.Background())
|
||||
|
||||
// DrawingTool
|
||||
s := client.ChatCompletion("GLM-4-AllTools")
|
||||
s.AddMessage(zhipu.ChatCompletionMultiMessage{
|
||||
Role: "user",
|
||||
Content: []zhipu.ChatCompletionMultiContent{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "画一个正弦函数图像",
|
||||
},
|
||||
},
|
||||
})
|
||||
s.AddTool(zhipu.ChatCompletionToolDrawingTool{})
|
||||
s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
|
||||
for _, c := range chunk.Choices {
|
||||
for _, tc := range c.Delta.ToolCalls {
|
||||
if tc.Type == ToolTypeDrawingTool && tc.DrawingTool != nil {
|
||||
if tc.DrawingTool.Input != "" {
|
||||
// DO SOMETHING
|
||||
}
|
||||
if len(tc.DrawingTool.Outputs) > 0 {
|
||||
// DO SOMETHING
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
s.Do(context.Background())
|
||||
```
|
||||
|
||||
**Embedding**
|
||||
|
||||
```go
|
||||
service := client.Embedding("embedding-v2").SetInput("你好呀")
|
||||
service.Do(context.Background())
|
||||
```
|
||||
|
||||
**Image Generation**
|
||||
|
||||
```go
|
||||
service := client.ImageGeneration("cogview-3").SetPrompt("一只可爱的小猫咪")
|
||||
service.Do(context.Background())
|
||||
```
|
||||
|
||||
**Video Generation**
|
||||
|
||||
```go
|
||||
service := client.VideoGeneration("cogvideox").SetPrompt("一只可爱的小猫咪")
|
||||
resp, err := service.Do(context.Background())
|
||||
|
||||
for {
|
||||
result, err := client.AsyncResult(resp.ID).Do(context.Background())
|
||||
|
||||
if result.TaskStatus == zhipu.VideoGenerationTaskStatusSuccess {
|
||||
_ = result.VideoResult[0].URL
|
||||
_ = result.VideoResult[0].CoverImageURL
|
||||
break
|
||||
}
|
||||
|
||||
if result.TaskStatus != zhipu.VideoGenerationTaskStatusProcessing {
|
||||
break
|
||||
}
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
```
|
||||
|
||||
**Upload File (Retrieval)**
|
||||
|
||||
```go
|
||||
service := client.FileCreate(zhipu.FilePurposeRetrieval)
|
||||
service.SetLocalFile(filepath.Join("testdata", "test-file.txt"))
|
||||
service.SetKnowledgeID("your-knowledge-id")
|
||||
|
||||
service.Do(context.Background())
|
||||
```
|
||||
|
||||
**Upload File (Fine-Tune)**
|
||||
|
||||
```go
|
||||
service := client.FileCreate(zhipu.FilePurposeFineTune)
|
||||
service.SetLocalFile(filepath.Join("testdata", "test-file.jsonl"))
|
||||
service.Do(context.Background())
|
||||
```
|
||||
|
||||
**Batch Create**
|
||||
|
||||
```go
|
||||
service := client.BatchCreate().
|
||||
SetInputFileID("fileid").
|
||||
SetCompletionWindow(zhipu.BatchCompletionWindow24h).
|
||||
SetEndpoint(BatchEndpointV4ChatCompletions)
|
||||
service.Do(context.Background())
|
||||
```
|
||||
|
||||
**Knowledge Base**
|
||||
|
||||
```go
|
||||
client.KnowledgeCreate("")
|
||||
client.KnowledgeEdit("")
|
||||
```
|
||||
|
||||
**Fine Tune**
|
||||
|
||||
```go
|
||||
client.FineTuneCreate("")
|
||||
```
|
||||
|
||||
### Batch Support
|
||||
|
||||
**Batch File Writer**
|
||||
|
||||
```go
|
||||
f, err := os.OpenFile("batch.jsonl", os.O_CREATE|os.O_WRONLY, 0644)
|
||||
|
||||
bw := zhipu.NewBatchFileWriter(f)
|
||||
|
||||
bw.Add("action_1", client.ChatCompletion("glm-4-flash").
|
||||
AddMessage(zhipu.ChatCompletionMessage{
|
||||
Role: "user",
|
||||
Content: "你好",
|
||||
}))
|
||||
bw.Add("action_2", client.Embedding("embedding-v2").SetInput("你好呀"))
|
||||
bw.Add("action_3", client.ImageGeneration("cogview-3").SetPrompt("一只可爱的小猫咪"))
|
||||
```
|
||||
|
||||
**Batch Result Reader**
|
||||
|
||||
```go
|
||||
br := zhipu.NewBatchResultReader[zhipu.ChatCompletionResponse](r)
|
||||
|
||||
for {
|
||||
var res zhipu.BatchResult[zhipu.ChatCompletionResponse]
|
||||
err := br.Read(&res)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Donation
|
||||
|
||||
Executing unit tests will actually call the ChatGLM API and consume my quota. Please donate and thank you for your support!
|
||||
|
||||
<img src="./wechat-donation.png" width="180"/>
|
||||
|
||||
## Credits
|
||||
|
||||
GUO YANKE, MIT License
|
278
llm/zhipu/zhipu/README.zh.md
Normal file
278
llm/zhipu/zhipu/README.zh.md
Normal file
@ -0,0 +1,278 @@
|
||||
# zhipu
|
||||
|
||||
[![Go Reference](https://pkg.go.dev/badge/github.com/yankeguo/zhipu.svg)](https://pkg.go.dev/github.com/yankeguo/zhipu)
|
||||
[![go](https://github.com/yankeguo/zhipu/actions/workflows/go.yml/badge.svg)](https://github.com/yankeguo/zhipu/actions/workflows/go.yml)
|
||||
|
||||
Zhipu AI 平台第三方 Golang 客户端库
|
||||
|
||||
## 用法
|
||||
|
||||
### 安装库
|
||||
|
||||
```bash
|
||||
go get -u github.com/yankeguo/zhipu
|
||||
```
|
||||
|
||||
### 创建客户端
|
||||
|
||||
```go
|
||||
// 默认使用环境变量 ZHIPUAI_API_KEY
|
||||
client, err := zhipu.NewClient()
|
||||
// 或者手动指定密钥
|
||||
client, err = zhipu.NewClient(zhipu.WithAPIKey("your api key"))
|
||||
```
|
||||
|
||||
### 使用客户端
|
||||
|
||||
**ChatCompletion(大语言模型)**
|
||||
|
||||
```go
|
||||
service := client.ChatCompletion("glm-4-flash").
|
||||
AddMessage(zhipu.ChatCompletionMessage{
|
||||
Role: "user",
|
||||
Content: "你好",
|
||||
})
|
||||
|
||||
res, err := service.Do(context.Background())
|
||||
|
||||
if err != nil {
|
||||
zhipu.GetAPIErrorCode(err) // get the API error code
|
||||
} else {
|
||||
println(res.Choices[0].Message.Content)
|
||||
}
|
||||
```
|
||||
|
||||
**ChatCompletion(流式调用大语言模型)**
|
||||
|
||||
```go
|
||||
service := client.ChatCompletion("glm-4-flash").
|
||||
AddMessage(zhipu.ChatCompletionMessage{
|
||||
Role: "user",
|
||||
Content: "你好",
|
||||
}).SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
|
||||
println(chunk.Choices[0].Delta.Content)
|
||||
return nil
|
||||
})
|
||||
|
||||
res, err := service.Do(context.Background())
|
||||
|
||||
if err != nil {
|
||||
zhipu.GetAPIErrorCode(err) // get the API error code
|
||||
} else {
|
||||
// this package will combine the stream chunks and build a final result mimicking the non-streaming API
|
||||
println(res.Choices[0].Message.Content)
|
||||
}
|
||||
```
|
||||
|
||||
**ChatCompletion(流式调用大语言工具模型GLM-4-AllTools)**
|
||||
|
||||
```go
|
||||
// CodeInterpreter
|
||||
s := client.ChatCompletion("GLM-4-AllTools")
|
||||
s.AddMessage(zhipu.ChatCompletionMultiMessage{
|
||||
Role: "user",
|
||||
Content: []zhipu.ChatCompletionMultiContent{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "计算[5,10,20,700,99,310,978,100]的平均值和方差。",
|
||||
},
|
||||
},
|
||||
})
|
||||
s.AddTool(zhipu.ChatCompletionToolCodeInterpreter{
|
||||
Sandbox: zhipu.Ptr(CodeInterpreterSandboxAuto),
|
||||
})
|
||||
s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
|
||||
for _, c := range chunk.Choices {
|
||||
for _, tc := range c.Delta.ToolCalls {
|
||||
if tc.Type == ToolTypeCodeInterpreter && tc.CodeInterpreter != nil {
|
||||
if tc.CodeInterpreter.Input != "" {
|
||||
// DO SOMETHING
|
||||
}
|
||||
if len(tc.CodeInterpreter.Outputs) > 0 {
|
||||
// DO SOMETHING
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// WebBrowser
|
||||
// CAUTION: NOT 'WebSearch'
|
||||
s := client.ChatCompletion("GLM-4-AllTools")
|
||||
s.AddMessage(zhipu.ChatCompletionMultiMessage{
|
||||
Role: "user",
|
||||
Content: []zhipu.ChatCompletionMultiContent{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "搜索下本周深圳天气如何",
|
||||
},
|
||||
},
|
||||
})
|
||||
s.AddTool(zhipu.ChatCompletionToolWebBrowser{})
|
||||
s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
|
||||
for _, c := range chunk.Choices {
|
||||
for _, tc := range c.Delta.ToolCalls {
|
||||
if tc.Type == ToolTypeWebBrowser && tc.WebBrowser != nil {
|
||||
if tc.WebBrowser.Input != "" {
|
||||
// DO SOMETHING
|
||||
}
|
||||
if len(tc.WebBrowser.Outputs) > 0 {
|
||||
// DO SOMETHING
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
s.Do(context.Background())
|
||||
|
||||
// DrawingTool
|
||||
s := client.ChatCompletion("GLM-4-AllTools")
|
||||
s.AddMessage(zhipu.ChatCompletionMultiMessage{
|
||||
Role: "user",
|
||||
Content: []zhipu.ChatCompletionMultiContent{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "画一个正弦函数图像",
|
||||
},
|
||||
},
|
||||
})
|
||||
s.AddTool(zhipu.ChatCompletionToolDrawingTool{})
|
||||
s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error {
|
||||
for _, c := range chunk.Choices {
|
||||
for _, tc := range c.Delta.ToolCalls {
|
||||
if tc.Type == ToolTypeDrawingTool && tc.DrawingTool != nil {
|
||||
if tc.DrawingTool.Input != "" {
|
||||
// DO SOMETHING
|
||||
}
|
||||
if len(tc.DrawingTool.Outputs) > 0 {
|
||||
// DO SOMETHING
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
s.Do(context.Background())
|
||||
```
|
||||
|
||||
**Embedding**
|
||||
|
||||
```go
|
||||
service := client.Embedding("embedding-v2").SetInput("你好呀")
|
||||
service.Do(context.Background())
|
||||
```
|
||||
|
||||
**ImageGeneration(图像生成)**
|
||||
|
||||
```go
|
||||
service := client.ImageGeneration("cogview-3").SetPrompt("一只可爱的小猫咪")
|
||||
service.Do(context.Background())
|
||||
```
|
||||
|
||||
**VideoGeneration(视频生成)**
|
||||
|
||||
```go
|
||||
service := client.VideoGeneration("cogvideox").SetPrompt("一只可爱的小猫咪")
|
||||
resp, err := service.Do(context.Background())
|
||||
|
||||
for {
|
||||
result, err := client.AsyncResult(resp.ID).Do(context.Background())
|
||||
|
||||
if result.TaskStatus == zhipu.VideoGenerationTaskStatusSuccess {
|
||||
_ = result.VideoResult[0].URL
|
||||
_ = result.VideoResult[0].CoverImageURL
|
||||
break
|
||||
}
|
||||
|
||||
if result.TaskStatus != zhipu.VideoGenerationTaskStatusProcessing {
|
||||
break
|
||||
}
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
```
|
||||
|
||||
**UploadFile(上传文件用于取回)**
|
||||
|
||||
```go
|
||||
service := client.FileCreate(zhipu.FilePurposeRetrieval)
|
||||
service.SetLocalFile(filepath.Join("testdata", "test-file.txt"))
|
||||
service.SetKnowledgeID("your-knowledge-id")
|
||||
|
||||
service.Do(context.Background())
|
||||
```
|
||||
|
||||
**UploadFile(上传文件用于微调)**
|
||||
|
||||
```go
|
||||
service := client.FileCreate(zhipu.FilePurposeFineTune)
|
||||
service.SetLocalFile(filepath.Join("testdata", "test-file.jsonl"))
|
||||
service.Do(context.Background())
|
||||
```
|
||||
|
||||
**BatchCreate(创建批量任务)**
|
||||
|
||||
```go
|
||||
service := client.BatchCreate().
|
||||
SetInputFileID("fileid").
|
||||
SetCompletionWindow(zhipu.BatchCompletionWindow24h).
|
||||
SetEndpoint(BatchEndpointV4ChatCompletions)
|
||||
service.Do(context.Background())
|
||||
```
|
||||
|
||||
**KnowledgeBase(知识库)**
|
||||
|
||||
```go
|
||||
client.KnowledgeCreate("")
|
||||
client.KnowledgeEdit("")
|
||||
```
|
||||
|
||||
**FineTune(微调)**
|
||||
|
||||
```go
|
||||
client.FineTuneCreate("")
|
||||
```
|
||||
|
||||
### 批量任务辅助工具
|
||||
|
||||
**批量任务文件创建**
|
||||
|
||||
```go
|
||||
f, err := os.OpenFile("batch.jsonl", os.O_CREATE|os.O_WRONLY, 0644)
|
||||
|
||||
bw := zhipu.NewBatchFileWriter(f)
|
||||
|
||||
bw.Add("action_1", client.ChatCompletion("glm-4-flash").
|
||||
AddMessage(zhipu.ChatCompletionMessage{
|
||||
Role: "user",
|
||||
Content: "你好",
|
||||
}))
|
||||
bw.Add("action_2", client.Embedding("embedding-v2").SetInput("你好呀"))
|
||||
bw.Add("action_3", client.ImageGeneration("cogview-3").SetPrompt("一只可爱的小猫咪"))
|
||||
```
|
||||
|
||||
**批量任务结果解析**
|
||||
|
||||
```go
|
||||
br := zhipu.NewBatchResultReader[zhipu.ChatCompletionResponse](r)
|
||||
|
||||
for {
|
||||
var res zhipu.BatchResult[zhipu.ChatCompletionResponse]
|
||||
err := br.Read(&res)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 赞助
|
||||
|
||||
执行单元测试会真实调用GLM接口,消耗我充值的额度,开发不易,请微信扫码捐赠,感谢您的支持!
|
||||
|
||||
<img src="./wechat-donation.png" width="180"/>
|
||||
|
||||
## 许可证
|
||||
|
||||
GUO YANKE, MIT License
|
63
llm/zhipu/zhipu/async_result.go
Normal file
63
llm/zhipu/zhipu/async_result.go
Normal file
@ -0,0 +1,63 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
// AsyncResultService creates a new async result get service
|
||||
type AsyncResultService struct {
|
||||
client *Client
|
||||
|
||||
id string
|
||||
}
|
||||
|
||||
// AsyncResultVideo is the video result of the AsyncResultService
|
||||
type AsyncResultVideo struct {
|
||||
URL string `json:"url"`
|
||||
CoverImageURL string `json:"cover_image_url"`
|
||||
}
|
||||
|
||||
// AsyncResultResponse is the response of the AsyncResultService
|
||||
type AsyncResultResponse struct {
|
||||
Model string `json:"model"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
RequestID string `json:"request_id"`
|
||||
ID string `json:"id"`
|
||||
VideoResult []AsyncResultVideo `json:"video_result"`
|
||||
}
|
||||
|
||||
// NewAsyncResultService creates a new async result get service
|
||||
func NewAsyncResultService(client *Client) *AsyncResultService {
|
||||
return &AsyncResultService{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// SetID sets the id parameter
|
||||
func (s *AsyncResultService) SetID(id string) *AsyncResultService {
|
||||
s.id = id
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *AsyncResultService) Do(ctx context.Context) (res AsyncResultResponse, err error) {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
|
||||
if resp, err = s.client.request(ctx).
|
||||
SetResult(&res).
|
||||
SetError(&apiError).
|
||||
Get("async-result/" + s.id); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
1
llm/zhipu/zhipu/async_result_test.go
Normal file
1
llm/zhipu/zhipu/async_result_test.go
Normal file
@ -0,0 +1 @@
|
||||
package zhipu
|
258
llm/zhipu/zhipu/batch.go
Normal file
258
llm/zhipu/zhipu/batch.go
Normal file
@ -0,0 +1,258 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
BatchEndpointV4ChatCompletions = "/v4/chat/completions"
|
||||
BatchEndpointV4ImagesGenerations = "/v4/images/generations"
|
||||
BatchEndpointV4Embeddings = "/v4/embeddings"
|
||||
BatchEndpointV4VideosGenerations = "/v4/videos/generations"
|
||||
|
||||
BatchCompletionWindow24h = "24h"
|
||||
)
|
||||
|
||||
// BatchRequestCounts represents the counts of the batch requests.
|
||||
type BatchRequestCounts struct {
|
||||
Total int64 `json:"total"`
|
||||
Completed int64 `json:"completed"`
|
||||
Failed int64 `json:"failed"`
|
||||
}
|
||||
|
||||
// BatchItem represents a batch item.
|
||||
type BatchItem struct {
|
||||
ID string `json:"id"`
|
||||
Object any `json:"object"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
InputFileID string `json:"input_file_id"`
|
||||
CompletionWindow string `json:"completion_window"`
|
||||
Status string `json:"status"`
|
||||
OutputFileID string `json:"output_file_id"`
|
||||
ErrorFileID string `json:"error_file_id"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
InProgressAt int64 `json:"in_progress_at"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
FinalizingAt int64 `json:"finalizing_at"`
|
||||
CompletedAt int64 `json:"completed_at"`
|
||||
FailedAt int64 `json:"failed_at"`
|
||||
ExpiredAt int64 `json:"expired_at"`
|
||||
CancellingAt int64 `json:"cancelling_at"`
|
||||
CancelledAt int64 `json:"cancelled_at"`
|
||||
RequestCounts BatchRequestCounts `json:"request_counts"`
|
||||
Metadata json.RawMessage `json:"metadata"`
|
||||
}
|
||||
|
||||
// BatchCreateService is a service to create a batch.
|
||||
type BatchCreateService struct {
|
||||
client *Client
|
||||
|
||||
inputFileID string
|
||||
endpoint string
|
||||
completionWindow string
|
||||
metadata any
|
||||
}
|
||||
|
||||
// NewBatchCreateService creates a new BatchCreateService.
|
||||
func NewBatchCreateService(client *Client) *BatchCreateService {
|
||||
return &BatchCreateService{client: client}
|
||||
}
|
||||
|
||||
// SetInputFileID sets the input file id for the batch.
|
||||
func (s *BatchCreateService) SetInputFileID(inputFileID string) *BatchCreateService {
|
||||
s.inputFileID = inputFileID
|
||||
return s
|
||||
}
|
||||
|
||||
// SetEndpoint sets the endpoint for the batch.
|
||||
func (s *BatchCreateService) SetEndpoint(endpoint string) *BatchCreateService {
|
||||
s.endpoint = endpoint
|
||||
return s
|
||||
}
|
||||
|
||||
// SetCompletionWindow sets the completion window for the batch.
|
||||
func (s *BatchCreateService) SetCompletionWindow(window string) *BatchCreateService {
|
||||
s.completionWindow = window
|
||||
return s
|
||||
}
|
||||
|
||||
// SetMetadata sets the metadata for the batch.
|
||||
func (s *BatchCreateService) SetMetadata(metadata any) *BatchCreateService {
|
||||
s.metadata = metadata
|
||||
return s
|
||||
}
|
||||
|
||||
// Do executes the batch create service.
|
||||
func (s *BatchCreateService) Do(ctx context.Context) (res BatchItem, err error) {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
|
||||
if resp, err = s.client.request(ctx).
|
||||
SetBody(M{
|
||||
"input_file_id": s.inputFileID,
|
||||
"endpoint": s.endpoint,
|
||||
"completion_window": s.completionWindow,
|
||||
"metadata": s.metadata,
|
||||
}).
|
||||
SetResult(&res).
|
||||
SetError(&apiError).
|
||||
Post("batches"); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// BatchGetService is a service to get a batch.
|
||||
type BatchGetService struct {
|
||||
client *Client
|
||||
batchID string
|
||||
}
|
||||
|
||||
// BatchGetResponse represents the response of the batch get service.
|
||||
type BatchGetResponse = BatchItem
|
||||
|
||||
// NewBatchGetService creates a new BatchGetService.
|
||||
func NewBatchGetService(client *Client) *BatchGetService {
|
||||
return &BatchGetService{client: client}
|
||||
}
|
||||
|
||||
// SetBatchID sets the batch id for the batch get service.
|
||||
func (s *BatchGetService) SetBatchID(batchID string) *BatchGetService {
|
||||
s.batchID = batchID
|
||||
return s
|
||||
}
|
||||
|
||||
// Do executes the batch get service.
|
||||
func (s *BatchGetService) Do(ctx context.Context) (res BatchGetResponse, err error) {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
|
||||
if resp, err = s.client.request(ctx).
|
||||
SetPathParam("batch_id", s.batchID).
|
||||
SetResult(&res).
|
||||
SetError(&apiError).
|
||||
Get("batches/{batch_id}"); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// BatchCancelService is a service to cancel a batch.
|
||||
type BatchCancelService struct {
|
||||
client *Client
|
||||
batchID string
|
||||
}
|
||||
|
||||
// NewBatchCancelService creates a new BatchCancelService.
|
||||
func NewBatchCancelService(client *Client) *BatchCancelService {
|
||||
return &BatchCancelService{client: client}
|
||||
}
|
||||
|
||||
// SetBatchID sets the batch id for the batch cancel service.
|
||||
func (s *BatchCancelService) SetBatchID(batchID string) *BatchCancelService {
|
||||
s.batchID = batchID
|
||||
return s
|
||||
}
|
||||
|
||||
// Do executes the batch cancel service.
|
||||
func (s *BatchCancelService) Do(ctx context.Context) (err error) {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
|
||||
if resp, err = s.client.request(ctx).
|
||||
SetPathParam("batch_id", s.batchID).
|
||||
SetBody(M{}).
|
||||
SetError(&apiError).
|
||||
Post("batches/{batch_id}/cancel"); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// BatchListService is a service to list batches.
|
||||
type BatchListService struct {
|
||||
client *Client
|
||||
|
||||
after *string
|
||||
limit *int
|
||||
}
|
||||
|
||||
// BatchListResponse represents the response of the batch list service.
|
||||
type BatchListResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []BatchItem `json:"data"`
|
||||
FirstID string `json:"first_id"`
|
||||
LastID string `json:"last_id"`
|
||||
HasMore bool `json:"has_more"`
|
||||
}
|
||||
|
||||
// NewBatchListService creates a new BatchListService.
|
||||
func NewBatchListService(client *Client) *BatchListService {
|
||||
return &BatchListService{client: client}
|
||||
}
|
||||
|
||||
// SetAfter sets the after cursor for the batch list service.
|
||||
func (s *BatchListService) SetAfter(after string) *BatchListService {
|
||||
s.after = &after
|
||||
return s
|
||||
}
|
||||
|
||||
// SetLimit sets the limit for the batch list service.
|
||||
func (s *BatchListService) SetLimit(limit int) *BatchListService {
|
||||
s.limit = &limit
|
||||
return s
|
||||
}
|
||||
|
||||
// Do executes the batch list service.
|
||||
func (s *BatchListService) Do(ctx context.Context) (res BatchListResponse, err error) {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
|
||||
req := s.client.request(ctx)
|
||||
if s.after != nil {
|
||||
req.SetQueryParam("after", *s.after)
|
||||
}
|
||||
if s.limit != nil {
|
||||
req.SetQueryParam("limit", strconv.Itoa(*s.limit))
|
||||
}
|
||||
|
||||
if resp, err = req.
|
||||
SetResult(&res).
|
||||
SetError(&apiError).
|
||||
Get("batches"); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
}
|
||||
|
||||
return
|
||||
}
|
63
llm/zhipu/zhipu/batch_support.go
Normal file
63
llm/zhipu/zhipu/batch_support.go
Normal file
@ -0,0 +1,63 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
)
|
||||
|
||||
// BatchSupport is the interface for services with batch support.
|
||||
type BatchSupport interface {
|
||||
BatchMethod() string
|
||||
BatchURL() string
|
||||
BatchBody() any
|
||||
}
|
||||
|
||||
// BatchFileWriter is a writer for batch files.
|
||||
type BatchFileWriter struct {
|
||||
w io.Writer
|
||||
je *json.Encoder
|
||||
}
|
||||
|
||||
// NewBatchFileWriter creates a new BatchFileWriter.
|
||||
func NewBatchFileWriter(w io.Writer) *BatchFileWriter {
|
||||
return &BatchFileWriter{w: w, je: json.NewEncoder(w)}
|
||||
}
|
||||
|
||||
// Write writes a batch file.
|
||||
func (b *BatchFileWriter) Write(customID string, s BatchSupport) error {
|
||||
return b.je.Encode(M{
|
||||
"custom_id": customID,
|
||||
"method": s.BatchMethod(),
|
||||
"url": s.BatchURL(),
|
||||
"body": s.BatchBody(),
|
||||
})
|
||||
}
|
||||
|
||||
// BatchResultResponse is the response of a batch result.
|
||||
type BatchResultResponse[T any] struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
Body T `json:"body"`
|
||||
}
|
||||
|
||||
// BatchResult is the result of a batch.
|
||||
type BatchResult[T any] struct {
|
||||
ID string `json:"id"`
|
||||
CustomID string `json:"custom_id"`
|
||||
Response BatchResultResponse[T] `json:"response"`
|
||||
}
|
||||
|
||||
// BatchResultReader reads batch results.
|
||||
type BatchResultReader[T any] struct {
|
||||
r io.Reader
|
||||
jd *json.Decoder
|
||||
}
|
||||
|
||||
// NewBatchResultReader creates a new BatchResultReader.
|
||||
func NewBatchResultReader[T any](r io.Reader) *BatchResultReader[T] {
|
||||
return &BatchResultReader[T]{r: r, jd: json.NewDecoder(r)}
|
||||
}
|
||||
|
||||
// Read reads a batch result.
|
||||
func (r *BatchResultReader[T]) Read(out *BatchResult[T]) error {
|
||||
return r.jd.Decode(out)
|
||||
}
|
73
llm/zhipu/zhipu/batch_support_test.go
Normal file
73
llm/zhipu/zhipu/batch_support_test.go
Normal file
@ -0,0 +1,73 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBatchFileWriter(t *testing.T) {
|
||||
client, err := NewClient()
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
w := NewBatchFileWriter(buf)
|
||||
err = w.Write("batch-1", client.ChatCompletion("a").AddMessage(ChatCompletionMessage{
|
||||
Role: "user", Content: "hello",
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
err = w.Write("batch-2", client.Embedding("c").SetInput("whoa"))
|
||||
require.NoError(t, err)
|
||||
err = w.Write("batch-3", client.ImageGeneration("d").SetPrompt("whoa"))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, `{"body":{"messages":[{"role":"user","content":"hello"}],"model":"a"},"custom_id":"batch-1","method":"POST","url":"/v4/chat/completions"}
|
||||
{"body":{"input":"whoa","model":"c"},"custom_id":"batch-2","method":"POST","url":"/v4/embeddings"}
|
||||
{"body":{"model":"d","prompt":"whoa"},"custom_id":"batch-3","method":"POST","url":"/v4/images/generations"}
|
||||
`, buf.String())
|
||||
}
|
||||
|
||||
func TestBatchResultReader(t *testing.T) {
|
||||
result := `
|
||||
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":26,"prompt_tokens":89,"total_tokens":115},"model":"glm-4","id":"8668357533850320547","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"订单处理慢\"\n}\n'''"}}],"request_id":"615-request-1"}},"custom_id":"request-1","id":"batch_1791490810192076800"}
|
||||
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":22,"prompt_tokens":94,"total_tokens":116},"model":"glm-4","id":"8668368425887509080","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"产品缺陷\"\n}\n'''"}}],"request_id":"616-request-2"}},"custom_id":"request-2","id":"batch_1791490810192076800"}
|
||||
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":25,"prompt_tokens":86,"total_tokens":111},"model":"glm-4","id":"8668355815863214980","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"正面\",\n \"特定问题标注\": \"性价比\"\n}\n'''"}}],"request_id":"617-request-3"}},"custom_id":"request-3","id":"batch_1791490810192076800"}
|
||||
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":28,"prompt_tokens":89,"total_tokens":117},"model":"glm-4","id":"8668355815863214981","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"说明文档不清晰\"\n}\n'''"}}],"request_id":"618-request-4"}},"custom_id":"request-4","id":"batch_1791490810192076800"}
|
||||
|
||||
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":26,"prompt_tokens":88,"total_tokens":114},"model":"glm-4","id":"8668357533850320546","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"中性\",\n \"特定问题标注\": \"价格问题\"\n}\n'''"}}],"request_id":"619-request-5"}},"custom_id":"request-5","id":"batch_1791490810192076800"}
|
||||
|
||||
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":26,"prompt_tokens":90,"total_tokens":116},"model":"glm-4","id":"8668356159460662846","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"配送延迟\"\n}\n'''"}}],"request_id":"620-request-6"}},"custom_id":"request-6","id":"batch_1791490810192076800"}
|
||||
|
||||
|
||||
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":27,"prompt_tokens":88,"total_tokens":115},"model":"glm-4","id":"8668357671289274638","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"产品描述不符\"\n}\n'''"}}],"request_id":"621-request-7"}},"custom_id":"request-7","id":"batch_1791490810192076800"}
|
||||
{"response":{"status_code":200,"body":{"created":1715959702,"usage":{"completion_tokens":26,"prompt_tokens":87,"total_tokens":113},"model":"glm-4","id":"8668355644064514872","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"正面\",\n \"特定问题标注\": \"客服态度\"\n}\n'''"}}],"request_id":"622-request-8"}},"custom_id":"request-8","id":"batch_1791490810192076800"}
|
||||
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":29,"prompt_tokens":90,"total_tokens":119},"model":"glm-4","id":"8668357671289274639","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"包装问题, 产品损坏\"\n}\n'''"}}],"request_id":"623-request-9"}},"custom_id":"request-9","id":"batch_1791490810192076800"}
|
||||
{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":27,"prompt_tokens":87,"total_tokens":114},"model":"glm-4","id":"8668355644064514871","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"正面\",\n \"特定问题标注\": \"产品描述不符\"\n}\n'''"}}],"request_id":"624-request-10"}},"custom_id":"request-10","id":"batch_1791490810192076800"}
|
||||
`
|
||||
|
||||
brr := NewBatchResultReader[ChatCompletionResponse](bytes.NewReader([]byte(result)))
|
||||
|
||||
var count int
|
||||
|
||||
for {
|
||||
var res BatchResult[ChatCompletionResponse]
|
||||
|
||||
err := brr.Read(&res)
|
||||
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
require.Equal(t, 10, count)
|
||||
require.NoError(t, err)
|
||||
break
|
||||
}
|
||||
|
||||
require.Equal(t, 200, res.Response.StatusCode)
|
||||
|
||||
count++
|
||||
}
|
||||
}
|
59
llm/zhipu/zhipu/batch_test.go
Normal file
59
llm/zhipu/zhipu/batch_test.go
Normal file
@ -0,0 +1,59 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBatchServiceAll(t *testing.T) {
|
||||
client, err := NewClient()
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
bfw := NewBatchFileWriter(buf)
|
||||
err = bfw.Write("batch_1", client.ChatCompletion("glm-4-flash").AddMessage(ChatCompletionMessage{
|
||||
Role: RoleUser, Content: "你好呀",
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
err = bfw.Write("batch_2", client.ChatCompletion("glm-4-flash").AddMessage(ChatCompletionMessage{
|
||||
Role: RoleUser, Content: "你叫什么名字",
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
|
||||
res, err := client.FileCreate(FilePurposeBatch).SetFile(bytes.NewReader(buf.Bytes()), "batch.jsonl").Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
fileID := res.FileCreateFineTuneResponse.ID
|
||||
require.NotEmpty(t, fileID)
|
||||
|
||||
res1, err := client.BatchCreate().
|
||||
SetInputFileID(fileID).
|
||||
SetCompletionWindow(BatchCompletionWindow24h).
|
||||
SetEndpoint(BatchEndpointV4ChatCompletions).Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, res1.ID)
|
||||
|
||||
res2, err := client.BatchGet(res1.ID).Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, res2.ID, res1.ID)
|
||||
|
||||
res3, err := client.BatchList().Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, res3.Data)
|
||||
|
||||
err = client.BatchCancel(res1.ID).Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestBatchListService(t *testing.T) {
|
||||
client, err := NewClient()
|
||||
require.NoError(t, err)
|
||||
|
||||
res, err := client.BatchList().Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
t.Log(res)
|
||||
}
|
577
llm/zhipu/zhipu/chat_completion.go
Normal file
577
llm/zhipu/zhipu/chat_completion.go
Normal file
@ -0,0 +1,577 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
RoleSystem = "system"
|
||||
RoleUser = "user"
|
||||
RoleAssistant = "assistant"
|
||||
RoleTool = "tool"
|
||||
|
||||
ToolChoiceAuto = "auto"
|
||||
|
||||
FinishReasonStop = "stop"
|
||||
FinishReasonStopSequence = "stop_sequence"
|
||||
FinishReasonToolCalls = "tool_calls"
|
||||
FinishReasonLength = "length"
|
||||
FinishReasonSensitive = "sensitive"
|
||||
FinishReasonNetworkError = "network_error"
|
||||
|
||||
ToolTypeFunction = "function"
|
||||
ToolTypeWebSearch = "web_search"
|
||||
ToolTypeRetrieval = "retrieval"
|
||||
|
||||
MultiContentTypeText = "text"
|
||||
MultiContentTypeImageURL = "image_url"
|
||||
MultiContentTypeVideoURL = "video_url"
|
||||
|
||||
// New in GLM-4-AllTools
|
||||
ToolTypeCodeInterpreter = "code_interpreter"
|
||||
ToolTypeDrawingTool = "drawing_tool"
|
||||
ToolTypeWebBrowser = "web_browser"
|
||||
|
||||
CodeInterpreterSandboxNone = "none"
|
||||
CodeInterpreterSandboxAuto = "auto"
|
||||
|
||||
ChatCompletionStatusFailed = "failed"
|
||||
ChatCompletionStatusCompleted = "completed"
|
||||
ChatCompletionStatusRequiresAction = "requires_action"
|
||||
)
|
||||
|
||||
// ChatCompletionTool is the interface for chat completion tool
|
||||
type ChatCompletionTool interface {
|
||||
isChatCompletionTool()
|
||||
}
|
||||
|
||||
// ChatCompletionToolFunction is the function for chat completion
|
||||
type ChatCompletionToolFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters any `json:"parameters"`
|
||||
}
|
||||
|
||||
func (ChatCompletionToolFunction) isChatCompletionTool() {}
|
||||
|
||||
// ChatCompletionToolRetrieval is the retrieval for chat completion
|
||||
type ChatCompletionToolRetrieval struct {
|
||||
KnowledgeID string `json:"knowledge_id"`
|
||||
PromptTemplate string `json:"prompt_template,omitempty"`
|
||||
}
|
||||
|
||||
func (ChatCompletionToolRetrieval) isChatCompletionTool() {}
|
||||
|
||||
// ChatCompletionToolWebSearch is the web search for chat completion
|
||||
type ChatCompletionToolWebSearch struct {
|
||||
Enable *bool `json:"enable,omitempty"`
|
||||
SearchQuery string `json:"search_query,omitempty"`
|
||||
SearchResult bool `json:"search_result,omitempty"`
|
||||
}
|
||||
|
||||
func (ChatCompletionToolWebSearch) isChatCompletionTool() {}
|
||||
|
||||
// ChatCompletionToolCodeInterpreter is the code interpreter for chat completion
|
||||
// only in GLM-4-AllTools
|
||||
type ChatCompletionToolCodeInterpreter struct {
|
||||
Sandbox *string `json:"sandbox,omitempty"`
|
||||
}
|
||||
|
||||
func (ChatCompletionToolCodeInterpreter) isChatCompletionTool() {}
|
||||
|
||||
// ChatCompletionToolDrawingTool is the drawing tool for chat completion
|
||||
// only in GLM-4-AllTools
|
||||
type ChatCompletionToolDrawingTool struct {
|
||||
// no fields
|
||||
}
|
||||
|
||||
func (ChatCompletionToolDrawingTool) isChatCompletionTool() {}
|
||||
|
||||
// ChatCompletionToolWebBrowser is the web browser for chat completion
|
||||
type ChatCompletionToolWebBrowser struct {
|
||||
// no fields
|
||||
}
|
||||
|
||||
func (ChatCompletionToolWebBrowser) isChatCompletionTool() {}
|
||||
|
||||
// ChatCompletionUsage is the usage for chat completion
|
||||
type ChatCompletionUsage struct {
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// ChatCompletionWebSearch is the web search result for chat completion
|
||||
type ChatCompletionWebSearch struct {
|
||||
Icon string `json:"icon"`
|
||||
Title string `json:"title"`
|
||||
Link string `json:"link"`
|
||||
Media string `json:"media"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// ChatCompletionToolCallFunction is the function for chat completion tool call
|
||||
type ChatCompletionToolCallFunction struct {
|
||||
Name string `json:"name"`
|
||||
Arguments json.RawMessage `json:"arguments"`
|
||||
}
|
||||
|
||||
// ChatCompletionToolCallCodeInterpreterOutput is the output for chat completion tool call code interpreter
|
||||
type ChatCompletionToolCallCodeInterpreterOutput struct {
|
||||
Type string `json:"type"`
|
||||
Logs string `json:"logs"`
|
||||
File string `json:"file"`
|
||||
}
|
||||
|
||||
// ChatCompletionToolCallCodeInterpreter is the code interpreter for chat completion tool call
|
||||
type ChatCompletionToolCallCodeInterpreter struct {
|
||||
Input string `json:"input"`
|
||||
Outputs []ChatCompletionToolCallCodeInterpreterOutput `json:"outputs"`
|
||||
}
|
||||
|
||||
// ChatCompletionToolCallDrawingToolOutput is the output for chat completion tool call drawing tool
|
||||
type ChatCompletionToolCallDrawingToolOutput struct {
|
||||
Image string `json:"image"`
|
||||
}
|
||||
|
||||
// ChatCompletionToolCallDrawingTool is the drawing tool for chat completion tool call
|
||||
type ChatCompletionToolCallDrawingTool struct {
|
||||
Input string `json:"input"`
|
||||
Outputs []ChatCompletionToolCallDrawingToolOutput `json:"outputs"`
|
||||
}
|
||||
|
||||
// ChatCompletionToolCallWebBrowserOutput is the output for chat completion tool call web browser
|
||||
type ChatCompletionToolCallWebBrowserOutput struct {
|
||||
Title string `json:"title"`
|
||||
Link string `json:"link"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// ChatCompletionToolCallWebBrowser is the web browser for chat completion tool call
|
||||
type ChatCompletionToolCallWebBrowser struct {
|
||||
Input string `json:"input"`
|
||||
Outputs []ChatCompletionToolCallWebBrowserOutput `json:"outputs"`
|
||||
}
|
||||
|
||||
// ChatCompletionToolCall is the tool call for chat completion
|
||||
type ChatCompletionToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function *ChatCompletionToolCallFunction `json:"function,omitempty"`
|
||||
CodeInterpreter *ChatCompletionToolCallCodeInterpreter `json:"code_interpreter,omitempty"`
|
||||
DrawingTool *ChatCompletionToolCallDrawingTool `json:"drawing_tool,omitempty"`
|
||||
WebBrowser *ChatCompletionToolCallWebBrowser `json:"web_browser,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionMessageType interface {
|
||||
isChatCompletionMessageType()
|
||||
}
|
||||
|
||||
// ChatCompletionMessage is the message for chat completion
|
||||
type ChatCompletionMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content,omitempty"`
|
||||
ToolCalls []ChatCompletionToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
func (ChatCompletionMessage) isChatCompletionMessageType() {}
|
||||
|
||||
type ChatCompletionMultiContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
ImageURL *URLItem `json:"image_url,omitempty"`
|
||||
VideoURL *URLItem `json:"video_url,omitempty"`
|
||||
}
|
||||
|
||||
// ChatCompletionMultiMessage is the multi message for chat completion
|
||||
type ChatCompletionMultiMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content []ChatCompletionMultiContent `json:"content"`
|
||||
}
|
||||
|
||||
func (ChatCompletionMultiMessage) isChatCompletionMessageType() {}
|
||||
|
||||
// ChatCompletionMeta is the meta for chat completion
|
||||
type ChatCompletionMeta struct {
|
||||
UserInfo string `json:"user_info"`
|
||||
BotInfo string `json:"bot_info"`
|
||||
UserName string `json:"user_name"`
|
||||
BotName string `json:"bot_name"`
|
||||
}
|
||||
|
||||
// ChatCompletionChoice is the choice for chat completion
|
||||
type ChatCompletionChoice struct {
|
||||
Index int `json:"index"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Delta ChatCompletionMessage `json:"delta"` // stream mode
|
||||
Message ChatCompletionMessage `json:"message"` // non-stream mode
|
||||
}
|
||||
|
||||
// ChatCompletionResponse is the response for chat completion
|
||||
type ChatCompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatCompletionChoice `json:"choices"`
|
||||
Usage ChatCompletionUsage `json:"usage"`
|
||||
WebSearch []ChatCompletionWebSearch `json:"web_search"`
|
||||
// Status is the status of the chat completion, only in GLM-4-AllTools
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// ChatCompletionStreamHandler is the handler for chat completion stream
|
||||
type ChatCompletionStreamHandler func(chunk ChatCompletionResponse) error
|
||||
|
||||
var (
|
||||
chatCompletionStreamPrefix = []byte("data:")
|
||||
chatCompletionStreamDone = []byte("[DONE]")
|
||||
)
|
||||
|
||||
// chatCompletionReduceResponse reduce the chunk to the response
|
||||
func chatCompletionReduceResponse(out *ChatCompletionResponse, chunk ChatCompletionResponse) {
|
||||
if len(out.Choices) == 0 {
|
||||
out.Choices = append(out.Choices, ChatCompletionChoice{})
|
||||
}
|
||||
|
||||
// basic
|
||||
out.ID = chunk.ID
|
||||
out.Created = chunk.Created
|
||||
out.Model = chunk.Model
|
||||
|
||||
// choices
|
||||
if len(chunk.Choices) != 0 {
|
||||
oc := &out.Choices[0]
|
||||
cc := chunk.Choices[0]
|
||||
|
||||
oc.Index = cc.Index
|
||||
if cc.Delta.Role != "" {
|
||||
oc.Message.Role = cc.Delta.Role
|
||||
}
|
||||
oc.Message.Content += cc.Delta.Content
|
||||
oc.Message.ToolCalls = append(oc.Message.ToolCalls, cc.Delta.ToolCalls...)
|
||||
if cc.FinishReason != "" {
|
||||
oc.FinishReason = cc.FinishReason
|
||||
}
|
||||
}
|
||||
|
||||
// usage
|
||||
if chunk.Usage.CompletionTokens != 0 {
|
||||
out.Usage.CompletionTokens = chunk.Usage.CompletionTokens
|
||||
}
|
||||
if chunk.Usage.PromptTokens != 0 {
|
||||
out.Usage.PromptTokens = chunk.Usage.PromptTokens
|
||||
}
|
||||
if chunk.Usage.TotalTokens != 0 {
|
||||
out.Usage.TotalTokens = chunk.Usage.TotalTokens
|
||||
}
|
||||
|
||||
// web search
|
||||
out.WebSearch = append(out.WebSearch, chunk.WebSearch...)
|
||||
}
|
||||
|
||||
// chatCompletionDecodeStream decode the sse stream of chat completion
|
||||
func chatCompletionDecodeStream(r io.Reader, fn func(chunk ChatCompletionResponse) error) (err error) {
|
||||
br := bufio.NewReader(r)
|
||||
|
||||
for {
|
||||
var line []byte
|
||||
|
||||
if line, err = br.ReadBytes('\n'); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line)
|
||||
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if !bytes.HasPrefix(line, chatCompletionStreamPrefix) {
|
||||
continue
|
||||
}
|
||||
|
||||
data := bytes.TrimSpace(line[len(chatCompletionStreamPrefix):])
|
||||
|
||||
if bytes.Equal(data, chatCompletionStreamDone) {
|
||||
break
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var chunk ChatCompletionResponse
|
||||
if err = json.Unmarshal(data, &chunk); err != nil {
|
||||
return
|
||||
}
|
||||
if err = fn(chunk); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// ChatCompletionStreamService is the service for chat completion stream
|
||||
type ChatCompletionService struct {
|
||||
client *Client
|
||||
|
||||
model string
|
||||
requestID *string
|
||||
doSample *bool
|
||||
temperature *float64
|
||||
topP *float64
|
||||
maxTokens *int
|
||||
stop []string
|
||||
toolChoice *string
|
||||
userID *string
|
||||
meta *ChatCompletionMeta
|
||||
|
||||
messages []any
|
||||
tools []any
|
||||
|
||||
streamHandler ChatCompletionStreamHandler
|
||||
}
|
||||
|
||||
var (
|
||||
_ BatchSupport = &ChatCompletionService{}
|
||||
)
|
||||
|
||||
// NewChatCompletionService creates a new ChatCompletionService.
|
||||
func NewChatCompletionService(client *Client) *ChatCompletionService {
|
||||
return &ChatCompletionService{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ChatCompletionService) BatchMethod() string {
|
||||
return "POST"
|
||||
}
|
||||
|
||||
func (s *ChatCompletionService) BatchURL() string {
|
||||
return BatchEndpointV4ChatCompletions
|
||||
}
|
||||
|
||||
func (s *ChatCompletionService) BatchBody() any {
|
||||
return s.buildBody()
|
||||
}
|
||||
|
||||
// SetModel set the model of the chat completion
|
||||
func (s *ChatCompletionService) SetModel(model string) *ChatCompletionService {
|
||||
s.model = model
|
||||
return s
|
||||
}
|
||||
|
||||
// SetMeta set the meta of the chat completion, optional
|
||||
func (s *ChatCompletionService) SetMeta(meta ChatCompletionMeta) *ChatCompletionService {
|
||||
s.meta = &meta
|
||||
return s
|
||||
}
|
||||
|
||||
// SetRequestID set the request id of the chat completion, optional
|
||||
func (s *ChatCompletionService) SetRequestID(requestID string) *ChatCompletionService {
|
||||
s.requestID = &requestID
|
||||
return s
|
||||
}
|
||||
|
||||
// SetTemperature set the temperature of the chat completion, optional
|
||||
func (s *ChatCompletionService) SetDoSample(doSample bool) *ChatCompletionService {
|
||||
s.doSample = &doSample
|
||||
return s
|
||||
}
|
||||
|
||||
// SetTemperature set the temperature of the chat completion, optional
|
||||
func (s *ChatCompletionService) SetTemperature(temperature float64) *ChatCompletionService {
|
||||
s.temperature = &temperature
|
||||
return s
|
||||
}
|
||||
|
||||
// SetTopP set the top p of the chat completion, optional
|
||||
func (s *ChatCompletionService) SetTopP(topP float64) *ChatCompletionService {
|
||||
s.topP = &topP
|
||||
return s
|
||||
}
|
||||
|
||||
// SetMaxTokens set the max tokens of the chat completion, optional
|
||||
func (s *ChatCompletionService) SetMaxTokens(maxTokens int) *ChatCompletionService {
|
||||
s.maxTokens = &maxTokens
|
||||
return s
|
||||
}
|
||||
|
||||
// SetStop set the stop of the chat completion, optional
|
||||
func (s *ChatCompletionService) SetStop(stop ...string) *ChatCompletionService {
|
||||
s.stop = stop
|
||||
return s
|
||||
}
|
||||
|
||||
// SetToolChoice set the tool choice of the chat completion, optional
|
||||
func (s *ChatCompletionService) SetToolChoice(toolChoice string) *ChatCompletionService {
|
||||
s.toolChoice = &toolChoice
|
||||
return s
|
||||
}
|
||||
|
||||
// SetUserID set the user id of the chat completion, optional
|
||||
func (s *ChatCompletionService) SetUserID(userID string) *ChatCompletionService {
|
||||
s.userID = &userID
|
||||
return s
|
||||
}
|
||||
|
||||
// SetStreamHandler set the stream handler of the chat completion, optional
|
||||
// this will enable the stream mode
|
||||
func (s *ChatCompletionService) SetStreamHandler(handler ChatCompletionStreamHandler) *ChatCompletionService {
|
||||
s.streamHandler = handler
|
||||
return s
|
||||
}
|
||||
|
||||
// AddMessage add the message to the chat completion
|
||||
func (s *ChatCompletionService) AddMessage(messages ...ChatCompletionMessageType) *ChatCompletionService {
|
||||
for _, message := range messages {
|
||||
s.messages = append(s.messages, message)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// AddFunction add the function to the chat completion
|
||||
func (s *ChatCompletionService) AddTool(tools ...ChatCompletionTool) *ChatCompletionService {
|
||||
for _, tool := range tools {
|
||||
switch tool := tool.(type) {
|
||||
case ChatCompletionToolFunction:
|
||||
s.tools = append(s.tools, map[string]any{
|
||||
"type": ToolTypeFunction,
|
||||
ToolTypeFunction: tool,
|
||||
})
|
||||
case ChatCompletionToolRetrieval:
|
||||
s.tools = append(s.tools, map[string]any{
|
||||
"type": ToolTypeRetrieval,
|
||||
ToolTypeRetrieval: tool,
|
||||
})
|
||||
case ChatCompletionToolWebSearch:
|
||||
s.tools = append(s.tools, map[string]any{
|
||||
"type": ToolTypeWebSearch,
|
||||
ToolTypeWebSearch: tool,
|
||||
})
|
||||
case ChatCompletionToolCodeInterpreter:
|
||||
s.tools = append(s.tools, map[string]any{
|
||||
"type": ToolTypeCodeInterpreter,
|
||||
ToolTypeCodeInterpreter: tool,
|
||||
})
|
||||
case ChatCompletionToolDrawingTool:
|
||||
s.tools = append(s.tools, map[string]any{
|
||||
"type": ToolTypeDrawingTool,
|
||||
ToolTypeDrawingTool: tool,
|
||||
})
|
||||
case ChatCompletionToolWebBrowser:
|
||||
s.tools = append(s.tools, map[string]any{
|
||||
"type": ToolTypeWebBrowser,
|
||||
ToolTypeWebBrowser: tool,
|
||||
})
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *ChatCompletionService) buildBody() M {
|
||||
body := map[string]any{
|
||||
"model": s.model,
|
||||
"messages": s.messages,
|
||||
}
|
||||
if s.requestID != nil {
|
||||
body["request_id"] = *s.requestID
|
||||
}
|
||||
if s.doSample != nil {
|
||||
body["do_sample"] = *s.doSample
|
||||
}
|
||||
if s.temperature != nil {
|
||||
body["temperature"] = *s.temperature
|
||||
}
|
||||
if s.topP != nil {
|
||||
body["top_p"] = *s.topP
|
||||
}
|
||||
if s.maxTokens != nil {
|
||||
body["max_tokens"] = *s.maxTokens
|
||||
}
|
||||
if len(s.stop) != 0 {
|
||||
body["stop"] = s.stop
|
||||
}
|
||||
if len(s.tools) != 0 {
|
||||
body["tools"] = s.tools
|
||||
}
|
||||
if s.toolChoice != nil {
|
||||
body["tool_choice"] = *s.toolChoice
|
||||
}
|
||||
if s.userID != nil {
|
||||
body["user_id"] = *s.userID
|
||||
}
|
||||
if s.meta != nil {
|
||||
body["meta"] = s.meta
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// Do send the request of the chat completion and return the response
|
||||
func (s *ChatCompletionService) Do(ctx context.Context) (res ChatCompletionResponse, err error) {
|
||||
body := s.buildBody()
|
||||
|
||||
streamHandler := s.streamHandler
|
||||
|
||||
if streamHandler == nil {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
//fmt.Println(u.BMagenta(u.JsonP(body)), 111)
|
||||
if resp, err = s.client.request(ctx).SetBody(body).SetResult(&res).SetError(&apiError).Post("chat/completions"); err != nil {
|
||||
//fmt.Println(u.BRed(err.Error()), 2221)
|
||||
return
|
||||
}
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
//fmt.Println(u.BRed(err.Error()), 2222)
|
||||
return
|
||||
}
|
||||
//fmt.Println(u.BGreen(u.JsonP(resp.Result())), resp.Status(), resp.Status(), 333)
|
||||
return
|
||||
}
|
||||
|
||||
// stream mode
|
||||
|
||||
body["stream"] = true
|
||||
|
||||
var resp *resty.Response
|
||||
|
||||
if resp, err = s.client.request(ctx).SetBody(body).SetDoNotParseResponse(true).Post("chat/completions"); err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.RawBody().Close()
|
||||
|
||||
if resp.IsError() {
|
||||
err = errors.New(resp.Status())
|
||||
return
|
||||
}
|
||||
|
||||
var choice ChatCompletionChoice
|
||||
|
||||
if err = chatCompletionDecodeStream(resp.RawBody(), func(chunk ChatCompletionResponse) error {
|
||||
// reduce the chunk to the response
|
||||
chatCompletionReduceResponse(&res, chunk)
|
||||
// invoke the stream handler
|
||||
return streamHandler(chunk)
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
res.Choices = append(res.Choices, choice)
|
||||
|
||||
return
|
||||
}
|
251
llm/zhipu/zhipu/chat_completion_test.go
Normal file
251
llm/zhipu/zhipu/chat_completion_test.go
Normal file
@ -0,0 +1,251 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestChatCompletionService(t *testing.T) {
|
||||
client, err := NewClient()
|
||||
require.NoError(t, err)
|
||||
|
||||
s := client.ChatCompletion("glm-4-flash")
|
||||
s.AddMessage(ChatCompletionMessage{
|
||||
Role: RoleUser,
|
||||
Content: "你好呀",
|
||||
})
|
||||
res, err := s.Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, res.Choices)
|
||||
choice := res.Choices[0]
|
||||
require.Equal(t, FinishReasonStop, choice.FinishReason)
|
||||
require.NotEmpty(t, choice.Message.Content)
|
||||
}
|
||||
|
||||
func TestChatCompletionServiceCharGLM(t *testing.T) {
|
||||
client, err := NewClient()
|
||||
require.NoError(t, err)
|
||||
|
||||
s := client.ChatCompletion("charglm-3")
|
||||
s.SetMeta(
|
||||
ChatCompletionMeta{
|
||||
UserName: "啵酱",
|
||||
UserInfo: "啵酱是小少爷",
|
||||
BotName: "塞巴斯酱",
|
||||
BotInfo: "塞巴斯酱是一个冷酷的恶魔管家",
|
||||
},
|
||||
).AddMessage(ChatCompletionMessage{
|
||||
Role: RoleUser,
|
||||
Content: "早上好",
|
||||
})
|
||||
res, err := s.Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, res.Choices)
|
||||
choice := res.Choices[0]
|
||||
require.Contains(t, []string{FinishReasonStop, FinishReasonStopSequence}, choice.FinishReason)
|
||||
require.NotEmpty(t, choice.Message.Content)
|
||||
}
|
||||
|
||||
func TestChatCompletionServiceAllToolsCodeInterpreter(t *testing.T) {
|
||||
client, err := NewClient()
|
||||
require.NoError(t, err)
|
||||
|
||||
s := client.ChatCompletion("GLM-4-AllTools")
|
||||
s.AddMessage(ChatCompletionMultiMessage{
|
||||
Role: "user",
|
||||
Content: []ChatCompletionMultiContent{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "计算[5,10,20,700,99,310,978,100]的平均值和方差。",
|
||||
},
|
||||
},
|
||||
})
|
||||
s.AddTool(ChatCompletionToolCodeInterpreter{
|
||||
Sandbox: Ptr(CodeInterpreterSandboxAuto),
|
||||
})
|
||||
|
||||
foundInterpreterInput := false
|
||||
foundInterpreterOutput := false
|
||||
|
||||
s.SetStreamHandler(func(chunk ChatCompletionResponse) error {
|
||||
for _, c := range chunk.Choices {
|
||||
for _, tc := range c.Delta.ToolCalls {
|
||||
if tc.Type == ToolTypeCodeInterpreter && tc.CodeInterpreter != nil {
|
||||
if tc.CodeInterpreter.Input != "" {
|
||||
foundInterpreterInput = true
|
||||
}
|
||||
if len(tc.CodeInterpreter.Outputs) > 0 {
|
||||
foundInterpreterOutput = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
buf, _ := json.MarshalIndent(chunk, "", " ")
|
||||
t.Log(string(buf))
|
||||
return nil
|
||||
})
|
||||
|
||||
res, err := s.Do(context.Background())
|
||||
require.True(t, foundInterpreterInput)
|
||||
require.True(t, foundInterpreterOutput)
|
||||
require.NotNil(t, res)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestChatCompletionServiceAllToolsDrawingTool(t *testing.T) {
|
||||
client, err := NewClient()
|
||||
require.NoError(t, err)
|
||||
|
||||
s := client.ChatCompletion("GLM-4-AllTools")
|
||||
s.AddMessage(ChatCompletionMultiMessage{
|
||||
Role: "user",
|
||||
Content: []ChatCompletionMultiContent{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "画一个正弦函数图像",
|
||||
},
|
||||
},
|
||||
})
|
||||
s.AddTool(ChatCompletionToolDrawingTool{})
|
||||
|
||||
foundInput := false
|
||||
foundOutput := false
|
||||
outputImage := ""
|
||||
|
||||
s.SetStreamHandler(func(chunk ChatCompletionResponse) error {
|
||||
for _, c := range chunk.Choices {
|
||||
for _, tc := range c.Delta.ToolCalls {
|
||||
if tc.Type == ToolTypeDrawingTool && tc.DrawingTool != nil {
|
||||
if tc.DrawingTool.Input != "" {
|
||||
foundInput = true
|
||||
}
|
||||
if len(tc.DrawingTool.Outputs) > 0 {
|
||||
foundOutput = true
|
||||
}
|
||||
for _, output := range tc.DrawingTool.Outputs {
|
||||
if output.Image != "" {
|
||||
outputImage = output.Image
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
buf, _ := json.MarshalIndent(chunk, "", " ")
|
||||
t.Log(string(buf))
|
||||
return nil
|
||||
})
|
||||
|
||||
res, err := s.Do(context.Background())
|
||||
require.True(t, foundInput)
|
||||
require.True(t, foundOutput)
|
||||
require.NotEmpty(t, outputImage)
|
||||
t.Log(outputImage)
|
||||
require.NotNil(t, res)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestChatCompletionServiceAllToolsWebBrowser(t *testing.T) {
|
||||
client, err := NewClient()
|
||||
require.NoError(t, err)
|
||||
|
||||
s := client.ChatCompletion("GLM-4-AllTools")
|
||||
s.AddMessage(ChatCompletionMultiMessage{
|
||||
Role: "user",
|
||||
Content: []ChatCompletionMultiContent{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "搜索下本周深圳天气如何",
|
||||
},
|
||||
},
|
||||
})
|
||||
s.AddTool(ChatCompletionToolWebBrowser{})
|
||||
|
||||
foundInput := false
|
||||
foundOutput := false
|
||||
outputContent := ""
|
||||
|
||||
s.SetStreamHandler(func(chunk ChatCompletionResponse) error {
|
||||
for _, c := range chunk.Choices {
|
||||
for _, tc := range c.Delta.ToolCalls {
|
||||
if tc.Type == ToolTypeWebBrowser && tc.WebBrowser != nil {
|
||||
if tc.WebBrowser.Input != "" {
|
||||
foundInput = true
|
||||
}
|
||||
if len(tc.WebBrowser.Outputs) > 0 {
|
||||
foundOutput = true
|
||||
}
|
||||
for _, output := range tc.WebBrowser.Outputs {
|
||||
if output.Content != "" {
|
||||
outputContent = output.Content
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
buf, _ := json.MarshalIndent(chunk, "", " ")
|
||||
t.Log(string(buf))
|
||||
return nil
|
||||
})
|
||||
|
||||
res, err := s.Do(context.Background())
|
||||
require.True(t, foundInput)
|
||||
require.True(t, foundOutput)
|
||||
require.NotEmpty(t, outputContent)
|
||||
t.Log(outputContent)
|
||||
require.NotNil(t, res)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestChatCompletionServiceStream(t *testing.T) {
|
||||
client, err := NewClient()
|
||||
require.NoError(t, err)
|
||||
|
||||
var content string
|
||||
|
||||
s := client.ChatCompletion("glm-4-flash").AddMessage(ChatCompletionMessage{
|
||||
Role: RoleUser,
|
||||
Content: "你好呀",
|
||||
}).SetStreamHandler(func(chunk ChatCompletionResponse) error {
|
||||
content += chunk.Choices[0].Delta.Content
|
||||
return nil
|
||||
})
|
||||
res, err := s.Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, res.Choices)
|
||||
choice := res.Choices[0]
|
||||
require.Equal(t, FinishReasonStop, choice.FinishReason)
|
||||
require.NotEmpty(t, choice.Message.Content)
|
||||
require.Equal(t, content, choice.Message.Content)
|
||||
}
|
||||
|
||||
func TestChatCompletionServiceVision(t *testing.T) {
|
||||
client, err := NewClient()
|
||||
require.NoError(t, err)
|
||||
|
||||
s := client.ChatCompletion("glm-4v")
|
||||
s.AddMessage(ChatCompletionMultiMessage{
|
||||
Role: RoleUser,
|
||||
Content: []ChatCompletionMultiContent{
|
||||
{
|
||||
Type: MultiContentTypeText,
|
||||
Text: "图里有什么",
|
||||
},
|
||||
{
|
||||
Type: MultiContentTypeImageURL,
|
||||
ImageURL: &URLItem{
|
||||
URL: "https://img1.baidu.com/it/u=1369931113,3388870256&fm=253&app=138&size=w931&n=0&f=JPEG&fmt=auto?sec=1703696400&t=f3028c7a1dca43a080aeb8239f09cc2f",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
res, err := s.Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, res.Choices)
|
||||
require.NotZero(t, res.Usage.CompletionTokens)
|
||||
choice := res.Choices[0]
|
||||
require.Equal(t, FinishReasonStop, choice.FinishReason)
|
||||
require.NotEmpty(t, choice.Message.Content)
|
||||
}
|
291
llm/zhipu/zhipu/client.go
Normal file
291
llm/zhipu/zhipu/client.go
Normal file
@ -0,0 +1,291 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
const (
|
||||
envAPIKey = "ZHIPUAI_API_KEY"
|
||||
envBaseURL = "ZHIPUAI_BASE_URL"
|
||||
envDebug = "ZHIPUAI_DEBUG"
|
||||
|
||||
defaultBaseURL = "https://open.bigmodel.cn/api/paas/v4"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrAPIKeyMissing is the error when the api key is missing
|
||||
ErrAPIKeyMissing = errors.New("zhipu: api key is missing")
|
||||
// ErrAPIKeyMalformed is the error when the api key is malformed
|
||||
ErrAPIKeyMalformed = errors.New("zhipu: api key is malformed")
|
||||
)
|
||||
|
||||
type clientOptions struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
client *http.Client
|
||||
debug *bool
|
||||
}
|
||||
|
||||
// ClientOption is a function that configures the client
|
||||
type ClientOption func(opts *clientOptions)
|
||||
|
||||
// WithAPIKey set the api key of the client
|
||||
func WithAPIKey(apiKey string) ClientOption {
|
||||
return func(opts *clientOptions) {
|
||||
opts.apiKey = apiKey
|
||||
}
|
||||
}
|
||||
|
||||
// WithBaseURL set the base url of the client
|
||||
func WithBaseURL(baseURL string) ClientOption {
|
||||
return func(opts *clientOptions) {
|
||||
opts.baseURL = baseURL
|
||||
}
|
||||
}
|
||||
|
||||
// WithHTTPClient set the http client of the client
|
||||
func WithHTTPClient(client *http.Client) ClientOption {
|
||||
return func(opts *clientOptions) {
|
||||
opts.client = client
|
||||
}
|
||||
}
|
||||
|
||||
// WithDebug set the debug mode of the client
|
||||
func WithDebug(debug bool) ClientOption {
|
||||
return func(opts *clientOptions) {
|
||||
opts.debug = new(bool)
|
||||
*opts.debug = debug
|
||||
}
|
||||
}
|
||||
|
||||
// Client is the client for zhipu ai platform
|
||||
type Client struct {
|
||||
client *resty.Client
|
||||
debug bool
|
||||
keyID string
|
||||
keySecret []byte
|
||||
}
|
||||
|
||||
func (c *Client) createJWT() string {
|
||||
timestamp := time.Now().UnixMilli()
|
||||
exp := timestamp + time.Hour.Milliseconds()*24*7
|
||||
|
||||
t := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"api_key": c.keyID,
|
||||
"timestamp": timestamp,
|
||||
"exp": exp,
|
||||
})
|
||||
t.Header = map[string]interface{}{
|
||||
"alg": "HS256",
|
||||
"sign_type": "SIGN",
|
||||
}
|
||||
|
||||
token, err := t.SignedString(c.keySecret)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
// request creates a new resty request with the jwt token and context
|
||||
func (c *Client) request(ctx context.Context) *resty.Request {
|
||||
return c.client.R().SetContext(ctx).SetHeader("Authorization", c.createJWT())
|
||||
}
|
||||
|
||||
// NewClient creates a new client
|
||||
// It will read the api key from the environment variable ZHIPUAI_API_KEY
|
||||
// It will read the base url from the environment variable ZHIPUAI_BASE_URL
|
||||
func NewClient(optFns ...ClientOption) (client *Client, err error) {
|
||||
var opts clientOptions
|
||||
for _, optFn := range optFns {
|
||||
optFn(&opts)
|
||||
}
|
||||
// base url
|
||||
if opts.baseURL == "" {
|
||||
opts.baseURL = strings.TrimSpace(os.Getenv(envBaseURL))
|
||||
}
|
||||
if opts.baseURL == "" {
|
||||
opts.baseURL = defaultBaseURL
|
||||
}
|
||||
// api key
|
||||
if opts.apiKey == "" {
|
||||
opts.apiKey = strings.TrimSpace(os.Getenv(envAPIKey))
|
||||
}
|
||||
if opts.apiKey == "" {
|
||||
err = ErrAPIKeyMissing
|
||||
return
|
||||
}
|
||||
// debug
|
||||
if opts.debug == nil {
|
||||
if debugStr := strings.TrimSpace(os.Getenv(envDebug)); debugStr != "" {
|
||||
if debug, err1 := strconv.ParseBool(debugStr); err1 == nil {
|
||||
opts.debug = &debug
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
keyComponents := strings.SplitN(opts.apiKey, ".", 2)
|
||||
|
||||
if len(keyComponents) != 2 {
|
||||
err = ErrAPIKeyMalformed
|
||||
return
|
||||
}
|
||||
|
||||
client = &Client{
|
||||
keyID: keyComponents[0],
|
||||
keySecret: []byte(keyComponents[1]),
|
||||
}
|
||||
|
||||
if opts.client == nil {
|
||||
client.client = resty.New()
|
||||
} else {
|
||||
client.client = resty.NewWithClient(opts.client)
|
||||
}
|
||||
|
||||
client.client = client.client.SetBaseURL(opts.baseURL)
|
||||
|
||||
if opts.debug != nil {
|
||||
client.client.SetDebug(*opts.debug)
|
||||
client.debug = *opts.debug
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// BatchCreate creates a new BatchCreateService.
|
||||
func (c *Client) BatchCreate() *BatchCreateService {
|
||||
return NewBatchCreateService(c)
|
||||
}
|
||||
|
||||
// BatchGet creates a new BatchGetService.
|
||||
func (c *Client) BatchGet(batchID string) *BatchGetService {
|
||||
return NewBatchGetService(c).SetBatchID(batchID)
|
||||
}
|
||||
|
||||
// BatchCancel creates a new BatchCancelService.
|
||||
func (c *Client) BatchCancel(batchID string) *BatchCancelService {
|
||||
return NewBatchCancelService(c).SetBatchID(batchID)
|
||||
}
|
||||
|
||||
// BatchList creates a new BatchListService.
|
||||
func (c *Client) BatchList() *BatchListService {
|
||||
return NewBatchListService(c)
|
||||
}
|
||||
|
||||
// ChatCompletion creates a new ChatCompletionService.
|
||||
func (c *Client) ChatCompletion(model string) *ChatCompletionService {
|
||||
return NewChatCompletionService(c).SetModel(model)
|
||||
}
|
||||
|
||||
// Embedding embeds a list of text into a vector space.
|
||||
func (c *Client) Embedding(model string) *EmbeddingService {
|
||||
return NewEmbeddingService(c).SetModel(model)
|
||||
}
|
||||
|
||||
// FileCreate creates a new FileCreateService.
|
||||
func (c *Client) FileCreate(purpose string) *FileCreateService {
|
||||
return NewFileCreateService(c).SetPurpose(purpose)
|
||||
}
|
||||
|
||||
// FileEditService creates a new FileEditService.
|
||||
func (c *Client) FileEdit(documentID string) *FileEditService {
|
||||
return NewFileEditService(c).SetDocumentID(documentID)
|
||||
}
|
||||
|
||||
// FileList creates a new FileListService.
|
||||
func (c *Client) FileList(purpose string) *FileListService {
|
||||
return NewFileListService(c).SetPurpose(purpose)
|
||||
}
|
||||
|
||||
// FileDeleteService creates a new FileDeleteService.
|
||||
func (c *Client) FileDelete(documentID string) *FileDeleteService {
|
||||
return NewFileDeleteService(c).SetDocumentID(documentID)
|
||||
}
|
||||
|
||||
// FileGetService creates a new FileGetService.
|
||||
func (c *Client) FileGet(documentID string) *FileGetService {
|
||||
return NewFileGetService(c).SetDocumentID(documentID)
|
||||
}
|
||||
|
||||
// FileDownload creates a new FileDownloadService.
|
||||
func (c *Client) FileDownload(fileID string) *FileDownloadService {
|
||||
return NewFileDownloadService(c).SetFileID(fileID)
|
||||
}
|
||||
|
||||
// FineTuneCreate creates a new fine tune create service
|
||||
func (c *Client) FineTuneCreate(model string) *FineTuneCreateService {
|
||||
return NewFineTuneCreateService(c).SetModel(model)
|
||||
}
|
||||
|
||||
// FineTuneEventList creates a new fine tune event list service
|
||||
func (c *Client) FineTuneEventList(jobID string) *FineTuneEventListService {
|
||||
return NewFineTuneEventListService(c).SetJobID(jobID)
|
||||
}
|
||||
|
||||
// FineTuneGet creates a new fine tune get service
|
||||
func (c *Client) FineTuneGet(jobID string) *FineTuneGetService {
|
||||
return NewFineTuneGetService(c).SetJobID(jobID)
|
||||
}
|
||||
|
||||
// FineTuneList creates a new fine tune list service
|
||||
func (c *Client) FineTuneList() *FineTuneListService {
|
||||
return NewFineTuneListService(c)
|
||||
}
|
||||
|
||||
// FineTuneDelete creates a new fine tune delete service
|
||||
func (c *Client) FineTuneDelete(jobID string) *FineTuneDeleteService {
|
||||
return NewFineTuneDeleteService(c).SetJobID(jobID)
|
||||
}
|
||||
|
||||
// FineTuneCancel creates a new fine tune cancel service
|
||||
func (c *Client) FineTuneCancel(jobID string) *FineTuneCancelService {
|
||||
return NewFineTuneCancelService(c).SetJobID(jobID)
|
||||
}
|
||||
|
||||
// ImageGeneration creates a new image generation service
|
||||
func (c *Client) ImageGeneration(model string) *ImageGenerationService {
|
||||
return NewImageGenerationService(c).SetModel(model)
|
||||
}
|
||||
|
||||
// KnowledgeCreate creates a new knowledge create service
|
||||
func (c *Client) KnowledgeCreate() *KnowledgeCreateService {
|
||||
return NewKnowledgeCreateService(c)
|
||||
}
|
||||
|
||||
// KnowledgeEdit creates a new knowledge edit service
|
||||
func (c *Client) KnowledgeEdit(knowledgeID string) *KnowledgeEditService {
|
||||
return NewKnowledgeEditService(c).SetKnowledgeID(knowledgeID)
|
||||
}
|
||||
|
||||
// KnowledgeList list all the knowledge
|
||||
func (c *Client) KnowledgeList() *KnowledgeListService {
|
||||
return NewKnowledgeListService(c)
|
||||
}
|
||||
|
||||
// KnowledgeDelete creates a new knowledge delete service
|
||||
func (c *Client) KnowledgeDelete(knowledgeID string) *KnowledgeDeleteService {
|
||||
return NewKnowledgeDeleteService(c).SetKnowledgeID(knowledgeID)
|
||||
}
|
||||
|
||||
// KnowledgeGet creates a new knowledge get service
|
||||
func (c *Client) KnowledgeCapacity() *KnowledgeCapacityService {
|
||||
return NewKnowledgeCapacityService(c)
|
||||
}
|
||||
|
||||
// VideoGeneration creates a new video generation service
|
||||
func (c *Client) VideoGeneration(model string) *VideoGenerationService {
|
||||
return NewVideoGenerationService(c).SetModel(model)
|
||||
}
|
||||
|
||||
// AsyncResult creates a new async result get service
|
||||
func (c *Client) AsyncResult(id string) *AsyncResultService {
|
||||
return NewAsyncResultService(c).SetID(id)
|
||||
}
|
17
llm/zhipu/zhipu/client_test.go
Normal file
17
llm/zhipu/zhipu/client_test.go
Normal file
@ -0,0 +1,17 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClientR(t *testing.T) {
|
||||
c, err := NewClient()
|
||||
require.NoError(t, err)
|
||||
// the only free api is to list fine-tuning jobs
|
||||
res, err := c.request(context.Background()).Get("fine_tuning/jobs")
|
||||
require.NoError(t, err)
|
||||
require.True(t, res.IsSuccess())
|
||||
}
|
25
llm/zhipu/zhipu/cog.toml
Normal file
25
llm/zhipu/zhipu/cog.toml
Normal file
@ -0,0 +1,25 @@
|
||||
from_latest_tag = false
|
||||
ignore_merge_commits = false
|
||||
disable_changelog = false
|
||||
disable_bump_commit = false
|
||||
generate_mono_repository_global_tag = true
|
||||
branch_whitelist = []
|
||||
skip_ci = "[skip ci]"
|
||||
skip_untracked = false
|
||||
pre_bump_hooks = []
|
||||
post_bump_hooks = []
|
||||
pre_package_bump_hooks = []
|
||||
post_package_bump_hooks = []
|
||||
tag_prefix = "v"
|
||||
|
||||
[git_hooks]
|
||||
|
||||
[commit_types]
|
||||
|
||||
[changelog]
|
||||
path = "CHANGELOG.md"
|
||||
authors = []
|
||||
|
||||
[bump_profiles]
|
||||
|
||||
[packages]
|
87
llm/zhipu/zhipu/embedding.go
Normal file
87
llm/zhipu/zhipu/embedding.go
Normal file
@ -0,0 +1,87 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
// EmbeddingData is the data for each embedding.
|
||||
type EmbeddingData struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
Object string `json:"object"`
|
||||
}
|
||||
|
||||
// EmbeddingResponse is the response from the embedding service.
|
||||
type EmbeddingResponse struct {
|
||||
Model string `json:"model"`
|
||||
Data []EmbeddingData `json:"data"`
|
||||
Object string `json:"object"`
|
||||
Usage ChatCompletionUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// EmbeddingService embeds a list of text into a vector space.
|
||||
type EmbeddingService struct {
|
||||
client *Client
|
||||
|
||||
model string
|
||||
input string
|
||||
}
|
||||
|
||||
var (
|
||||
_ BatchSupport = &EmbeddingService{}
|
||||
)
|
||||
|
||||
// NewEmbeddingService creates a new EmbeddingService.
|
||||
func NewEmbeddingService(client *Client) *EmbeddingService {
|
||||
return &EmbeddingService{client: client}
|
||||
}
|
||||
|
||||
func (s *EmbeddingService) BatchMethod() string {
|
||||
return "POST"
|
||||
}
|
||||
|
||||
func (s *EmbeddingService) BatchURL() string {
|
||||
return BatchEndpointV4Embeddings
|
||||
}
|
||||
|
||||
func (s *EmbeddingService) BatchBody() any {
|
||||
return s.buildBody()
|
||||
}
|
||||
|
||||
// SetModel sets the model to use for the embedding.
|
||||
func (s *EmbeddingService) SetModel(model string) *EmbeddingService {
|
||||
s.model = model
|
||||
return s
|
||||
}
|
||||
|
||||
// SetInput sets the input text to embed.
|
||||
func (s *EmbeddingService) SetInput(input string) *EmbeddingService {
|
||||
s.input = input
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *EmbeddingService) buildBody() M {
|
||||
return M{"model": s.model, "input": s.input}
|
||||
}
|
||||
|
||||
func (s *EmbeddingService) Do(ctx context.Context) (res EmbeddingResponse, err error) {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
|
||||
if resp, err = s.client.request(ctx).
|
||||
SetBody(s.buildBody()).
|
||||
SetResult(&res).
|
||||
SetError(&apiError).
|
||||
Post("embeddings"); err != nil {
|
||||
return
|
||||
}
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
21
llm/zhipu/zhipu/embedding_test.go
Normal file
21
llm/zhipu/zhipu/embedding_test.go
Normal file
@ -0,0 +1,21 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEmbeddingService(t *testing.T) {
|
||||
client, err := NewClient()
|
||||
require.NoError(t, err)
|
||||
|
||||
service := client.Embedding("embedding-2")
|
||||
|
||||
resp, err := service.SetInput("你好").Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotZero(t, resp.Usage.TotalTokens)
|
||||
require.NotEmpty(t, resp.Data)
|
||||
require.NotEmpty(t, resp.Data[0].Embedding)
|
||||
}
|
58
llm/zhipu/zhipu/error.go
Normal file
58
llm/zhipu/zhipu/error.go
Normal file
@ -0,0 +1,58 @@
|
||||
package zhipu
|
||||
|
||||
type APIError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func (e APIError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
type APIErrorResponse struct {
|
||||
APIError `json:"error"`
|
||||
}
|
||||
|
||||
func (e APIErrorResponse) Error() string {
|
||||
return e.APIError.Error()
|
||||
}
|
||||
|
||||
// GetAPIErrorCode returns the error code of an API error.
|
||||
func GetAPIErrorCode(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
if e, ok := err.(APIError); ok {
|
||||
return e.Code
|
||||
}
|
||||
if e, ok := err.(APIErrorResponse); ok {
|
||||
return e.Code
|
||||
}
|
||||
if e, ok := err.(*APIError); ok && e != nil {
|
||||
return e.Code
|
||||
}
|
||||
if e, ok := err.(*APIErrorResponse); ok && e != nil {
|
||||
return e.Code
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetAPIErrorMessage returns the error message of an API error.
|
||||
func GetAPIErrorMessage(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
if e, ok := err.(APIError); ok {
|
||||
return e.Message
|
||||
}
|
||||
if e, ok := err.(APIErrorResponse); ok {
|
||||
return e.Message
|
||||
}
|
||||
if e, ok := err.(*APIError); ok && e != nil {
|
||||
return e.Message
|
||||
}
|
||||
if e, ok := err.(*APIErrorResponse); ok && e != nil {
|
||||
return e.Message
|
||||
}
|
||||
return err.Error()
|
||||
}
|
38
llm/zhipu/zhipu/error_test.go
Normal file
38
llm/zhipu/zhipu/error_test.go
Normal file
@ -0,0 +1,38 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAPIError(t *testing.T) {
|
||||
err := APIError{
|
||||
Code: "code",
|
||||
Message: "message",
|
||||
}
|
||||
require.Equal(t, "message", err.Error())
|
||||
require.Equal(t, "code", GetAPIErrorCode(err))
|
||||
require.Equal(t, "message", GetAPIErrorMessage(err))
|
||||
}
|
||||
|
||||
func TestAPIErrorResponse(t *testing.T) {
|
||||
err := APIErrorResponse{
|
||||
APIError: APIError{
|
||||
Code: "code",
|
||||
Message: "message",
|
||||
},
|
||||
}
|
||||
require.Equal(t, "message", err.Error())
|
||||
require.Equal(t, "code", GetAPIErrorCode(err))
|
||||
require.Equal(t, "message", GetAPIErrorMessage(err))
|
||||
}
|
||||
|
||||
func TestAPIErrorResponseFromDoc(t *testing.T) {
|
||||
var res APIErrorResponse
|
||||
err := json.Unmarshal([]byte(`{"error":{"code":"1002","message":"Authorization Token非法,请确认Authorization Token正确传递。"}}`), &res)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "1002", res.Code)
|
||||
require.Equal(t, "1002", GetAPIErrorCode(res))
|
||||
}
|
541
llm/zhipu/zhipu/file.go
Normal file
541
llm/zhipu/zhipu/file.go
Normal file
@ -0,0 +1,541 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
FilePurposeFineTune = "fine-tune"
|
||||
FilePurposeRetrieval = "retrieval"
|
||||
FilePurposeBatch = "batch"
|
||||
|
||||
KnowledgeTypeArticle = 1
|
||||
KnowledgeTypeQADocument = 2
|
||||
KnowledgeTypeQASpreadsheet = 3
|
||||
KnowledgeTypeProductDatabaseSpreadsheet = 4
|
||||
KnowledgeTypeCustom = 5
|
||||
)
|
||||
|
||||
// FileCreateService is a service to create a file.
|
||||
type FileCreateService struct {
|
||||
client *Client
|
||||
|
||||
purpose string
|
||||
|
||||
localFile string
|
||||
file io.Reader
|
||||
filename string
|
||||
|
||||
customSeparator *string
|
||||
sentenceSize *int
|
||||
knowledgeID *string
|
||||
}
|
||||
|
||||
// FileCreateKnowledgeSuccessInfo is the success info of the FileCreateKnowledgeResponse.
|
||||
type FileCreateKnowledgeSuccessInfo struct {
|
||||
Filename string `json:"fileName"`
|
||||
DocumentID string `json:"documentId"`
|
||||
}
|
||||
|
||||
// FileCreateKnowledgeFailedInfo is the failed info of the FileCreateKnowledgeResponse.
|
||||
type FileCreateKnowledgeFailedInfo struct {
|
||||
Filename string `json:"fileName"`
|
||||
FailReason string `json:"failReason"`
|
||||
}
|
||||
|
||||
// FileCreateKnowledgeResponse is the response of the FileCreateService.
|
||||
type FileCreateKnowledgeResponse struct {
|
||||
SuccessInfos []FileCreateKnowledgeSuccessInfo `json:"successInfos"`
|
||||
FailedInfos []FileCreateKnowledgeFailedInfo `json:"failedInfos"`
|
||||
}
|
||||
|
||||
// FileCreateFineTuneResponse is the response of the FileCreateService.
|
||||
type FileCreateFineTuneResponse struct {
|
||||
Bytes int64 `json:"bytes"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Filename string `json:"filename"`
|
||||
Object string `json:"object"`
|
||||
Purpose string `json:"purpose"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
// FileCreateResponse is the response of the FileCreateService.
|
||||
type FileCreateResponse struct {
|
||||
FileCreateFineTuneResponse
|
||||
FileCreateKnowledgeResponse
|
||||
}
|
||||
|
||||
// NewFileCreateService creates a new FileCreateService.
|
||||
func NewFileCreateService(client *Client) *FileCreateService {
|
||||
return &FileCreateService{client: client}
|
||||
}
|
||||
|
||||
// SetLocalFile sets the local_file parameter of the FileCreateService.
|
||||
func (s *FileCreateService) SetLocalFile(localFile string) *FileCreateService {
|
||||
s.localFile = localFile
|
||||
return s
|
||||
}
|
||||
|
||||
// SetFile sets the file parameter of the FileCreateService.
|
||||
func (s *FileCreateService) SetFile(file io.Reader, filename string) *FileCreateService {
|
||||
s.file = file
|
||||
s.filename = filename
|
||||
return s
|
||||
}
|
||||
|
||||
// SetPurpose sets the purpose parameter of the FileCreateService.
|
||||
func (s *FileCreateService) SetPurpose(purpose string) *FileCreateService {
|
||||
s.purpose = purpose
|
||||
return s
|
||||
}
|
||||
|
||||
// SetCustomSeparator sets the custom_separator parameter of the FileCreateService.
|
||||
func (s *FileCreateService) SetCustomSeparator(customSeparator string) *FileCreateService {
|
||||
s.customSeparator = &customSeparator
|
||||
return s
|
||||
}
|
||||
|
||||
// SetSentenceSize sets the sentence_size parameter of the FileCreateService.
|
||||
func (s *FileCreateService) SetSentenceSize(sentenceSize int) *FileCreateService {
|
||||
s.sentenceSize = &sentenceSize
|
||||
return s
|
||||
}
|
||||
|
||||
// SetKnowledgeID sets the knowledge_id parameter of the FileCreateService.
|
||||
func (s *FileCreateService) SetKnowledgeID(knowledgeID string) *FileCreateService {
|
||||
s.knowledgeID = &knowledgeID
|
||||
return s
|
||||
}
|
||||
|
||||
// Do makes the request.
|
||||
func (s *FileCreateService) Do(ctx context.Context) (res FileCreateResponse, err error) {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
|
||||
body := map[string]string{"purpose": s.purpose}
|
||||
|
||||
if s.customSeparator != nil {
|
||||
body["custom_separator"] = *s.customSeparator
|
||||
}
|
||||
if s.sentenceSize != nil {
|
||||
body["sentence_size"] = strconv.Itoa(*s.sentenceSize)
|
||||
}
|
||||
if s.knowledgeID != nil {
|
||||
body["knowledge_id"] = *s.knowledgeID
|
||||
}
|
||||
|
||||
file, filename := s.file, s.filename
|
||||
|
||||
if file == nil && s.localFile != "" {
|
||||
var f *os.File
|
||||
if f, err = os.Open(s.localFile); err != nil {
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
file = f
|
||||
filename = filepath.Base(s.localFile)
|
||||
}
|
||||
|
||||
if file == nil {
|
||||
err = errors.New("no file specified")
|
||||
return
|
||||
}
|
||||
|
||||
if resp, err = s.client.request(ctx).
|
||||
SetFileReader("file", filename, file).
|
||||
SetMultipartFormData(body).
|
||||
SetResult(&res).
|
||||
SetError(&apiError).
|
||||
Post("files"); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// FileEditService is a service to edit a file.
|
||||
type FileEditService struct {
|
||||
client *Client
|
||||
|
||||
documentID string
|
||||
|
||||
knowledgeType *int
|
||||
customSeparator []string
|
||||
sentenceSize *int
|
||||
}
|
||||
|
||||
// NewFileEditService creates a new FileEditService.
|
||||
func NewFileEditService(client *Client) *FileEditService {
|
||||
return &FileEditService{client: client}
|
||||
}
|
||||
|
||||
// SetDocumentID sets the document_id parameter of the FileEditService.
|
||||
func (s *FileEditService) SetDocumentID(documentID string) *FileEditService {
|
||||
s.documentID = documentID
|
||||
return s
|
||||
}
|
||||
|
||||
// SetKnowledgeType sets the knowledge_type parameter of the FileEditService.
|
||||
func (s *FileEditService) SetKnowledgeType(knowledgeType int) *FileEditService {
|
||||
s.knowledgeType = &knowledgeType
|
||||
return s
|
||||
}
|
||||
|
||||
// SetSentenceSize sets the sentence_size parameter of the FileEditService.
|
||||
func (s *FileEditService) SetCustomSeparator(customSeparator ...string) *FileEditService {
|
||||
s.customSeparator = customSeparator
|
||||
return s
|
||||
}
|
||||
|
||||
// SetSentenceSize sets the sentence_size parameter of the FileEditService.
|
||||
func (s *FileEditService) SetSentenceSize(sentenceSize int) *FileEditService {
|
||||
s.sentenceSize = &sentenceSize
|
||||
return s
|
||||
}
|
||||
|
||||
// Do makes the request.
|
||||
func (s *FileEditService) Do(ctx context.Context) (err error) {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
|
||||
body := M{}
|
||||
|
||||
if s.knowledgeType != nil {
|
||||
body["knowledge_type"] = strconv.Itoa(*s.knowledgeType)
|
||||
}
|
||||
if len(s.customSeparator) > 0 {
|
||||
body["custom_separator"] = s.customSeparator
|
||||
}
|
||||
if s.sentenceSize != nil {
|
||||
body["sentence_size"] = strconv.Itoa(*s.sentenceSize)
|
||||
}
|
||||
|
||||
if resp, err = s.client.request(ctx).
|
||||
SetPathParam("document_id", s.documentID).
|
||||
SetBody(body).
|
||||
SetError(&apiError).
|
||||
Put("document/{document_id}"); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// FileListService is a service to list files.
|
||||
type FileListService struct {
|
||||
client *Client
|
||||
|
||||
purpose string
|
||||
|
||||
knowledgeID *string
|
||||
page *int
|
||||
limit *int
|
||||
after *string
|
||||
orderAsc *bool
|
||||
}
|
||||
|
||||
// FileFailInfo is the failed info of the FileListKnowledgeItem.
|
||||
type FileFailInfo struct {
|
||||
EmbeddingCode int `json:"embedding_code"`
|
||||
EmbeddingMsg string `json:"embedding_msg"`
|
||||
}
|
||||
|
||||
// FileListKnowledgeItem is the item of the FileListKnowledgeResponse.
|
||||
type FileListKnowledgeItem struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
Length int64 `json:"length"`
|
||||
SentenceSize int64 `json:"sentence_size"`
|
||||
CustomSeparator []string `json:"custom_separator"`
|
||||
EmbeddingStat int `json:"embedding_stat"`
|
||||
FailInfo *FileFailInfo `json:"failInfo"`
|
||||
WordNum int64 `json:"word_num"`
|
||||
ParseImage int `json:"parse_image"`
|
||||
}
|
||||
|
||||
// FileListKnowledgeResponse is the response of the FileListService.
|
||||
type FileListKnowledgeResponse struct {
|
||||
Total int `json:"total"`
|
||||
List []FileListKnowledgeItem `json:"list"`
|
||||
}
|
||||
|
||||
// FileListFineTuneItem is the item of the FileListFineTuneResponse.
|
||||
type FileListFineTuneItem struct {
|
||||
Bytes int64 `json:"bytes"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Filename string `json:"filename"`
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Purpose string `json:"purpose"`
|
||||
}
|
||||
|
||||
// FileListFineTuneResponse is the response of the FileListService.
|
||||
type FileListFineTuneResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []FileListFineTuneItem `json:"data"`
|
||||
}
|
||||
|
||||
// FileListResponse is the response of the FileListService.
|
||||
type FileListResponse struct {
|
||||
FileListKnowledgeResponse
|
||||
FileListFineTuneResponse
|
||||
}
|
||||
|
||||
// NewFileListService creates a new FileListService.
|
||||
func NewFileListService(client *Client) *FileListService {
|
||||
return &FileListService{client: client}
|
||||
}
|
||||
|
||||
// SetPurpose sets the purpose parameter of the FileListService.
|
||||
func (s *FileListService) SetPurpose(purpose string) *FileListService {
|
||||
s.purpose = purpose
|
||||
return s
|
||||
}
|
||||
|
||||
// SetKnowledgeID sets the knowledge_id parameter of the FileListService.
|
||||
func (s *FileListService) SetKnowledgeID(knowledgeID string) *FileListService {
|
||||
s.knowledgeID = &knowledgeID
|
||||
return s
|
||||
}
|
||||
|
||||
// SetPage sets the page parameter of the FileListService.
|
||||
func (s *FileListService) SetPage(page int) *FileListService {
|
||||
s.page = &page
|
||||
return s
|
||||
}
|
||||
|
||||
// SetLimit sets the limit parameter of the FileListService.
|
||||
func (s *FileListService) SetLimit(limit int) *FileListService {
|
||||
s.limit = &limit
|
||||
return s
|
||||
}
|
||||
|
||||
// SetAfter sets the after parameter of the FileListService.
|
||||
func (s *FileListService) SetAfter(after string) *FileListService {
|
||||
s.after = &after
|
||||
return s
|
||||
}
|
||||
|
||||
// SetOrder sets the order parameter of the FileListService.
|
||||
func (s *FileListService) SetOrder(asc bool) *FileListService {
|
||||
s.orderAsc = &asc
|
||||
return s
|
||||
}
|
||||
|
||||
// Do makes the request.
|
||||
func (s *FileListService) Do(ctx context.Context) (res FileListResponse, err error) {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
|
||||
m := map[string]string{
|
||||
"purpose": s.purpose,
|
||||
}
|
||||
|
||||
if s.knowledgeID != nil {
|
||||
m["knowledge_id"] = *s.knowledgeID
|
||||
}
|
||||
if s.page != nil {
|
||||
m["page"] = strconv.Itoa(*s.page)
|
||||
}
|
||||
if s.limit != nil {
|
||||
m["limit"] = strconv.Itoa(*s.limit)
|
||||
}
|
||||
if s.after != nil {
|
||||
m["after"] = *s.after
|
||||
}
|
||||
if s.orderAsc != nil {
|
||||
if *s.orderAsc {
|
||||
m["order"] = "asc"
|
||||
} else {
|
||||
m["order"] = "desc"
|
||||
}
|
||||
}
|
||||
|
||||
if resp, err = s.client.request(ctx).
|
||||
SetQueryParams(m).
|
||||
SetResult(&res).
|
||||
SetError(&apiError).
|
||||
Get("files"); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// FileDeleteService is a service to delete a file.
|
||||
type FileDeleteService struct {
|
||||
client *Client
|
||||
documentID string
|
||||
}
|
||||
|
||||
// NewFileDeleteService creates a new FileDeleteService.
|
||||
func NewFileDeleteService(client *Client) *FileDeleteService {
|
||||
return &FileDeleteService{client: client}
|
||||
}
|
||||
|
||||
// SetDocumentID sets the document_id parameter of the FileDeleteService.
|
||||
func (s *FileDeleteService) SetDocumentID(documentID string) *FileDeleteService {
|
||||
s.documentID = documentID
|
||||
return s
|
||||
}
|
||||
|
||||
// Do makes the request.
|
||||
func (s *FileDeleteService) Do(ctx context.Context) (err error) {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
|
||||
if resp, err = s.client.request(ctx).
|
||||
SetPathParam("document_id", s.documentID).
|
||||
SetError(&apiError).
|
||||
Delete("document/{document_id}"); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// FileGetService is a service to get a file.
|
||||
type FileGetService struct {
|
||||
client *Client
|
||||
documentID string
|
||||
}
|
||||
|
||||
// FileGetResponse is the response of the FileGetService.
|
||||
type FileGetResponse = FileListKnowledgeItem
|
||||
|
||||
// NewFileGetService creates a new FileGetService.
|
||||
func NewFileGetService(client *Client) *FileGetService {
|
||||
return &FileGetService{client: client}
|
||||
}
|
||||
|
||||
// SetDocumentID sets the document_id parameter of the FileGetService.
|
||||
func (s *FileGetService) SetDocumentID(documentID string) *FileGetService {
|
||||
s.documentID = documentID
|
||||
return s
|
||||
}
|
||||
|
||||
// Do makes the request.
|
||||
func (s *FileGetService) Do(ctx context.Context) (res FileGetResponse, err error) {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
|
||||
if resp, err = s.client.request(ctx).
|
||||
SetPathParam("document_id", s.documentID).
|
||||
SetResult(&res).
|
||||
SetError(&apiError).
|
||||
Get("document/{document_id}"); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// FileDownloadService is a service to download a file.
|
||||
type FileDownloadService struct {
|
||||
client *Client
|
||||
|
||||
fileID string
|
||||
|
||||
writer io.Writer
|
||||
filename string
|
||||
}
|
||||
|
||||
// NewFileDownloadService creates a new FileDownloadService.
|
||||
func NewFileDownloadService(client *Client) *FileDownloadService {
|
||||
return &FileDownloadService{client: client}
|
||||
}
|
||||
|
||||
// SetFileID sets the file_id parameter of the FileDownloadService.
|
||||
func (s *FileDownloadService) SetFileID(fileID string) *FileDownloadService {
|
||||
s.fileID = fileID
|
||||
return s
|
||||
}
|
||||
|
||||
// SetOutput sets the output parameter of the FileDownloadService.
|
||||
func (s *FileDownloadService) SetOutput(w io.Writer) *FileDownloadService {
|
||||
s.writer = w
|
||||
return s
|
||||
}
|
||||
|
||||
// SetOutputFile sets the output_file parameter of the FileDownloadService.
|
||||
func (s *FileDownloadService) SetOutputFile(filename string) *FileDownloadService {
|
||||
s.filename = filename
|
||||
return s
|
||||
}
|
||||
|
||||
// Do makes the request.
|
||||
func (s *FileDownloadService) Do(ctx context.Context) (err error) {
|
||||
var resp *resty.Response
|
||||
|
||||
writer := s.writer
|
||||
|
||||
if writer == nil && s.filename != "" {
|
||||
var f *os.File
|
||||
if f, err = os.Create(s.filename); err != nil {
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
writer = f
|
||||
}
|
||||
|
||||
if writer == nil {
|
||||
return errors.New("no output specified")
|
||||
}
|
||||
|
||||
if resp, err = s.client.request(ctx).
|
||||
SetDoNotParseResponse(true).
|
||||
SetPathParam("file_id", s.fileID).
|
||||
Get("files/{file_id}/content"); err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.RawBody().Close()
|
||||
|
||||
_, err = io.Copy(writer, resp.RawBody())
|
||||
|
||||
return
|
||||
}
|
71
llm/zhipu/zhipu/file_test.go
Normal file
71
llm/zhipu/zhipu/file_test.go
Normal file
@ -0,0 +1,71 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFileServiceFineTune(t *testing.T) {
|
||||
client, err := NewClient()
|
||||
require.NoError(t, err)
|
||||
|
||||
s := client.FileCreate(FilePurposeFineTune)
|
||||
s.SetLocalFile(filepath.Join("testdata", "test-file.jsonl"))
|
||||
|
||||
res, err := s.Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotZero(t, res.Bytes)
|
||||
require.NotZero(t, res.CreatedAt)
|
||||
require.NotEmpty(t, res.ID)
|
||||
}
|
||||
|
||||
func TestFileServiceKnowledge(t *testing.T) {
|
||||
client, err := NewClient()
|
||||
require.NoError(t, err)
|
||||
|
||||
s := client.FileCreate(FilePurposeRetrieval)
|
||||
s.SetKnowledgeID(os.Getenv("TEST_KNOWLEDGE_ID"))
|
||||
s.SetLocalFile(filepath.Join("testdata", "test-file.txt"))
|
||||
|
||||
res, err := s.Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, res.SuccessInfos)
|
||||
require.NotEmpty(t, res.SuccessInfos[0].DocumentID)
|
||||
require.NotEmpty(t, res.SuccessInfos[0].Filename)
|
||||
|
||||
documentID := res.SuccessInfos[0].DocumentID
|
||||
|
||||
res2, err := client.FileGet(documentID).Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, res2.ID)
|
||||
|
||||
err = client.FileEdit(documentID).SetKnowledgeType(KnowledgeTypeCustom).Do(context.Background())
|
||||
require.True(t, err == nil || GetAPIErrorCode(err) == "10019")
|
||||
|
||||
err = client.FileDelete(res.SuccessInfos[0].DocumentID).Do(context.Background())
|
||||
require.True(t, err == nil || GetAPIErrorCode(err) == "10019")
|
||||
}
|
||||
|
||||
func TestFileListServiceKnowledge(t *testing.T) {
|
||||
client, err := NewClient()
|
||||
require.NoError(t, err)
|
||||
|
||||
s := client.FileList(FilePurposeRetrieval).SetKnowledgeID(os.Getenv("TEST_KNOWLEDGE_ID"))
|
||||
res, err := s.Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, res.List)
|
||||
}
|
||||
|
||||
func TestFileListServiceFineTune(t *testing.T) {
|
||||
client, err := NewClient()
|
||||
require.NoError(t, err)
|
||||
|
||||
s := client.FileList(FilePurposeFineTune)
|
||||
res, err := s.Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, res.Data)
|
||||
}
|
456
llm/zhipu/zhipu/fine_tune.go
Normal file
456
llm/zhipu/zhipu/fine_tune.go
Normal file
@ -0,0 +1,456 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
HyperParameterAuto = "auto"
|
||||
|
||||
FineTuneStatusCreate = "create"
|
||||
FineTuneStatusValidatingFiles = "validating_files"
|
||||
FineTuneStatusQueued = "queued"
|
||||
FineTuneStatusRunning = "running"
|
||||
FineTuneStatusSucceeded = "succeeded"
|
||||
FineTuneStatusFailed = "failed"
|
||||
FineTuneStatusCancelled = "cancelled"
|
||||
)
|
||||
|
||||
// FineTuneItem is the item of the FineTune
|
||||
type FineTuneItem struct {
|
||||
ID string `json:"id"`
|
||||
RequestID string `json:"request_id"`
|
||||
FineTunedModel string `json:"fine_tuned_model"`
|
||||
Status string `json:"status"`
|
||||
Object string `json:"object"`
|
||||
TrainingFile string `json:"training_file"`
|
||||
ValidationFile string `json:"validation_file"`
|
||||
Error APIError `json:"error"`
|
||||
}
|
||||
|
||||
// FineTuneCreateService creates a new fine tune
|
||||
type FineTuneCreateService struct {
|
||||
client *Client
|
||||
|
||||
model string
|
||||
trainingFile string
|
||||
validationFile *string
|
||||
|
||||
learningRateMultiplier *StringOr[float64]
|
||||
batchSize *StringOr[int]
|
||||
nEpochs *StringOr[int]
|
||||
|
||||
suffix *string
|
||||
requestID *string
|
||||
}
|
||||
|
||||
// FineTuneCreateResponse is the response of the FineTuneCreateService
|
||||
type FineTuneCreateResponse = FineTuneItem
|
||||
|
||||
// NewFineTuneCreateService creates a new FineTuneCreateService
|
||||
func NewFineTuneCreateService(client *Client) *FineTuneCreateService {
|
||||
return &FineTuneCreateService{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// SetModel sets the model parameter
|
||||
func (s *FineTuneCreateService) SetModel(model string) *FineTuneCreateService {
|
||||
s.model = model
|
||||
return s
|
||||
}
|
||||
|
||||
// SetTrainingFile sets the trainingFile parameter
|
||||
func (s *FineTuneCreateService) SetTrainingFile(trainingFile string) *FineTuneCreateService {
|
||||
s.trainingFile = trainingFile
|
||||
return s
|
||||
}
|
||||
|
||||
// SetValidationFile sets the validationFile parameter
|
||||
func (s *FineTuneCreateService) SetValidationFile(validationFile string) *FineTuneCreateService {
|
||||
s.validationFile = &validationFile
|
||||
return s
|
||||
}
|
||||
|
||||
// SetLearningRateMultiplier sets the learningRateMultiplier parameter
|
||||
func (s *FineTuneCreateService) SetLearningRateMultiplier(learningRateMultiplier float64) *FineTuneCreateService {
|
||||
s.learningRateMultiplier = &StringOr[float64]{}
|
||||
s.learningRateMultiplier.SetValue(learningRateMultiplier)
|
||||
return s
|
||||
}
|
||||
|
||||
// SetLearningRateMultiplierAuto sets the learningRateMultiplier parameter to auto
|
||||
func (s *FineTuneCreateService) SetLearningRateMultiplierAuto() *FineTuneCreateService {
|
||||
s.learningRateMultiplier = &StringOr[float64]{}
|
||||
s.learningRateMultiplier.SetString(HyperParameterAuto)
|
||||
return s
|
||||
}
|
||||
|
||||
// SetBatchSize sets the batchSize parameter
|
||||
func (s *FineTuneCreateService) SetBatchSize(batchSize int) *FineTuneCreateService {
|
||||
s.batchSize = &StringOr[int]{}
|
||||
s.batchSize.SetValue(batchSize)
|
||||
return s
|
||||
}
|
||||
|
||||
// SetBatchSizeAuto sets the batchSize parameter to auto
|
||||
func (s *FineTuneCreateService) SetBatchSizeAuto() *FineTuneCreateService {
|
||||
s.batchSize = &StringOr[int]{}
|
||||
s.batchSize.SetString(HyperParameterAuto)
|
||||
return s
|
||||
}
|
||||
|
||||
// SetNEpochs sets the nEpochs parameter
|
||||
func (s *FineTuneCreateService) SetNEpochs(nEpochs int) *FineTuneCreateService {
|
||||
s.nEpochs = &StringOr[int]{}
|
||||
s.nEpochs.SetValue(nEpochs)
|
||||
return s
|
||||
}
|
||||
|
||||
// SetNEpochsAuto sets the nEpochs parameter to auto
|
||||
func (s *FineTuneCreateService) SetNEpochsAuto() *FineTuneCreateService {
|
||||
s.nEpochs = &StringOr[int]{}
|
||||
s.nEpochs.SetString(HyperParameterAuto)
|
||||
return s
|
||||
}
|
||||
|
||||
// SetSuffix sets the suffix parameter
|
||||
func (s *FineTuneCreateService) SetSuffix(suffix string) *FineTuneCreateService {
|
||||
s.suffix = &suffix
|
||||
return s
|
||||
}
|
||||
|
||||
// SetRequestID sets the requestID parameter
|
||||
func (s *FineTuneCreateService) SetRequestID(requestID string) *FineTuneCreateService {
|
||||
s.requestID = &requestID
|
||||
return s
|
||||
}
|
||||
|
||||
// Do makes the request
|
||||
func (s *FineTuneCreateService) Do(ctx context.Context) (res FineTuneCreateResponse, err error) {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
|
||||
body := M{
|
||||
"model": s.model,
|
||||
"training_file": s.trainingFile,
|
||||
}
|
||||
|
||||
if s.validationFile != nil {
|
||||
body["validation_file"] = *s.validationFile
|
||||
}
|
||||
if s.suffix != nil {
|
||||
body["suffix"] = *s.suffix
|
||||
}
|
||||
if s.requestID != nil {
|
||||
body["request_id"] = *s.requestID
|
||||
}
|
||||
if s.learningRateMultiplier != nil || s.batchSize != nil || s.nEpochs != nil {
|
||||
hp := M{}
|
||||
if s.learningRateMultiplier != nil {
|
||||
hp["learning_rate_multiplier"] = s.learningRateMultiplier
|
||||
}
|
||||
if s.batchSize != nil {
|
||||
hp["batch_size"] = s.batchSize
|
||||
}
|
||||
if s.nEpochs != nil {
|
||||
hp["n_epochs"] = s.nEpochs
|
||||
}
|
||||
body["hyperparameters"] = hp
|
||||
}
|
||||
|
||||
if resp, err = s.client.request(ctx).
|
||||
SetBody(body).
|
||||
SetResult(&res).
|
||||
SetError(&apiError).
|
||||
Post("fine_tuning/jobs"); err != nil {
|
||||
return
|
||||
}
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// FineTuneEventListService creates a new fine tune event list
|
||||
type FineTuneEventListService struct {
|
||||
client *Client
|
||||
|
||||
jobID string
|
||||
|
||||
limit *int
|
||||
after *string
|
||||
}
|
||||
|
||||
// FineTuneEventData is the data of the FineTuneEventItem
|
||||
type FineTuneEventData struct {
|
||||
Acc float64 `json:"acc"`
|
||||
Loss float64 `json:"loss"`
|
||||
CurrentSteps int64 `json:"current_steps"`
|
||||
RemainingTime string `json:"remaining_time"`
|
||||
ElapsedTime string `json:"elapsed_time"`
|
||||
TotalSteps int64 `json:"total_steps"`
|
||||
Epoch int64 `json:"epoch"`
|
||||
TrainedTokens int64 `json:"trained_tokens"`
|
||||
LearningRate float64 `json:"learning_rate"`
|
||||
}
|
||||
|
||||
// FineTuneEventItem is the item of the FineTuneEventListResponse
|
||||
type FineTuneEventItem struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Level string `json:"level"`
|
||||
Message string `json:"message"`
|
||||
Object string `json:"object"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Data FineTuneEventData `json:"data"`
|
||||
}
|
||||
|
||||
// FineTuneEventListResponse is the response of the FineTuneEventListService
|
||||
type FineTuneEventListResponse struct {
|
||||
Data []FineTuneEventItem `json:"data"`
|
||||
HasMore bool `json:"has_more"`
|
||||
Object string `json:"object"`
|
||||
}
|
||||
|
||||
// NewFineTuneEventListService creates a new FineTuneEventListService
|
||||
func NewFineTuneEventListService(client *Client) *FineTuneEventListService {
|
||||
return &FineTuneEventListService{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// SetJobID sets the jobID parameter
|
||||
func (s *FineTuneEventListService) SetJobID(jobID string) *FineTuneEventListService {
|
||||
s.jobID = jobID
|
||||
return s
|
||||
}
|
||||
|
||||
// SetLimit sets the limit parameter
|
||||
func (s *FineTuneEventListService) SetLimit(limit int) *FineTuneEventListService {
|
||||
s.limit = &limit
|
||||
return s
|
||||
}
|
||||
|
||||
// SetAfter sets the after parameter
|
||||
func (s *FineTuneEventListService) SetAfter(after string) *FineTuneEventListService {
|
||||
s.after = &after
|
||||
return s
|
||||
}
|
||||
|
||||
// Do makes the request
|
||||
func (s *FineTuneEventListService) Do(ctx context.Context) (res FineTuneEventListResponse, err error) {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
|
||||
req := s.client.request(ctx)
|
||||
|
||||
if s.limit != nil {
|
||||
req.SetQueryParam("limit", strconv.Itoa(*s.limit))
|
||||
}
|
||||
if s.after != nil {
|
||||
req.SetQueryParam("after", *s.after)
|
||||
}
|
||||
|
||||
if resp, err = req.
|
||||
SetPathParam("job_id", s.jobID).
|
||||
SetResult(&res).
|
||||
SetError(&apiError).
|
||||
Get("fine_tuning/jobs/{job_id}/events"); err != nil {
|
||||
return
|
||||
}
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// FineTuneGetService creates a new fine tune get
|
||||
type FineTuneGetService struct {
|
||||
client *Client
|
||||
jobID string
|
||||
}
|
||||
|
||||
// NewFineTuneGetService creates a new FineTuneGetService
|
||||
func NewFineTuneGetService(client *Client) *FineTuneGetService {
|
||||
return &FineTuneGetService{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// SetJobID sets the jobID parameter
|
||||
func (s *FineTuneGetService) SetJobID(jobID string) *FineTuneGetService {
|
||||
s.jobID = jobID
|
||||
return s
|
||||
}
|
||||
|
||||
// Do makes the request
|
||||
func (s *FineTuneGetService) Do(ctx context.Context) (res FineTuneItem, err error) {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
|
||||
if resp, err = s.client.request(ctx).
|
||||
SetPathParam("job_id", s.jobID).
|
||||
SetResult(&res).
|
||||
SetError(&apiError).
|
||||
Get("fine_tuning/jobs/{job_id}"); err != nil {
|
||||
return
|
||||
}
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// FineTuneListService creates a new fine tune list
|
||||
type FineTuneListService struct {
|
||||
client *Client
|
||||
|
||||
limit *int
|
||||
after *string
|
||||
}
|
||||
|
||||
// FineTuneListResponse is the response of the FineTuneListService
|
||||
type FineTuneListResponse struct {
|
||||
Data []FineTuneItem `json:"data"`
|
||||
Object string `json:"object"`
|
||||
}
|
||||
|
||||
// NewFineTuneListService creates a new FineTuneListService
|
||||
func NewFineTuneListService(client *Client) *FineTuneListService {
|
||||
return &FineTuneListService{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// SetLimit sets the limit parameter
|
||||
func (s *FineTuneListService) SetLimit(limit int) *FineTuneListService {
|
||||
s.limit = &limit
|
||||
return s
|
||||
}
|
||||
|
||||
// SetAfter sets the after parameter
|
||||
func (s *FineTuneListService) SetAfter(after string) *FineTuneListService {
|
||||
s.after = &after
|
||||
return s
|
||||
}
|
||||
|
||||
// Do makes the request
|
||||
func (s *FineTuneListService) Do(ctx context.Context) (res FineTuneListResponse, err error) {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
|
||||
req := s.client.request(ctx)
|
||||
if s.limit != nil {
|
||||
req.SetQueryParam("limit", strconv.Itoa(*s.limit))
|
||||
}
|
||||
if s.after != nil {
|
||||
req.SetQueryParam("after", *s.after)
|
||||
}
|
||||
|
||||
if resp, err = req.
|
||||
SetResult(&res).
|
||||
SetError(&apiError).
|
||||
Get("fine_tuning/jobs"); err != nil {
|
||||
return
|
||||
}
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// FineTuneDeleteService creates a new fine tune delete
|
||||
type FineTuneDeleteService struct {
|
||||
client *Client
|
||||
jobID string
|
||||
}
|
||||
|
||||
// NewFineTuneDeleteService creates a new FineTuneDeleteService
|
||||
func NewFineTuneDeleteService(client *Client) *FineTuneDeleteService {
|
||||
return &FineTuneDeleteService{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// SetJobID sets the jobID parameter
|
||||
func (s *FineTuneDeleteService) SetJobID(jobID string) *FineTuneDeleteService {
|
||||
s.jobID = jobID
|
||||
return s
|
||||
}
|
||||
|
||||
// Do makes the request
|
||||
func (s *FineTuneDeleteService) Do(ctx context.Context) (res FineTuneItem, err error) {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
|
||||
if resp, err = s.client.request(ctx).
|
||||
SetPathParam("job_id", s.jobID).
|
||||
SetResult(&res).
|
||||
SetError(&apiError).
|
||||
Delete("fine_tuning/jobs/{job_id}"); err != nil {
|
||||
return
|
||||
}
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// FineTuneCancelService creates a new fine tune cancel
|
||||
type FineTuneCancelService struct {
|
||||
client *Client
|
||||
jobID string
|
||||
}
|
||||
|
||||
// NewFineTuneCancelService creates a new FineTuneCancelService
|
||||
func NewFineTuneCancelService(client *Client) *FineTuneCancelService {
|
||||
return &FineTuneCancelService{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// SetJobID sets the jobID parameter
|
||||
func (s *FineTuneCancelService) SetJobID(jobID string) *FineTuneCancelService {
|
||||
s.jobID = jobID
|
||||
return s
|
||||
}
|
||||
|
||||
// Do makes the request
|
||||
func (s *FineTuneCancelService) Do(ctx context.Context) (res FineTuneItem, err error) {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
|
||||
if resp, err = s.client.request(ctx).
|
||||
SetPathParam("job_id", s.jobID).
|
||||
SetResult(&res).
|
||||
SetError(&apiError).
|
||||
Post("fine_tuning/jobs/{job_id}/cancel"); err != nil {
|
||||
return
|
||||
}
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
3
llm/zhipu/zhipu/fine_tune_test.go
Normal file
3
llm/zhipu/zhipu/fine_tune_test.go
Normal file
@ -0,0 +1,3 @@
|
||||
package zhipu
|
||||
|
||||
// tests not available since lack of budget to test it
|
110
llm/zhipu/zhipu/image_generation.go
Normal file
110
llm/zhipu/zhipu/image_generation.go
Normal file
@ -0,0 +1,110 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
// ImageGenerationService creates a new image generation
|
||||
type ImageGenerationService struct {
|
||||
client *Client
|
||||
|
||||
model string
|
||||
prompt string
|
||||
size string
|
||||
userID string
|
||||
}
|
||||
|
||||
var (
|
||||
_ BatchSupport = &ImageGenerationService{}
|
||||
)
|
||||
|
||||
// ImageGenerationResponse is the response of the ImageGenerationService
|
||||
type ImageGenerationResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []URLItem `json:"data"`
|
||||
}
|
||||
|
||||
// NewImageGenerationService creates a new ImageGenerationService
|
||||
func NewImageGenerationService(client *Client) *ImageGenerationService {
|
||||
return &ImageGenerationService{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ImageGenerationService) BatchMethod() string {
|
||||
return "POST"
|
||||
}
|
||||
|
||||
func (s *ImageGenerationService) BatchURL() string {
|
||||
return BatchEndpointV4ImagesGenerations
|
||||
}
|
||||
|
||||
func (s *ImageGenerationService) BatchBody() any {
|
||||
return s.buildBody()
|
||||
}
|
||||
|
||||
// SetModel sets the model parameter
|
||||
func (s *ImageGenerationService) SetModel(model string) *ImageGenerationService {
|
||||
s.model = model
|
||||
return s
|
||||
}
|
||||
|
||||
// SetPrompt sets the prompt parameter
|
||||
func (s *ImageGenerationService) SetPrompt(prompt string) *ImageGenerationService {
|
||||
s.prompt = prompt
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *ImageGenerationService) SetSize(size string) *ImageGenerationService {
|
||||
s.size = size
|
||||
return s
|
||||
}
|
||||
|
||||
// SetUserID sets the userID parameter
|
||||
func (s *ImageGenerationService) SetUserID(userID string) *ImageGenerationService {
|
||||
s.userID = userID
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *ImageGenerationService) buildBody() M {
|
||||
body := M{
|
||||
"model": s.model,
|
||||
"prompt": s.prompt,
|
||||
}
|
||||
|
||||
if s.userID != "" {
|
||||
body["user_id"] = s.userID
|
||||
}
|
||||
|
||||
if s.size != "" {
|
||||
body["size"] = s.size
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
func (s *ImageGenerationService) Do(ctx context.Context) (res ImageGenerationResponse, err error) {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
|
||||
body := s.buildBody()
|
||||
|
||||
if resp, err = s.client.request(ctx).
|
||||
SetBody(body).
|
||||
SetResult(&res).
|
||||
SetError(&apiError).
|
||||
Post("images/generations"); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
21
llm/zhipu/zhipu/image_generation_test.go
Normal file
21
llm/zhipu/zhipu/image_generation_test.go
Normal file
@ -0,0 +1,21 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestImageGenerationService(t *testing.T) {
|
||||
client, err := NewClient()
|
||||
require.NoError(t, err)
|
||||
|
||||
s := client.ImageGeneration("cogview-3")
|
||||
s.SetPrompt("一只可爱的小猫")
|
||||
|
||||
res, err := s.Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, res.Data)
|
||||
t.Log(res.Data[0].URL)
|
||||
}
|
299
llm/zhipu/zhipu/knowledge.go
Normal file
299
llm/zhipu/zhipu/knowledge.go
Normal file
@ -0,0 +1,299 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
KnowledgeEmbeddingIDEmbedding2 = 3
|
||||
)
|
||||
|
||||
// KnowledgeCreateService creates a new knowledge
|
||||
type KnowledgeCreateService struct {
|
||||
client *Client
|
||||
|
||||
embeddingID int
|
||||
name string
|
||||
description *string
|
||||
}
|
||||
|
||||
// KnowledgeCreateResponse is the response of the KnowledgeCreateService
|
||||
type KnowledgeCreateResponse = IDItem
|
||||
|
||||
// NewKnowledgeCreateService creates a new KnowledgeCreateService
|
||||
func NewKnowledgeCreateService(client *Client) *KnowledgeCreateService {
|
||||
return &KnowledgeCreateService{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// SetEmbeddingID sets the embedding id of the knowledge
|
||||
func (s *KnowledgeCreateService) SetEmbeddingID(embeddingID int) *KnowledgeCreateService {
|
||||
s.embeddingID = embeddingID
|
||||
return s
|
||||
}
|
||||
|
||||
// SetName sets the name of the knowledge
|
||||
func (s *KnowledgeCreateService) SetName(name string) *KnowledgeCreateService {
|
||||
s.name = name
|
||||
return s
|
||||
}
|
||||
|
||||
// SetDescription sets the description of the knowledge
|
||||
func (s *KnowledgeCreateService) SetDescription(description string) *KnowledgeCreateService {
|
||||
s.description = &description
|
||||
return s
|
||||
}
|
||||
|
||||
// Do creates the knowledge
|
||||
func (s *KnowledgeCreateService) Do(ctx context.Context) (res KnowledgeCreateResponse, err error) {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
body := M{
|
||||
"name": s.name,
|
||||
"embedding_id": s.embeddingID,
|
||||
}
|
||||
if s.description != nil {
|
||||
body["description"] = *s.description
|
||||
}
|
||||
if resp, err = s.client.request(ctx).
|
||||
SetBody(body).
|
||||
SetResult(&res).
|
||||
SetError(&apiError).
|
||||
Post("knowledge"); err != nil {
|
||||
return
|
||||
}
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// KnowledgeEditService edits a knowledge
|
||||
type KnowledgeEditService struct {
|
||||
client *Client
|
||||
|
||||
knowledgeID string
|
||||
|
||||
embeddingID *int
|
||||
name *string
|
||||
description *string
|
||||
}
|
||||
|
||||
// NewKnowledgeEditService creates a new KnowledgeEditService
|
||||
func NewKnowledgeEditService(client *Client) *KnowledgeEditService {
|
||||
return &KnowledgeEditService{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// SetKnowledgeID sets the knowledge id
|
||||
func (s *KnowledgeEditService) SetKnowledgeID(knowledgeID string) *KnowledgeEditService {
|
||||
s.knowledgeID = knowledgeID
|
||||
return s
|
||||
}
|
||||
|
||||
// SetName sets the name of the knowledge
|
||||
func (s *KnowledgeEditService) SetName(name string) *KnowledgeEditService {
|
||||
s.name = &name
|
||||
return s
|
||||
}
|
||||
|
||||
// SetEmbeddingID sets the embedding id of the knowledge
|
||||
func (s *KnowledgeEditService) SetEmbeddingID(embeddingID int) *KnowledgeEditService {
|
||||
s.embeddingID = &embeddingID
|
||||
return s
|
||||
}
|
||||
|
||||
// SetDescription sets the description of the knowledge
|
||||
func (s *KnowledgeEditService) SetDescription(description string) *KnowledgeEditService {
|
||||
s.description = &description
|
||||
return s
|
||||
}
|
||||
|
||||
// Do edits the knowledge
|
||||
func (s *KnowledgeEditService) Do(ctx context.Context) (err error) {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
body := M{}
|
||||
if s.name != nil {
|
||||
body["name"] = *s.name
|
||||
}
|
||||
if s.description != nil {
|
||||
body["description"] = *s.description
|
||||
}
|
||||
if s.embeddingID != nil {
|
||||
body["embedding_id"] = *s.embeddingID
|
||||
}
|
||||
if resp, err = s.client.request(ctx).
|
||||
SetPathParam("knowledge_id", s.knowledgeID).
|
||||
SetBody(body).
|
||||
SetError(&apiError).
|
||||
Put("knowledge/{knowledge_id}"); err != nil {
|
||||
return
|
||||
}
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// KnowledgeListService lists the knowledge
|
||||
type KnowledgeListService struct {
|
||||
client *Client
|
||||
|
||||
page *int
|
||||
size *int
|
||||
}
|
||||
|
||||
// KnowledgeItem is an item in the knowledge list
|
||||
type KnowledgeItem struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Icon string `json:"icon"`
|
||||
Background string `json:"background"`
|
||||
EmbeddingID int `json:"embedding_id"`
|
||||
CustomIdentifier string `json:"custom_identifier"`
|
||||
WordNum int64 `json:"word_num"`
|
||||
Length int64 `json:"length"`
|
||||
DocumentSize int64 `json:"document_size"`
|
||||
}
|
||||
|
||||
// KnowledgeListResponse is the response of the KnowledgeListService
|
||||
type KnowledgeListResponse struct {
|
||||
List []KnowledgeItem `json:"list"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
|
||||
// NewKnowledgeListService creates a new KnowledgeListService
|
||||
func NewKnowledgeListService(client *Client) *KnowledgeListService {
|
||||
return &KnowledgeListService{client: client}
|
||||
}
|
||||
|
||||
// SetPage sets the page of the knowledge list
|
||||
func (s *KnowledgeListService) SetPage(page int) *KnowledgeListService {
|
||||
s.page = &page
|
||||
return s
|
||||
}
|
||||
|
||||
// SetSize sets the size of the knowledge list
|
||||
func (s *KnowledgeListService) SetSize(size int) *KnowledgeListService {
|
||||
s.size = &size
|
||||
return s
|
||||
}
|
||||
|
||||
// Do lists the knowledge
|
||||
func (s *KnowledgeListService) Do(ctx context.Context) (res KnowledgeListResponse, err error) {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
req := s.client.request(ctx)
|
||||
if s.page != nil {
|
||||
req.SetQueryParam("page", strconv.Itoa(*s.page))
|
||||
}
|
||||
if s.size != nil {
|
||||
req.SetQueryParam("size", strconv.Itoa(*s.size))
|
||||
}
|
||||
if resp, err = req.
|
||||
SetResult(&res).
|
||||
SetError(&apiError).
|
||||
Get("knowledge"); err != nil {
|
||||
return
|
||||
}
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// KnowledgeDeleteService deletes a knowledge
|
||||
type KnowledgeDeleteService struct {
|
||||
client *Client
|
||||
|
||||
knowledgeID string
|
||||
}
|
||||
|
||||
// NewKnowledgeDeleteService creates a new KnowledgeDeleteService
|
||||
func NewKnowledgeDeleteService(client *Client) *KnowledgeDeleteService {
|
||||
return &KnowledgeDeleteService{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// SetKnowledgeID sets the knowledge id
|
||||
func (s *KnowledgeDeleteService) SetKnowledgeID(knowledgeID string) *KnowledgeDeleteService {
|
||||
s.knowledgeID = knowledgeID
|
||||
return s
|
||||
}
|
||||
|
||||
// Do deletes the knowledge
|
||||
func (s *KnowledgeDeleteService) Do(ctx context.Context) (err error) {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
if resp, err = s.client.request(ctx).
|
||||
SetPathParam("knowledge_id", s.knowledgeID).
|
||||
SetError(&apiError).
|
||||
Delete("knowledge/{knowledge_id}"); err != nil {
|
||||
return
|
||||
}
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// KnowledgeCapacityService query the capacity of the knowledge
|
||||
type KnowledgeCapacityService struct {
|
||||
client *Client
|
||||
}
|
||||
|
||||
// KnowledgeCapacityItem is an item in the knowledge capacity
|
||||
type KnowledgeCapacityItem struct {
|
||||
WordNum int64 `json:"word_num"`
|
||||
Length int64 `json:"length"`
|
||||
}
|
||||
|
||||
// KnowledgeCapacityResponse is the response of the KnowledgeCapacityService
|
||||
type KnowledgeCapacityResponse struct {
|
||||
Used KnowledgeCapacityItem `json:"used"`
|
||||
Total KnowledgeCapacityItem `json:"total"`
|
||||
}
|
||||
|
||||
// SetKnowledgeID sets the knowledge id
|
||||
func NewKnowledgeCapacityService(client *Client) *KnowledgeCapacityService {
|
||||
return &KnowledgeCapacityService{client: client}
|
||||
}
|
||||
|
||||
// Do query the capacity of the knowledge
|
||||
func (s *KnowledgeCapacityService) Do(ctx context.Context) (res KnowledgeCapacityResponse, err error) {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
if resp, err = s.client.request(ctx).
|
||||
SetResult(&res).
|
||||
SetError(&apiError).
|
||||
Get("knowledge/capacity"); err != nil {
|
||||
return
|
||||
}
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
50
llm/zhipu/zhipu/knowledge_test.go
Normal file
50
llm/zhipu/zhipu/knowledge_test.go
Normal file
@ -0,0 +1,50 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestKnowledgeCapacity(t *testing.T) {
|
||||
client, err := NewClient()
|
||||
require.NoError(t, err)
|
||||
|
||||
s := client.KnowledgeCapacity()
|
||||
res, err := s.Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, res.Total.Length)
|
||||
require.NotEmpty(t, res.Total.WordNum)
|
||||
}
|
||||
|
||||
func TestKnowledgeServiceAll(t *testing.T) {
|
||||
client, err := NewClient()
|
||||
require.NoError(t, err)
|
||||
|
||||
s := client.KnowledgeCreate()
|
||||
s.SetName("test")
|
||||
s.SetDescription("test description")
|
||||
s.SetEmbeddingID(KnowledgeEmbeddingIDEmbedding2)
|
||||
|
||||
res, err := s.Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, res.ID)
|
||||
|
||||
s2 := client.KnowledgeList()
|
||||
res2, err := s2.Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, res2.List)
|
||||
require.Equal(t, res.ID, res2.List[0].ID)
|
||||
|
||||
s3 := client.KnowledgeEdit(res.ID)
|
||||
s3.SetDescription("test description 2")
|
||||
s3.SetName("test 2")
|
||||
s3.SetEmbeddingID(KnowledgeEmbeddingIDEmbedding2)
|
||||
err = s3.Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
s4 := client.KnowledgeDelete(res.ID)
|
||||
err = s4.Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
}
|
54
llm/zhipu/zhipu/string_or.go
Normal file
54
llm/zhipu/zhipu/string_or.go
Normal file
@ -0,0 +1,54 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// StringOr is a struct that can be either a string or a value of type T.
|
||||
type StringOr[T any] struct {
|
||||
String *string
|
||||
Value *T
|
||||
}
|
||||
|
||||
var (
|
||||
_ json.Marshaler = StringOr[float64]{}
|
||||
_ json.Unmarshaler = &StringOr[float64]{}
|
||||
)
|
||||
|
||||
// SetString sets the string value of the struct.
|
||||
func (f *StringOr[T]) SetString(v string) {
|
||||
f.String = &v
|
||||
f.Value = nil
|
||||
}
|
||||
|
||||
// SetValue sets the value of the struct.
|
||||
func (f *StringOr[T]) SetValue(v T) {
|
||||
f.String = nil
|
||||
f.Value = &v
|
||||
}
|
||||
|
||||
func (f StringOr[T]) MarshalJSON() ([]byte, error) {
|
||||
if f.Value != nil {
|
||||
return json.Marshal(f.Value)
|
||||
}
|
||||
return json.Marshal(f.String)
|
||||
}
|
||||
|
||||
func (f *StringOr[T]) UnmarshalJSON(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
if bytes.Equal(data, []byte("null")) {
|
||||
return nil
|
||||
}
|
||||
if data[0] == '"' {
|
||||
f.String = new(string)
|
||||
f.Value = nil
|
||||
return json.Unmarshal(data, f.String)
|
||||
} else {
|
||||
f.Value = new(T)
|
||||
f.String = nil
|
||||
return json.Unmarshal(data, f.Value)
|
||||
}
|
||||
}
|
37
llm/zhipu/zhipu/string_or_test.go
Normal file
37
llm/zhipu/zhipu/string_or_test.go
Normal file
@ -0,0 +1,37 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStringOr(t *testing.T) {
|
||||
data := struct {
|
||||
Item *StringOr[int] `json:"item,omitempty"`
|
||||
}{}
|
||||
data.Item = &StringOr[int]{}
|
||||
data.Item.SetString("test")
|
||||
|
||||
b, err := json.Marshal(data)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, `{"item":"test"}`, string(b))
|
||||
|
||||
data.Item.SetValue(1)
|
||||
b, err = json.Marshal(data)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, `{"item":1}`, string(b))
|
||||
|
||||
err = json.Unmarshal([]byte(`{"item":"test2"}`), &data)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, data.Item.String)
|
||||
require.Nil(t, data.Item.Value)
|
||||
require.Equal(t, "test2", *data.Item.String)
|
||||
|
||||
err = json.Unmarshal([]byte(`{"item":2}`), &data)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, data.Item.String)
|
||||
require.NotNil(t, data.Item.Value)
|
||||
require.Equal(t, 2, *data.Item.Value)
|
||||
}
|
3
llm/zhipu/zhipu/testdata/test-file.jsonl
vendored
Normal file
3
llm/zhipu/zhipu/testdata/test-file.jsonl
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
{"messages": [{"role": "system", "content": "你是一位乐于助人,知识渊博的全能AI助手。"},{"role": "user", "content": "请介绍一下AlphaGo的基本原理。"}, {"role": "assistant", "content": "AlphaGo利用深度神经网络与树搜索算法相结合,通过对弈数据进行训练,逐步提升游戏水平,能够对围棋局面进行判断并选择最佳落子位置。"}]}
|
||||
{"messages": [{"role": "system", "content": "你是一位乐于助人,知识渊博的全能AI助手。"},{"role": "user", "content": "请列举一些健康饮食的重要原则。"}, {"role": "assistant", "content": "健康饮食的重要原则包括:膳食应营养均衡,合理摄入各种营养素;多吃新鲜蔬果,适量饮水;少吃糖、油、盐,忌吃过量快餐;养成规律生活作息习惯,适量运动。"}]}
|
||||
{"messages": [{"role": "system", "content": "你是一位乐于助人,知识渊博的全能AI助手。"},{"role": "user", "content": "请给出一些改善睡眠质量的建议。"}, {"role": "assistant", "content": "改善睡眠质量的建议包括:保证睡眠时间,避免睡眠不足;睡前放松身心,远离刺激;养成规律作息时间,不要频繁熬夜;适量运动,但不要运动过度;睡前可以喝一杯热牛奶等温和饮料。"}]}
|
1
llm/zhipu/zhipu/testdata/test-file.txt
vendored
Normal file
1
llm/zhipu/zhipu/testdata/test-file.txt
vendored
Normal file
@ -0,0 +1 @@
|
||||
The quick brown fox jumps over the lazy dog
|
22
llm/zhipu/zhipu/util.go
Normal file
22
llm/zhipu/zhipu/util.go
Normal file
@ -0,0 +1,22 @@
|
||||
package zhipu
|
||||
|
||||
// URLItem is a struct that contains a URL.
|
||||
type URLItem struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
}
|
||||
|
||||
// IDItem is a struct that contains an ID.
|
||||
type IDItem struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
}
|
||||
|
||||
// Ptr returns a pointer to the value passed in.
|
||||
// Example:
|
||||
//
|
||||
// web_search_enable = zhipu.Ptr(false)
|
||||
func Ptr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
// M is a shorthand for map[string]any.
|
||||
type M = map[string]any
|
3
llm/zhipu/zhipu/util_test.go
Normal file
3
llm/zhipu/zhipu/util_test.go
Normal file
@ -0,0 +1,3 @@
|
||||
package zhipu
|
||||
|
||||
// nothing to test
|
125
llm/zhipu/zhipu/video_generation.go
Normal file
125
llm/zhipu/zhipu/video_generation.go
Normal file
@ -0,0 +1,125 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
VideoGenerationTaskStatusProcessing = "PROCESSING"
|
||||
VideoGenerationTaskStatusSuccess = "SUCCESS"
|
||||
VideoGenerationTaskStatusFail = "FAIL"
|
||||
)
|
||||
|
||||
// VideoGenerationService creates a new video generation
|
||||
type VideoGenerationService struct {
|
||||
client *Client
|
||||
|
||||
model string
|
||||
prompt string
|
||||
userID string
|
||||
imageURL string
|
||||
requestID string
|
||||
}
|
||||
|
||||
var (
|
||||
_ BatchSupport = &VideoGenerationService{}
|
||||
)
|
||||
|
||||
// VideoGenerationResponse is the response of the VideoGenerationService
|
||||
type VideoGenerationResponse struct {
|
||||
RequestID string `json:"request_id"`
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
}
|
||||
|
||||
func NewVideoGenerationService(client *Client) *VideoGenerationService {
|
||||
return &VideoGenerationService{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *VideoGenerationService) BatchMethod() string {
|
||||
return "POST"
|
||||
}
|
||||
|
||||
func (s *VideoGenerationService) BatchURL() string {
|
||||
return BatchEndpointV4VideosGenerations
|
||||
}
|
||||
|
||||
func (s *VideoGenerationService) BatchBody() any {
|
||||
return s.buildBody()
|
||||
}
|
||||
|
||||
// SetModel sets the model parameter
|
||||
func (s *VideoGenerationService) SetModel(model string) *VideoGenerationService {
|
||||
s.model = model
|
||||
return s
|
||||
}
|
||||
|
||||
// SetPrompt sets the prompt parameter
|
||||
func (s *VideoGenerationService) SetPrompt(prompt string) *VideoGenerationService {
|
||||
s.prompt = prompt
|
||||
return s
|
||||
}
|
||||
|
||||
// SetUserID sets the userID parameter
|
||||
func (s *VideoGenerationService) SetUserID(userID string) *VideoGenerationService {
|
||||
s.userID = userID
|
||||
return s
|
||||
}
|
||||
|
||||
// SetImageURL sets the imageURL parameter
|
||||
func (s *VideoGenerationService) SetImageURL(imageURL string) *VideoGenerationService {
|
||||
s.imageURL = imageURL
|
||||
return s
|
||||
}
|
||||
|
||||
// SetRequestID sets the requestID parameter
|
||||
func (s *VideoGenerationService) SetRequestID(requestID string) *VideoGenerationService {
|
||||
s.requestID = requestID
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *VideoGenerationService) buildBody() M {
|
||||
body := M{
|
||||
"model": s.model,
|
||||
"prompt": s.prompt,
|
||||
}
|
||||
if s.userID != "" {
|
||||
body["user_id"] = s.userID
|
||||
}
|
||||
if s.imageURL != "" {
|
||||
body["image_url"] = s.imageURL
|
||||
}
|
||||
if s.requestID != "" {
|
||||
body["request_id"] = s.requestID
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func (s *VideoGenerationService) Do(ctx context.Context) (res VideoGenerationResponse, err error) {
|
||||
var (
|
||||
resp *resty.Response
|
||||
apiError APIErrorResponse
|
||||
)
|
||||
|
||||
body := s.buildBody()
|
||||
|
||||
if resp, err = s.client.request(ctx).
|
||||
SetBody(body).
|
||||
SetResult(&res).
|
||||
SetError(&apiError).
|
||||
Post("videos/generations"); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.IsError() {
|
||||
err = apiError
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
38
llm/zhipu/zhipu/video_generation_test.go
Normal file
38
llm/zhipu/zhipu/video_generation_test.go
Normal file
@ -0,0 +1,38 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVideoGeneration(t *testing.T) {
|
||||
client, err := NewClient()
|
||||
require.NoError(t, err)
|
||||
|
||||
s := client.VideoGeneration("cogvideox")
|
||||
s.SetPrompt("一只可爱的小猫")
|
||||
|
||||
res, err := s.Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, res.TaskStatus)
|
||||
require.NotEmpty(t, res.ID)
|
||||
t.Log(res.ID)
|
||||
|
||||
for {
|
||||
res, err := client.AsyncResult(res.ID).Do(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, res.TaskStatus)
|
||||
if res.TaskStatus == VideoGenerationTaskStatusSuccess {
|
||||
require.NotEmpty(t, res.VideoResult)
|
||||
t.Log(res.VideoResult[0].URL)
|
||||
t.Log(res.VideoResult[0].CoverImageURL)
|
||||
}
|
||||
if res.TaskStatus != VideoGenerationTaskStatusProcessing {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Second * 5)
|
||||
}
|
||||
}
|
BIN
llm/zhipu/zhipu/wechat-donation.png
Normal file
BIN
llm/zhipu/zhipu/wechat-donation.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 46 KiB |
9
tests/ask.js
Normal file
9
tests/ask.js
Normal file
@ -0,0 +1,9 @@
|
||||
import {glm, gpt} from './lib/ai'
|
||||
import console from './lib/console'
|
||||
import chat from './mod/chat'
|
||||
|
||||
function main(...args) {
|
||||
let ask = args[0]
|
||||
if (!ask) ask = console.input('请输入问题:')
|
||||
return chat.fast(ask)
|
||||
}
|
@ -2,51 +2,48 @@ package tests
|
||||
|
||||
import (
|
||||
"apigo.cc/ai/ai"
|
||||
_ "apigo.cc/ai/openai"
|
||||
_ "apigo.cc/ai/zhipu"
|
||||
"apigo.cc/ai/ai/js"
|
||||
"apigo.cc/ai/ai/llm"
|
||||
"fmt"
|
||||
"github.com/ssgo/config"
|
||||
"github.com/ssgo/u"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var agentConf = map[string]ai.AgentConfig{}
|
||||
|
||||
func init() {
|
||||
_ = config.LoadConfig("agent", &agentConf)
|
||||
}
|
||||
|
||||
func TestAgent(t *testing.T) {
|
||||
//testChat(t, "openai", "https://api.keya.pw/v1", keys.Openai)
|
||||
ai.Init()
|
||||
//testChat(t, "gpt")
|
||||
// 尚未支持
|
||||
//testCode(t, "openai", "https://api.keya.pw/v1", keys.Openai)
|
||||
//testCode(t, "gpt")
|
||||
// keya 消耗过大,计费不合理
|
||||
//testAskWithImage(t, "openai", "https://api.keya.pw/v1", keys.Openai, "4032.jpg")
|
||||
//testAskWithImage(t, "gpt", "4032.jpg")
|
||||
// keya 不支持
|
||||
//testMakeImage(t, "openai", "https://api.keya.pw/v1", keys.Openai, "冬天大雪纷飞,一个男人身穿军绿色棉大衣,戴着红色围巾和绿色帽子走在铺面大雪的小镇路上", "")
|
||||
//testMakeImage(t, "gpt", "冬天大雪纷飞,一个男人身穿军绿色棉大衣,戴着红色围巾和绿色帽子走在铺面大雪的小镇路上", "")
|
||||
|
||||
testChat(t, "zhipu")
|
||||
//testCode(t, "zhipu", "", keys.Zhipu)
|
||||
//testSearch(t, "zhipu", "", keys.Zhipu)
|
||||
//testChat(t, "glm")
|
||||
//testCode(t, "glm", "", keys.Zhipu)
|
||||
//testSearch(t, "glm", "", keys.Zhipu)
|
||||
|
||||
// 测试图片识别
|
||||
//testAskWithImage(t, "zhipu", "", keys.Zhipu, "4032.jpg")
|
||||
//testAskWithImage(t, "glm", "", keys.Zhipu, "4032.jpg")
|
||||
|
||||
// 视频似乎尚不支持 glm-4v-plus
|
||||
//testAskWithVideo(t, "zhipu", "", keys.Zhipu, "glm-4v", "1080.mp4")
|
||||
//testAskWithVideo(t, "glm", "", keys.Zhipu, "glm-4v", "1080.mp4")
|
||||
|
||||
//testMakeImage(t, "zhipu", "", keys.Zhipu, "冬天大雪纷飞,一个男人身穿军绿色棉大衣,戴着红色围巾和绿色帽子走在铺面大雪的小镇路上", "")
|
||||
//testMakeVideo(t, "zhipu", "", keys.Zhipu, "大雪纷飞,男人蹦蹦跳跳", "https://aigc-files.bigmodel.cn/api/cogview/20240904133130c4b7121019724aa3_0.png")
|
||||
//testMakeImage(t, "glm", "", keys.Zhipu, "冬天大雪纷飞,一个男人身穿军绿色棉大衣,戴着红色围巾和绿色帽子走在铺面大雪的小镇路上", "")
|
||||
//testMakeVideo(t, "glm", "", keys.Zhipu, "大雪纷飞,男人蹦蹦跳跳", "https://aigc-files.bigmodel.cn/api/cogview/20240904133130c4b7121019724aa3_0.png")
|
||||
|
||||
testJS(t)
|
||||
//testFile(t)
|
||||
}
|
||||
|
||||
func testChat(t *testing.T, agent string) {
|
||||
ag := ai.GetAgent(agent)
|
||||
func testChat(t *testing.T, llmName string) {
|
||||
lm := llm.Get(llmName)
|
||||
|
||||
if ag == nil {
|
||||
if lm == nil {
|
||||
t.Fatal("agent is nil")
|
||||
}
|
||||
|
||||
r, usage, err := ag.FastAsk(ai.Messages().User().Text("你是什么模型,请给出具体名称、版本号").Make(), func(text string) {
|
||||
r, usage, err := lm.FastAsk(llm.Messages().User().Text("你是什么模型,请给出具体名称、版本号").Make(), func(text string) {
|
||||
fmt.Print(u.BCyan(text))
|
||||
fmt.Print(" ")
|
||||
})
|
||||
@ -60,14 +57,14 @@ func testChat(t *testing.T, agent string) {
|
||||
fmt.Println("usage:", u.JsonP(usage))
|
||||
}
|
||||
|
||||
func testCode(t *testing.T, agent string) {
|
||||
ag := ai.GetAgent(agent)
|
||||
func testCode(t *testing.T, llmName string) {
|
||||
lm := llm.Get(llmName)
|
||||
|
||||
if ag == nil {
|
||||
if lm == nil {
|
||||
t.Fatal("agent is nil")
|
||||
}
|
||||
|
||||
r, usage, err := ag.CodeInterpreterAsk(ai.Messages().User().Text("计算[5,10,20,700,99,310,978,100]的平均值和方差。").Make(), func(text string) {
|
||||
r, usage, err := lm.CodeInterpreterAsk(llm.Messages().User().Text("计算[5,10,20,700,99,310,978,100]的平均值和方差。").Make(), func(text string) {
|
||||
fmt.Print(u.BCyan(text))
|
||||
fmt.Print(" ")
|
||||
})
|
||||
@ -81,14 +78,14 @@ func testCode(t *testing.T, agent string) {
|
||||
fmt.Println("usage:", u.JsonP(usage))
|
||||
}
|
||||
|
||||
func testSearch(t *testing.T, agent string) {
|
||||
ag := ai.GetAgent(agent)
|
||||
func testSearch(t *testing.T, llmName string) {
|
||||
lm := llm.Get(llmName)
|
||||
|
||||
if ag == nil {
|
||||
if lm == nil {
|
||||
t.Fatal("agent is nil")
|
||||
}
|
||||
|
||||
r, usage, err := ag.WebSearchAsk(ai.Messages().User().Text("今天上海的天气怎么样?").Make(), func(text string) {
|
||||
r, usage, err := lm.WebSearchAsk(llm.Messages().User().Text("今天上海的天气怎么样?").Make(), func(text string) {
|
||||
fmt.Print(u.BCyan(text))
|
||||
fmt.Print(" ")
|
||||
})
|
||||
@ -102,10 +99,10 @@ func testSearch(t *testing.T, agent string) {
|
||||
fmt.Println("usage:", u.JsonP(usage))
|
||||
}
|
||||
|
||||
func testAskWithImage(t *testing.T, agent, imageFile string) {
|
||||
ag := ai.GetAgent(agent)
|
||||
func testAskWithImage(t *testing.T, llmName, imageFile string) {
|
||||
lm := llm.Get(llmName)
|
||||
|
||||
if ag == nil {
|
||||
if lm == nil {
|
||||
t.Fatal("agent is nil")
|
||||
}
|
||||
|
||||
@ -115,7 +112,7 @@ func testAskWithImage(t *testing.T, agent, imageFile string) {
|
||||
3、正在用什么软件播放什么歌?谁演唱的?歌曲的大意是?
|
||||
4、后面的浏览器中正在浏览什么内容?猜测一下我浏览这个网页是想干嘛?
|
||||
`
|
||||
r, usage, err := ag.MultiAsk(ai.Messages().User().Text(ask).Image("data:image/jpeg;base64,"+u.Base64(u.ReadFileBytesN(imageFile))).Make(), func(text string) {
|
||||
r, usage, err := lm.MultiAsk(llm.Messages().User().Text(ask).Image("data:image/jpeg;base64,"+u.Base64(u.ReadFileBytesN(imageFile))).Make(), func(text string) {
|
||||
fmt.Print(u.BCyan(text))
|
||||
fmt.Print(" ")
|
||||
})
|
||||
@ -129,10 +126,10 @@ func testAskWithImage(t *testing.T, agent, imageFile string) {
|
||||
fmt.Println("usage:", u.JsonP(usage))
|
||||
}
|
||||
|
||||
func testAskWithVideo(t *testing.T, agent, videoFile string) {
|
||||
ag := ai.GetAgent(agent)
|
||||
func testAskWithVideo(t *testing.T, llmName, videoFile string) {
|
||||
lm := llm.Get(llmName)
|
||||
|
||||
if ag == nil {
|
||||
if lm == nil {
|
||||
t.Fatal("agent is nil")
|
||||
}
|
||||
|
||||
@ -141,7 +138,7 @@ func testAskWithVideo(t *testing.T, agent, videoFile string) {
|
||||
4、后面的浏览器中正在浏览什么内容?猜测一下我浏览这个网页是想干嘛?
|
||||
`
|
||||
|
||||
r, usage, err := ag.MultiAsk(ai.Messages().User().Text(ask).Video("data:video/mp4,"+u.Base64(u.ReadFileBytesN(videoFile))).Make(), func(text string) {
|
||||
r, usage, err := lm.MultiAsk(llm.Messages().User().Text(ask).Video("data:video/mp4,"+u.Base64(u.ReadFileBytesN(videoFile))).Make(), func(text string) {
|
||||
fmt.Print(u.BCyan(text))
|
||||
fmt.Print(" ")
|
||||
})
|
||||
@ -155,14 +152,16 @@ func testAskWithVideo(t *testing.T, agent, videoFile string) {
|
||||
fmt.Println("usage:", u.JsonP(usage))
|
||||
}
|
||||
|
||||
func testMakeImage(t *testing.T, agent, prompt, refImage string) {
|
||||
ag := ai.GetAgent(agent)
|
||||
func testMakeImage(t *testing.T, llmName, prompt, refImage string) {
|
||||
lm := llm.Get(llmName)
|
||||
|
||||
if ag == nil {
|
||||
if lm == nil {
|
||||
t.Fatal("agent is nil")
|
||||
}
|
||||
|
||||
r, err := ag.FastMakeImage(prompt, "1024x1024", "")
|
||||
r, err := lm.FastMakeImage(prompt, llm.GCConfig{
|
||||
Size: "1024x1024",
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("发生错误", err.Error())
|
||||
@ -174,14 +173,17 @@ func testMakeImage(t *testing.T, agent, prompt, refImage string) {
|
||||
}
|
||||
}
|
||||
|
||||
func testMakeVideo(t *testing.T, agent, prompt, refImage string) {
|
||||
ag := ai.GetAgent(agent)
|
||||
func testMakeVideo(t *testing.T, llmName, prompt, refImage string) {
|
||||
lm := llm.Get(llmName)
|
||||
|
||||
if ag == nil {
|
||||
if lm == nil {
|
||||
t.Fatal("agent is nil")
|
||||
}
|
||||
|
||||
r, covers, err := ag.FastMakeVideo(prompt, "1280x720", refImage)
|
||||
r, covers, err := lm.FastMakeVideo(prompt, llm.GCConfig{
|
||||
Size: "1280x720",
|
||||
Ref: refImage,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("发生错误", err.Error())
|
||||
@ -192,3 +194,31 @@ func testMakeVideo(t *testing.T, agent, prompt, refImage string) {
|
||||
fmt.Println("result:", i, v, covers[i])
|
||||
}
|
||||
}
|
||||
|
||||
func testJS(t *testing.T) {
|
||||
r1, err := js.RunFile("test.js", "1+2=4吗")
|
||||
if err != nil {
|
||||
t.Fatal("发生错误", err.Error())
|
||||
}
|
||||
r := js.ChatResult{}
|
||||
u.Convert(r1, &r)
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("result:", r.Result)
|
||||
fmt.Println("usage:", u.JsonP(r.TokenUsage))
|
||||
}
|
||||
|
||||
func testFile(t *testing.T) {
|
||||
r1, err := js.Run(`
|
||||
import fs from './lib/file'
|
||||
import out from './lib/console'
|
||||
let r = fs.read('test.js')
|
||||
return r
|
||||
`, "test.js")
|
||||
if err != nil {
|
||||
t.Fatal("发生错误", err.Error())
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("result:", r1)
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
agent:
|
||||
openai:
|
||||
llm:
|
||||
gpt1:
|
||||
apiKey: ...
|
||||
# endpoint: ...
|
||||
# endpoint: https://.... # api base url
|
||||
# llm: openai # registered llm name, if not specified, it will use the key name as llm name
|
||||
zhipu:
|
||||
apiKey: ...
|
||||
|
24
tests/go.mod
24
tests/go.mod
@ -1,24 +0,0 @@
|
||||
module tests
|
||||
|
||||
go 1.22
|
||||
|
||||
require (
|
||||
apigo.cc/ai/ai v0.0.0
|
||||
apigo.cc/ai/openai v0.0.1
|
||||
apigo.cc/ai/zhipu v0.0.1
|
||||
github.com/ssgo/config v1.7.7
|
||||
github.com/ssgo/u v1.7.7
|
||||
)
|
||||
|
||||
require (
|
||||
apigo.cc/ai/agent v0.0.1 // indirect
|
||||
github.com/go-resty/resty/v2 v2.14.0 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1 // indirect
|
||||
github.com/sashabaranov/go-openai v1.29.1 // indirect
|
||||
github.com/ssgo/log v1.7.7 // indirect
|
||||
github.com/ssgo/standard v1.7.7 // indirect
|
||||
golang.org/x/net v0.29.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
replace apigo.cc/ai/ai v0.0.0 => ../
|
14
tests/mod/chat.js
Normal file
14
tests/mod/chat.js
Normal file
@ -0,0 +1,14 @@
|
||||
import {glm} from '../lib/ai'
|
||||
import console from '../lib/console'
|
||||
|
||||
function fast(...args) {
|
||||
if(!args[0]) throw new Error('no ask')
|
||||
let r = glm.fastAsk(args[0], r => {
|
||||
console.print(r)
|
||||
})
|
||||
console.println()
|
||||
return r
|
||||
}
|
||||
|
||||
// exports.fast = fast
|
||||
module.exports = {fast}
|
6
tests/test.js
Normal file
6
tests/test.js
Normal file
@ -0,0 +1,6 @@
|
||||
import {glm, gpt} from './lib/ai'
|
||||
import chat from './mod/chat'
|
||||
|
||||
function main(...args) {
|
||||
return chat.fast(args[0])
|
||||
}
|
Loading…
Reference in New Issue
Block a user