diff --git a/.gitignore b/.gitignore index 53669e9..b313d61 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .* go.sum env.yml +/tests/lib diff --git a/README.md b/README.md new file mode 100644 index 0000000..d2d4249 --- /dev/null +++ b/README.md @@ -0,0 +1,204 @@ +# 低代码AI能力工具包 + +## 命令行工具 + +### Install + +```shell +go install apigo.cc/ai/ai/ai@latest +``` + +### Usage + +```shell +ai -h | --help show usage +ai -e | --export export ai.ts file to ./lib for develop +ai test | test.js run test.js, if not specified, run ./ai.js +``` + +### Sample + +#### test.js + +```javascript +import {glm} from './lib/ai' +import console from './lib/console' + +function main(...args) { + if(!args[0]) throw new Error('no ask') + let r = glm.fastAsk(args[0], r => { + console.print(r) + }) + console.println() + return r +} +``` + +#### run sample + +```shell +ai test "你好" +``` + +### Module + +#### mod/glm.js + +```javascript +import {glm} from './lib/ai' +import console from './lib/console' + +function chat(...args) { + if(!args[0]) throw new Error('no prompt') + return glm.fastAsk(args[0]) +} + +function draw(...args) { + if(!args[0]) throw new Error('no ask') + return glm.makeImage(args[0], {size:'1024x1024'}) +} + +module.exports = {chat, draw} +``` + +#### test.js + +```javascript +import glm from './mod/glm' + +function main(...args) { + console.println(glm.chat(args[0])) +} +``` + +#### run sample + +```shell +ai test "你好" +``` + +### Configure + +#### llm.yml + +```yaml +openai: + apiKey: ... +zhipu: + apiKey: ... +``` + +#### or use env.yml + +#### llm.yml + +```yaml +llm: + openai: + apiKey: ... + zhipu: + apiKey: ... +``` + +#### encrypt apiKey + +install sskey + +```shell +go install github.com/ssgo/tool/sskey@latest +sskey -e 'your apiKey' +``` + +copy url base64 format encrypted apiKey into llm.yml or env.yml + +#### config with special endpoint + +```yaml +llm: + openai: + apiKey: ... + endpoint: https://api.openai.com/v1 +``` + +#### config multi api + +```yaml +llm: + glm: + apiKey: ... + llm: zhipu + glm2: + apiKey: ... + endpoint: https://...... + llm: zhipu +``` + + +## 调用 JavaScript API + +### Install + +```shell +go get apigo.cc/ai/ai +``` + +### Usage + +```go +package main + +import ( + "apigo.cc/ai/ai/js" + "fmt" +) + +func main() { + result, err := js.Run(`return ai.glm.fastAsk(args[0])`, "", "你好") + // js.RunFile + // js.StartFromFile + // js.StartFromCode + if err != nil { + fmt.Println(err.Error()) + } else if result != nil { + fmt.Println(result) + } +} +``` + + +## 调用 Go API + +### Install + +```shell +go get apigo.cc/ai/ai +``` + +### Usage + +```go +package main + +import ( + "apigo.cc/ai/ai" + "apigo.cc/ai/ai/llm" + "fmt" +) + +func main() { + ai.Init() + glm := llm.Get("zhipu") + + r, usage, err := glm.FastAsk(llm.Messages().User().Text("你是什么模型").Make(), func(text string) { + fmt.Print(text) + }) + + if err != nil { + fmt.Println(err) + } else { + fmt.Println() + fmt.Println("result:", r) + fmt.Println("usage:", usage) + } +} +``` diff --git a/agent.go b/agent.go deleted file mode 100644 index 7667ceb..0000000 --- a/agent.go +++ /dev/null @@ -1,89 +0,0 @@ -package ai - -import ( - "apigo.cc/ai/agent" - "github.com/ssgo/config" - "github.com/ssgo/u" - "sync" -) - -var confAes = u.NewAes([]byte("?GQ$0K0GgLdO=f+~L68PLm$uhKr4'=tV"), []byte("VFs7@sK61cj^f?HZ")) -var keysIsSet = false - -func SetSSKey(key, iv []byte) { - if !keysIsSet { - confAes = u.NewAes(key, iv) - keysIsSet = true - } -} - -const ( - TypeText = agent.TypeText - TypeImage = agent.TypeImage - TypeVideo = agent.TypeVideo - RoleSystem = agent.RoleSystem - RoleUser = agent.RoleUser - RoleAssistant = agent.RoleAssistant - RoleTool = agent.RoleTool - ToolCodeInterpreter = agent.ToolCodeInterpreter - ToolWebSearch = agent.ToolWebSearch -) - -type APIConfig struct { - Endpoint string - ApiKey string - DefaultChatModelConfig ChatModelConfig -} - -type AgentConfig struct { - ApiKey string - Endpoint string - Agent string - ChatConfig ChatModelConfig -} - -var agentConfigs map[string]*AgentConfig -var agentConfigsLock = sync.RWMutex{} - -func GetAgent(name string) agent.Agent { - ag := agent.GetAgent(name) - if ag != nil { - return ag - } - - var agConf *AgentConfig - if agentConfigs == nil { - agConfs := make(map[string]*AgentConfig) - config.LoadConfig("agent", &agConfs) - agConf = agConfs[name] - if agConf != nil { - agentConfigsLock.Lock() - agentConfigs = agConfs - agentConfigsLock.Unlock() - } - } else { - agentConfigsLock.RLock() - agConf = agentConfigs[name] - agentConfigsLock.RUnlock() - } - - if agConf == nil { - return nil - } - - if agConf.Agent == "" { - agConf.Agent = name - } - - return agent.CreateAgent(name, agConf.Agent, agent.APIConfig{ - Endpoint: agConf.Endpoint, - ApiKey: confAes.DecryptUrlBase64ToString(agConf.ApiKey), - DefaultChatModelConfig: agent.ChatModelConfig{ - Model: agConf.ChatConfig.Model, - MaxTokens: agConf.ChatConfig.MaxTokens, - Temperature: agConf.ChatConfig.Temperature, - TopP: agConf.ChatConfig.TopP, - Tools: agConf.ChatConfig.Tools, - }, - }) -} diff --git a/ai.go b/ai.go new file mode 100644 index 0000000..5c26858 --- /dev/null +++ b/ai.go @@ -0,0 +1,61 @@ +package ai + +import ( + "apigo.cc/ai/ai/llm" + _ "apigo.cc/ai/ai/llm/openai" + _ "apigo.cc/ai/ai/llm/zhipu" + "github.com/ssgo/config" + "github.com/ssgo/u" + "os" +) + +var confAes = u.NewAes([]byte("?GQ$0K0GgLdO=f+~L68PLm$uhKr4'=tV"), []byte("VFs7@sK61cj^f?HZ")) +var keysIsSet = false +var isInit = false + +func SetSSKey(key, iv []byte) { + if !keysIsSet { + confAes = u.NewAes(key, iv) + keysIsSet = true + } +} + +func Init() { + InitFrom("") +} + +func InitFrom(filePath string) { + if !isInit { + isInit = true + list := map[string]*struct { + Endpoint string + ApiKey string + DefaultChatModelConfig llm.ChatConfig + Llm string + }{} + savedPath := "" + if filePath != "" { + curPath, _ := os.Getwd() + if curPath != filePath { + savedPath = curPath + _ = os.Chdir(filePath) + config.ResetConfigEnv() + } + } + _ = config.LoadConfig("llm", &list) + if savedPath != "" { + _ = os.Chdir(savedPath) + } + for name, llmConf := range list { + if llmConf.Llm == "" { + llmConf.Llm = name + } + llmConf.ApiKey = confAes.DecryptUrlBase64ToString(llmConf.ApiKey) + llm.Create(name, llmConf.Llm, llm.Config{ + Endpoint: llmConf.Endpoint, + ApiKey: llmConf.ApiKey, + ChatConfig: llmConf.DefaultChatModelConfig, + }) + } + } +} diff --git a/ai/ai.go b/ai/ai.go new file mode 100644 index 0000000..446101e --- /dev/null +++ b/ai/ai.go @@ -0,0 +1,59 @@ +package main + +import ( + "apigo.cc/ai/ai/js" + "fmt" + "github.com/ssgo/u" + "os" + "strings" +) + +func main() { + if len(os.Args) > 1 && (os.Args[1] == "-e" || os.Args[1] == "export") { + imports, err := js.ExportForDev() + if err != nil { + fmt.Println(err.Error()) + } else { + fmt.Println(`exported to ./lib +example: +`) + fmt.Println(u.Cyan(imports + ` +// import customMod from './customMod' + +function main(...args) { + // TODO + return null +} +`)) + } + return + } else { + jsFile := "ai.js" + if len(os.Args) > 1 { + jsFile = os.Args[1] + if !strings.HasSuffix(jsFile, ".js") && !u.FileExists(jsFile) { + jsFile = jsFile + ".js" + } + } + if u.FileExists(jsFile) { + args := make([]any, len(os.Args)-2) + for i := 2; i < len(os.Args); i++ { + args[i-2] = os.Args[i] + } + result, err := js.RunFile(jsFile, args...) + if err != nil { + fmt.Println(err.Error()) + } else if result != nil { + fmt.Println(u.JsonP(result)) + } + return + } + } + + fmt.Println(`Usage: +ai -h | --help show usage +ai -e | --export export ai.ts file for develop +ai test | test.js run test.js, if not specified, run ai.js +`) + return +} diff --git a/go.mod b/go.mod index 48dcd11..643c0f0 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,29 @@ module apigo.cc/ai/ai go 1.22 require ( - apigo.cc/ai/agent v0.0.1 + github.com/dop251/goja v0.0.0-20240828124009-016eb7256539 + github.com/go-resty/resty/v2 v2.15.0 + github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/sashabaranov/go-openai v1.29.2 github.com/ssgo/config v1.7.7 + github.com/ssgo/log v1.7.7 github.com/ssgo/u v1.7.7 + github.com/stretchr/testify v1.9.0 ) -require gopkg.in/yaml.v3 v3.0.1 // indirect +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dlclark/regexp2 v1.11.4 // indirect + github.com/dop251/goja_nodejs v0.0.0-20240728170619-29b559befffc // indirect + github.com/go-sourcemap/sourcemap v2.1.4+incompatible // indirect + github.com/google/pprof v0.0.0-20240910150728-a0b0bb1d4134 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/ssgo/standard v1.7.7 // indirect + golang.org/x/net v0.29.0 // indirect + golang.org/x/text v0.18.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +replace ( + github.com/ssgo/config v1.7.7 => ../../ssgo/config +) diff --git a/js/ai.go b/js/ai.go new file mode 100644 index 0000000..0f3aff8 --- /dev/null +++ b/js/ai.go @@ -0,0 +1,291 @@ +package js + +import ( + "apigo.cc/ai/ai/llm" + "github.com/dop251/goja" + "github.com/ssgo/u" + "reflect" + "strings" +) + +type ChatResult struct { + llm.TokenUsage + Result string + Error string +} + +type AIGCResult struct { + Result string + Preview string + Results []string + Previews []string + Error string +} + +func requireAI(lm llm.LLM) map[string]any { + return map[string]any{ + "ask": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value { + conf, cb := getAskArgs(args.This, vm, args.Arguments) + result, usage, err := lm.Ask(makeChatMessages(args.Arguments), conf, cb) + return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)}) + }, + "fastAsk": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value { + _, cb := getAskArgs(args.This, vm, args.Arguments) + result, usage, err := lm.FastAsk(makeChatMessages(args.Arguments), cb) + return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)}) + }, + "longAsk": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value { + _, cb := getAskArgs(args.This, vm, args.Arguments) + result, usage, err := lm.LongAsk(makeChatMessages(args.Arguments), cb) + return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)}) + }, + "batterAsk": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value { + _, cb := getAskArgs(args.This, vm, args.Arguments) + result, usage, err := lm.BatterAsk(makeChatMessages(args.Arguments), cb) + return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)}) + }, + "bestAsk": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value { + _, cb := getAskArgs(args.This, vm, args.Arguments) + result, usage, err := lm.BestAsk(makeChatMessages(args.Arguments), cb) + return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)}) + }, + + "multiAsk": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value { + _, cb := getAskArgs(args.This, vm, args.Arguments) + result, usage, err := lm.MultiAsk(makeChatMessages(args.Arguments), cb) + return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)}) + }, + "bestMultiAsk": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value { + _, cb := getAskArgs(args.This, vm, args.Arguments) + result, usage, err := lm.BestMultiAsk(makeChatMessages(args.Arguments), cb) + return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)}) + }, + + "codeInterpreterAsk": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value { + _, cb := getAskArgs(args.This, vm, args.Arguments) + result, usage, err := lm.CodeInterpreterAsk(makeChatMessages(args.Arguments), cb) + return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)}) + }, + "webSearchAsk": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value { + _, cb := getAskArgs(args.This, vm, args.Arguments) + result, usage, err := lm.WebSearchAsk(makeChatMessages(args.Arguments), cb) + return vm.ToValue(ChatResult{TokenUsage: usage, Result: result, Error: getErrorStr(err)}) + }, + + "makeImage": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value { + prompt, conf := getAIGCArgs(args.Arguments) + results, err := lm.MakeImage(prompt, conf) + return makeAIGCResult(vm, results, nil, err) + }, + "fastMakeImage": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value { + prompt, conf := getAIGCArgs(args.Arguments) + results, err := lm.FastMakeImage(prompt, conf) + return makeAIGCResult(vm, results, nil, err) + }, + "bestMakeImage": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value { + prompt, conf := getAIGCArgs(args.Arguments) + results, err := lm.BestMakeImage(prompt, conf) + return makeAIGCResult(vm, results, nil, err) + }, + + "makeVideo": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value { + prompt, conf := getAIGCArgs(args.Arguments) + results, previews, err := lm.MakeVideo(prompt, conf) + return makeAIGCResult(vm, results, previews, err) + }, + "fastMakeVideo": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value { + prompt, conf := getAIGCArgs(args.Arguments) + results, previews, err := lm.FastMakeVideo(prompt, conf) + return makeAIGCResult(vm, results, previews, err) + }, + "bestMakeVideo": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value { + prompt, conf := getAIGCArgs(args.Arguments) + results, previews, err := lm.BestMakeVideo(prompt, conf) + return makeAIGCResult(vm, results, previews, err) + }, + + "support": lm.Support(), + } +} + +func getErrorStr(err error) string { + if err != nil { + return err.Error() + } + return "" +} + +func makeAIGCResult(vm *goja.Runtime, results []string, previews []string, err error) goja.Value { + result := "" + preview := "" + if len(results) > 0 { + result = results[0] + } else { + results = make([]string, 0) + } + if len(previews) > 0 { + preview = previews[0] + } else { + previews = make([]string, 0) + } + return vm.ToValue(AIGCResult{ + Result: result, + Preview: preview, + Results: results, + Previews: previews, + Error: getErrorStr(err), + }) +} + +func getAIGCArgs(args []goja.Value) (string, llm.GCConfig) { + prompt := "" + var config llm.GCConfig + if len(args) > 0 { + prompt = u.String(args[0].Export()) + if len(args) > 1 { + u.Convert(args[1].Export(), &config) + } + } + return prompt, config +} + +func getAskArgs(thisArg goja.Value, vm *goja.Runtime, args []goja.Value) (llm.ChatConfig, func(string)) { + var chatConfig llm.ChatConfig + var callback func(answer string) + if len(args) > 0 { + for i := 1; i < len(args); i++ { + if cb, ok := goja.AssertFunction(args[i]); ok { + callback = func(answer string) { + _, _ = cb(thisArg, vm.ToValue(answer)) + } + } else { + switch args[i].ExportType().Kind() { + case reflect.Map, reflect.Struct: + u.Convert(args[i].Export(), &chatConfig) + default: + chatConfig.Model = u.String(args[i].Export()) + } + } + } + } + return chatConfig, callback +} + +func makeChatMessages(args []goja.Value) []llm.ChatMessage { + out := make([]llm.ChatMessage, 0) + if len(args) > 0 { + v := args[0].Export() + vv := reflect.ValueOf(v) + t := args[0].ExportType() + lastRoleIsUser := false + switch t.Kind() { + // 数组,根据成员类型处理 + // 字符串: + // 含有媒体:单条多模态消息 + // 无媒体:多条文本消息 + // 数组:多条消息(第一个成员不是 role 则自动生成) + // 对象:多条消息(无 role 则自动生成)(支持 content 或 contents) + // 结构:转换为 llm.ChatMessage + // 对象:单条消息(支持 content 或 contents) + // 结构:转换为 llm.ChatMessage + // 字符串:单条文本消息 + case reflect.Slice: + hasSub := false + hasMulti := false + for i := 0; i < vv.Len(); i++ { + if vv.Index(i).Kind() == reflect.Slice || vv.Index(i).Kind() == reflect.Map || vv.Index(i).Kind() == reflect.Struct { + hasSub = true + break + } + if vv.Index(i).Kind() == reflect.String { + str := vv.Index(i).String() + if strings.HasPrefix(str, "data:") || strings.HasPrefix(str, "https://") || strings.HasPrefix(str, "http://") { + hasMulti = true + } + } + } + if hasSub || !hasMulti { + // 有子对象或纯文本数组 + var defaultRole string + for i := 0; i < vv.Len(); i++ { + lastRoleIsUser = !lastRoleIsUser + if lastRoleIsUser { + defaultRole = llm.RoleUser + } else { + defaultRole = llm.RoleAssistant + } + vv2 := vv.Index(i) + switch vv2.Kind() { + case reflect.Slice: + out = append(out, makeChatMessageFromSlice(vv2, defaultRole)) + case reflect.Map: + out = append(out, makeChatMessageFromSlice(vv2, defaultRole)) + case reflect.Struct: + item := llm.ChatMessage{} + u.Convert(vv2.Interface(), &item) + out = append(out, item) + default: + out = append(out, llm.ChatMessage{Role: llm.RoleUser, Contents: []llm.ChatMessageContent{makeChatMessageContent(u.String(vv2.Interface()))}}) + } + lastRoleIsUser = out[len(out)-1].Role != llm.RoleUser + } + } else { + // 单条多模态消息 + out = append(out, makeChatMessageFromSlice(vv, llm.RoleUser)) + } + case reflect.Map: + out = append(out, makeChatMessageFromMap(vv, llm.RoleUser)) + case reflect.Struct: + item := llm.ChatMessage{} + u.Convert(v, &item) + out = append(out, item) + default: + out = append(out, llm.ChatMessage{Role: llm.RoleUser, Contents: []llm.ChatMessageContent{makeChatMessageContent(u.String(v))}}) + } + } + return out +} + +func makeChatMessageFromSlice(vv reflect.Value, defaultRole string) llm.ChatMessage { + role := u.String(vv.Index(0).Interface()) + j := 0 + if role == llm.RoleUser || role == llm.RoleAssistant || role == llm.RoleSystem || role == llm.RoleTool { + j = 1 + } else { + role = defaultRole + } + contents := make([]llm.ChatMessageContent, 0) + for ; j < vv.Len(); j++ { + contents = append(contents, makeChatMessageContent(u.String(vv.Index(j).Interface()))) + } + return llm.ChatMessage{Role: role, Contents: contents} +} + +func makeChatMessageFromMap(vv reflect.Value, defaultRole string) llm.ChatMessage { + role := u.String(vv.MapIndex(reflect.ValueOf("role")).Interface()) + if role == "" { + role = defaultRole + } + contents := make([]llm.ChatMessageContent, 0) + content := u.String(vv.MapIndex(reflect.ValueOf("content")).Interface()) + if content != "" { + contents = append(contents, makeChatMessageContent(content)) + } else { + contentsV := vv.MapIndex(reflect.ValueOf("contents")) + if contentsV.IsValid() && contentsV.Kind() == reflect.Slice { + for i := 0; i < contentsV.Len(); i++ { + contents = append(contents, makeChatMessageContent(u.String(contentsV.Index(i).Interface()))) + } + } + } + return llm.ChatMessage{Role: role, Contents: contents} +} + +func makeChatMessageContent(contnet string) llm.ChatMessageContent { + if strings.HasPrefix(contnet, "data:image/") || ((strings.HasPrefix(contnet, "https://") || strings.HasPrefix(contnet, "http://")) && (strings.HasSuffix(contnet, ".png") || strings.HasSuffix(contnet, ".jpg") || strings.HasSuffix(contnet, ".jpeg") || strings.HasSuffix(contnet, ".gif") || strings.HasSuffix(contnet, ".svg"))) { + return llm.ChatMessageContent{Type: llm.TypeImage, Content: contnet} + } else if strings.HasPrefix(contnet, "data:video/") || ((strings.HasPrefix(contnet, "https://") || strings.HasPrefix(contnet, "http://")) && (strings.HasSuffix(contnet, ".mp4") || strings.HasSuffix(contnet, ".mov") || strings.HasSuffix(contnet, ".m4v") || strings.HasSuffix(contnet, ".avi") || strings.HasSuffix(contnet, ".wmv"))) { + return llm.ChatMessageContent{Type: llm.TypeVideo, Content: contnet} + } + return llm.ChatMessageContent{Type: llm.TypeText, Content: contnet} +} diff --git a/js/console.go b/js/console.go new file mode 100644 index 0000000..fe5ec5d --- /dev/null +++ b/js/console.go @@ -0,0 +1,81 @@ +package js + +import ( + "fmt" + "github.com/dop251/goja" + "github.com/ssgo/u" + "strings" +) + +func requireConsole() map[string]any { + return map[string]any{ + "print": func(args goja.FunctionCall) goja.Value { + consolePrint(args, "print", nil) + return nil + }, + "println": func(args goja.FunctionCall) goja.Value { + consolePrint(args, "log", nil) + return nil + }, + "log": func(args goja.FunctionCall) goja.Value { + consolePrint(args, "log", nil) + return nil + }, + "info": func(args goja.FunctionCall) goja.Value { + consolePrint(args, "info", nil) + return nil + }, + "warn": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value { + consolePrint(args, "warn", vm) + return nil + }, + "error": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value { + consolePrint(args, "error", vm) + return nil + }, + "debug": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value { + consolePrint(args, "debug", vm) + return nil + }, + "input": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value { + consolePrint(args, "print", nil) + line := "" + _, _ = fmt.Scanln(&line) + return vm.ToValue(line) + }, + } +} + +func consolePrint(args goja.FunctionCall, typ string, vm *goja.Runtime) { + arr := make([]any, len(args.Arguments)) + textColor := u.TextNone + switch typ { + case "info": + textColor = u.TextCyan + case "warn": + textColor = u.TextYellow + case "error": + textColor = u.TextRed + } + + for i, arg := range args.Arguments { + if textColor != u.TextNone { + arr[i] = u.Color(u.StringP(arg.Export()), textColor, u.BgNone) + } else { + arr[i] = u.StringP(arg.Export()) + } + } + + if (typ == "warn" || typ == "error" || typ == "debug") && vm != nil { + callStacks := make([]string, 0) + for _, stack := range vm.CaptureCallStack(0, nil) { + callStacks = append(callStacks, u.Color(" "+stack.Position().String(), textColor, u.BgNone)) + } + fmt.Println(arr...) + fmt.Println(strings.Join(callStacks, "\n")) + } else if typ == "print" { + fmt.Print(arr...) + } else { + fmt.Println(arr...) + } +} diff --git a/js/file.go b/js/file.go new file mode 100644 index 0000000..90f7d7c --- /dev/null +++ b/js/file.go @@ -0,0 +1,48 @@ +package js + +import ( + "errors" + "github.com/dop251/goja" + "github.com/ssgo/u" +) + +func requireFile() map[string]any { + return map[string]any{ + "read": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value { + if len(args.Arguments) < 1 { + return vm.NewGoError(errors.New("arguments need 1 given " + u.String(len(args.Arguments)))) + } + if r, err := u.ReadFile(u.String(args.Arguments[0].Export())); err == nil { + return vm.ToValue(r) + } else { + return vm.NewGoError(err) + } + }, + "write": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value { + if len(args.Arguments) < 2 { + return vm.NewGoError(errors.New("arguments need 2 given " + u.String(len(args.Arguments)))) + } + if err := u.WriteFileBytes(u.String(args.Arguments[0].Export()), u.Bytes(args.Arguments[0].Export())); err == nil { + return nil + } else { + return vm.NewGoError(err) + } + }, + "dir": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value { + if len(args.Arguments) < 1 { + return vm.NewGoError(errors.New("arguments need 1 given " + u.String(len(args.Arguments)))) + } + if r, err := u.ReadDir(u.String(args.Arguments[0].Export())); err == nil { + return vm.ToValue(r) + } else { + return vm.NewGoError(err) + } + }, + "stat": func(args goja.FunctionCall, vm *goja.Runtime) goja.Value { + if len(args.Arguments) < 1 { + return vm.NewGoError(errors.New("arguments need 1 given " + u.String(len(args.Arguments)))) + } + return vm.ToValue(u.GetFileInfo(u.String(args.Arguments[0].Export()))) + }, + } +} diff --git a/js/js.go b/js/js.go new file mode 100644 index 0000000..662eb57 --- /dev/null +++ b/js/js.go @@ -0,0 +1,238 @@ +package js + +import ( + "apigo.cc/ai/ai" + "apigo.cc/ai/ai/llm" + "bytes" + _ "embed" + "encoding/json" + "errors" + "fmt" + "github.com/dop251/goja" + "github.com/dop251/goja_nodejs/require" + "github.com/ssgo/u" + "path/filepath" + "regexp" + "strings" + "text/template" +) + +//go:embed lib/ai.ts +var aiTS string + +//go:embed lib/console.ts +var consoleTS string + +//go:embed lib/file.ts +var fileTS string + +func RunFile(file string, args ...any) (any, error) { + return Run(u.ReadFileN(file), file, args...) +} + +func Run(code string, refFile string, args ...any) (any, error) { + var r any + js, err := StartFromCode(code, refFile) + if err == nil { + r, err = js.Run(args...) + } + return r, err +} + +var importModMatcher = regexp.MustCompile(`(?im)^\s*import\s+(.+?)\s+from\s+['"](.+?)['"]`) +var importLibMatcher = regexp.MustCompile(`(?im)^\s*(import)\s+(.+?)\s+from\s+['"][./\\\w:]+lib[/\\](.+?)(\.ts)?['"]`) +var requireLibMatcher = regexp.MustCompile(`(?im)^\s*(const|let|var)\s+(.+?)\s*=\s*require\s*\(\s*['"][./\\\w:]+lib[/\\](.+?)(\.ts)?['"]\s*\)`) +var checkMainMatcher = regexp.MustCompile(`(?im)^\s*function\s+main\s*\(`) + +type JS struct { + vm *goja.Runtime + required map[string]bool + file string + srcCode string + code string +} + +func (js *JS) requireMod(name string) error { + var err error + if name == "console" || name == "" { + if !js.required["console"] { + js.required["console"] = true + err = js.vm.Set("console", requireConsole()) + } + } + if err == nil && (name == "file" || name == "") { + if !js.required["file"] { + js.required["file"] = true + err = js.vm.Set("file", requireFile()) + } + } + if err == nil && (name == "ai" || name == "") { + if !js.required["ai"] { + js.required["ai"] = true + aiList := make(map[string]any) + for name, lm := range llm.List() { + aiList[name] = requireAI(lm) + } + err = js.vm.Set("ai", aiList) + } + } + return err +} + +func (js *JS) makeImport(matcher *regexp.Regexp, code string) (string, int, error) { + var modErr error + importCount := 0 + code = matcher.ReplaceAllStringFunc(code, func(str string) string { + if m := matcher.FindStringSubmatch(str); m != nil && len(m) > 3 { + optName := m[1] + if optName == "import" { + optName = "let" + } + varName := m[2] + modName := m[3] + importCount++ + if modErr == nil { + if err := js.requireMod(modName); err != nil { + modErr = err + } + } + if varName != modName { + return fmt.Sprintf("%s %s = %s", optName, varName, modName) + } + } + return "" + }) + return code, importCount, modErr +} + +func StartFromFile(file string) (*JS, error) { + return StartFromCode(u.ReadFileN(file), file) +} + +func StartFromCode(code, refFile string) (*JS, error) { + if refFile == "" { + refFile = "main.js" + } + + if absFile, err := filepath.Abs(refFile); err == nil { + refFile = absFile + } + + ai.InitFrom(filepath.Dir(refFile)) + + js := &JS{ + vm: goja.New(), + required: map[string]bool{}, + file: refFile, + srcCode: code, + code: code, + } + + // 按需加载引用 + var importCount int + var modErr error + js.code, importCount, modErr = js.makeImport(importLibMatcher, js.code) + if modErr == nil { + importCount1 := importCount + js.code, importCount, modErr = js.makeImport(requireLibMatcher, js.code) + importCount += importCount1 + } + + // 将 import 转换为 require + js.code = importModMatcher.ReplaceAllString(js.code, "let $1 = require('$2')") + + // 如果没有import,默认import所有 + if modErr == nil && importCount == 0 { + modErr = js.requireMod("") + } + if modErr != nil { + return nil, modErr + } + + //fmt.Println(u.BCyan(js.code)) + + // 处理模块引用 + require.NewRegistryWithLoader(func(path string) ([]byte, error) { + refPath := filepath.Join(filepath.Dir(js.file), path) + if !strings.HasSuffix(refPath, ".js") && !u.FileExists(refPath) { + refPath += ".js" + } + modCode, err := u.ReadFile(refPath) + if err != nil { + return nil, err + } + modCode, _, _ = js.makeImport(importLibMatcher, modCode) + modCode, _, _ = js.makeImport(requireLibMatcher, modCode) + return []byte(modCode), modErr + }).Enable(js.vm) + + // 初始化主函数 + if !checkMainMatcher.MatchString(js.code) { + js.code = "function main(...args){" + js.code + "}" + } + if _, err := js.vm.RunScript("main", js.code); err != nil { + return nil, err + } + + return js, nil +} + +func (js *JS) Run(args ...any) (any, error) { + // 解析参数 + for i, arg := range args { + if str, ok := arg.(string); ok { + var v interface{} + if err := json.Unmarshal([]byte(str), &v); err == nil { + args[i] = v + } + } + } + + if err := js.vm.Set("__args", args); err != nil { + return nil, err + } + jsResult, err := js.vm.RunScript(js.file, "main(...__args)") + + var result any + if err == nil { + if jsResult != nil && !jsResult.Equals(goja.Undefined()) { + result = jsResult.Export() + } + } + return result, err +} + +type Exports struct { + LLMList []string +} + +func ExportForDev() (string, error) { + ai.Init() + if len(llm.List()) == 0 && !u.FileExists("env.yml") && !u.FileExists("env.json") && !u.FileExists("llm.yml") && !u.FileExists("llm.json") { + return "", errors.New("no llm config found, please run `ai -e` on env.yml or llm.yml path") + } + exports := Exports{} + for name, _ := range llm.List() { + exports.LLMList = append(exports.LLMList, name) + } + + exportFile := filepath.Join("lib", "ai.ts") + var tpl *template.Template + var err error + if tpl, err = template.New(exportFile).Parse(aiTS); err == nil { + buf := bytes.NewBuffer(make([]byte, 0)) + if err = tpl.Execute(buf, exports); err == nil { + err = u.WriteFileBytes(exportFile, buf.Bytes()) + } + } + if err != nil { + return "", err + } + + _ = u.WriteFile(filepath.Join("lib", "console.ts"), consoleTS) + _ = u.WriteFile(filepath.Join("lib", "file.ts"), fileTS) + + return `import {` + strings.Join(exports.LLMList, ", ") + `} from './lib/ai' +import console from './lib/console' +import file from './lib/file'`, nil +} diff --git a/js/lib/ai.ts b/js/lib/ai.ts new file mode 100644 index 0000000..6139132 --- /dev/null +++ b/js/lib/ai.ts @@ -0,0 +1,81 @@ +// just for develop + +{{range .LLMList}} +let {{.}}: LLM +{{end}} + +export default { +{{range .LLMList}} + {{.}}, +{{end}} +} + +interface ChatModelConfig { + model: string + ratio: number + maxTokens: number + temperature: number + topP: number + tools: Object +} + +interface ChatResult { + result: string + askTokens: number + answerTokens: number + totalTokens: number + error: string +} + +interface GCResult { + result: string + preview: string + results: Array + previews: Array + error: string +} + +interface Support { + ask: boolean + askWithImage: boolean + askWithVideo: boolean + askWithCodeInterpreter: boolean + askWithWebSearch: boolean + makeImage: boolean + makeVideo: boolean + models: Array +} + +interface LLM { + ask(messages: any, config?: ChatModelConfig, callback?: (answer: string) => void): ChatResult + + fastAsk(messages: any, callback?: (answer: string) => void): ChatResult + + longAsk(messages: any, callback?: (answer: string) => void): ChatResult + + batterAsk(messages: any, callback?: (answer: string) => void): ChatResult + + bestAsk(messages: any, callback?: (answer: string) => void): ChatResult + + multiAsk(messages: any, callback?: (answer: string) => void): ChatResult + + bestMultiAsk(messages: any, callback?: (answer: string) => void): ChatResult + + codeInterpreterAsk(messages: any, callback?: (answer: string) => void): ChatResult + + webSearchAsk(messages: any, callback?: (answer: string) => void): ChatResult + + makeImage(model: string, prompt: string, size?: string, refImage?: string): GCResult + + fastMakeImage(prompt: string, size?: string, refImage?: string): GCResult + + bestMakeImage(prompt: string, size?: string, refImage?: string): GCResult + + makeVideo(arg2: string, arg3: string, arg4: string, arg5: string): GCResult + + fastMakeVideo(prompt: string, size?: string, refImage?: string): GCResult + + bestMakeVideo(prompt: string, size?: string, refImage?: string): GCResult + + support: Support +} diff --git a/js/lib/console.ts b/js/lib/console.ts new file mode 100644 index 0000000..8c40017 --- /dev/null +++ b/js/lib/console.ts @@ -0,0 +1,37 @@ +// just for develop + +export default { + print, + println, + log, + debug, + info, + warn, + error, + input, +} + +function print(...data: any[]): void { +} + +function println(...data: any[]): void { +} + +function log(...data: any[]): void { +} + +function debug(...data: any[]): void { +} + +function info(...data: any[]): void { +} + +function warn(...data: any[]): void { +} + +function error(...data: any[]): void { +} + +function input(...data: any[]): string { + return '' +} diff --git a/js/lib/file.ts b/js/lib/file.ts new file mode 100644 index 0000000..0cdc810 --- /dev/null +++ b/js/lib/file.ts @@ -0,0 +1,32 @@ +// just for develop + +export default { + read, + write, + dir, + stat +} + +function read(filename: string): string { + return '' +} + +function write(filename: string, data: any): void { +} + + +function dir(filename: string): Array { + return null +} + +function stat(filename: string): FileInfo { + return null +} + +interface FileInfo { + Name: string + FullName: string + IsDir: boolean + Size: number + ModTime: number +} diff --git a/chat.go b/llm/chat.go similarity index 52% rename from chat.go rename to llm/chat.go index 95bdbcd..1bda649 100644 --- a/chat.go +++ b/llm/chat.go @@ -1,15 +1,68 @@ -package ai +package llm -import "apigo.cc/ai/agent" +type ChatMessage struct { + Role string + Contents []ChatMessageContent +} -type ChatMessage = agent.ChatMessage -type ChatMessageContent = agent.ChatMessageContent -type ChatModelConfig struct { - Model string - MaxTokens int - Temperature float64 - TopP float64 - Tools map[string]any +type ChatMessageContent struct { + Type string // text, image, audio, video + Content string +} + +type ChatConfig struct { + defaultConfig *ChatConfig + Model string + Ratio float64 + MaxTokens int + Temperature float64 + TopP float64 + Tools map[string]any +} + +func (chatConfig *ChatConfig) SetDefault(config *ChatConfig) { + chatConfig.defaultConfig = config +} + +func (chatConfig *ChatConfig) GetModel() string { + if chatConfig.Model == "" && chatConfig.defaultConfig != nil { + return chatConfig.defaultConfig.Model + } + return chatConfig.Model +} + +func (chatConfig *ChatConfig) GetMaxTokens() int { + if chatConfig.MaxTokens == 0 && chatConfig.defaultConfig != nil { + return chatConfig.defaultConfig.MaxTokens + } + return chatConfig.MaxTokens +} + +func (chatConfig *ChatConfig) GetTemperature() float64 { + if chatConfig.Temperature == 0 && chatConfig.defaultConfig != nil { + return chatConfig.defaultConfig.Temperature + } + return chatConfig.Temperature +} + +func (chatConfig *ChatConfig) GetTopP() float64 { + if chatConfig.TopP == 0 && chatConfig.defaultConfig != nil { + return chatConfig.defaultConfig.TopP + } + return chatConfig.TopP +} + +func (chatConfig *ChatConfig) GetTools() map[string]any { + if chatConfig.Tools == nil && chatConfig.defaultConfig != nil { + return chatConfig.defaultConfig.Tools + } + return chatConfig.Tools +} + +type TokenUsage struct { + AskTokens int64 + AnswerTokens int64 + TotalTokens int64 } type MessagesMaker struct { diff --git a/llm/gc.go b/llm/gc.go new file mode 100644 index 0000000..f6b3abd --- /dev/null +++ b/llm/gc.go @@ -0,0 +1,33 @@ +package llm + +type GCConfig struct { + defaultConfig *GCConfig + Model string + Size string + Ref string +} + +func (gcConfig *GCConfig) SetDefault(config *GCConfig) { + gcConfig.defaultConfig = config +} + +func (gcConfig *GCConfig) GetModel() string { + if gcConfig.Model == "" && gcConfig.defaultConfig != nil { + return gcConfig.defaultConfig.Model + } + return gcConfig.Model +} + +func (gcConfig *GCConfig) GetSize() string { + if gcConfig.Size == "" && gcConfig.defaultConfig != nil { + return gcConfig.defaultConfig.Size + } + return gcConfig.Size +} + +func (gcConfig *GCConfig) GetRef() string { + if gcConfig.Ref == "" && gcConfig.defaultConfig != nil { + return gcConfig.defaultConfig.Ref + } + return gcConfig.Ref +} diff --git a/llm/llm.go b/llm/llm.go new file mode 100644 index 0000000..4c947c7 --- /dev/null +++ b/llm/llm.go @@ -0,0 +1,96 @@ +package llm + +import "sync" + +const ( + TypeText = "text" + TypeImage = "image" + TypeVideo = "video" + RoleSystem = "system" + RoleUser = "user" + RoleAssistant = "assistant" + RoleTool = "tool" + ToolCodeInterpreter = "codeInterpreter" + ToolWebSearch = "webSearch" +) + +type Support struct { + Ask bool + AskWithImage bool + AskWithVideo bool + AskWithCodeInterpreter bool + AskWithWebSearch bool + MakeImage bool + MakeVideo bool + Models []string +} + +type Config struct { + Endpoint string + ApiKey string + ChatConfig ChatConfig + GCConfig GCConfig +} + +type LLM interface { + Support() Support + Ask(messages []ChatMessage, config ChatConfig, callback func(answer string)) (string, TokenUsage, error) + FastAsk(messages []ChatMessage, callback func(answer string)) (string, TokenUsage, error) + LongAsk(messages []ChatMessage, callback func(answer string)) (string, TokenUsage, error) + BatterAsk(messages []ChatMessage, callback func(answer string)) (string, TokenUsage, error) + BestAsk(messages []ChatMessage, callback func(answer string)) (string, TokenUsage, error) + MultiAsk(messages []ChatMessage, callback func(answer string)) (string, TokenUsage, error) + BestMultiAsk(messages []ChatMessage, callback func(answer string)) (string, TokenUsage, error) + CodeInterpreterAsk(messages []ChatMessage, callback func(answer string)) (string, TokenUsage, error) + WebSearchAsk(messages []ChatMessage, callback func(answer string)) (string, TokenUsage, error) + MakeImage(prompt string, config GCConfig) ([]string, error) + FastMakeImage(prompt string, config GCConfig) ([]string, error) + BestMakeImage(prompt string, config GCConfig) ([]string, error) + MakeVideo(prompt string, config GCConfig) ([]string, []string, error) + FastMakeVideo(prompt string, config GCConfig) ([]string, []string, error) + BestMakeVideo(prompt string, config GCConfig) ([]string, []string, error) +} + +var llmMakers = map[string]func(Config) LLM{} +var llmMakersLock = sync.RWMutex{} + +var llms = map[string]LLM{} +var llmsLock = sync.RWMutex{} + +func Register(llmId string, maker func(Config) LLM) { + llmMakersLock.Lock() + llmMakers[llmId] = maker + llmMakersLock.Unlock() +} + +func Create(name, llmId string, config Config) LLM { + llmMakersLock.RLock() + maker := llmMakers[llmId] + llmMakersLock.RUnlock() + + if maker != nil { + llm := maker(config) + llmsLock.Lock() + llms[name] = llm + llmsLock.Unlock() + return llm + } + return nil +} + +func Get(name string) LLM { + llmsLock.RLock() + llm := llms[name] + llmsLock.RUnlock() + return llm +} + +func List() map[string]LLM { + list := map[string]LLM{} + llmsLock.RLock() + for name, llm := range llms { + list[name] = llm + } + llmsLock.RUnlock() + return list +} diff --git a/llm/openai/chat.go b/llm/openai/chat.go new file mode 100644 index 0000000..da5d6b8 --- /dev/null +++ b/llm/openai/chat.go @@ -0,0 +1,165 @@ +package openai + +import ( + "apigo.cc/ai/ai/llm" + "context" + "github.com/sashabaranov/go-openai" + "github.com/ssgo/log" + "strings" +) + +func (lm *LLM) FastAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) { + return lm.Ask(messages, llm.ChatConfig{ + Model: ModelGPT_4o_mini_2024_07_18, + }, callback) +} + +func (lm *LLM) LongAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) { + return lm.Ask(messages, llm.ChatConfig{ + Model: ModelGPT_4_32k_0613, + }, callback) +} + +func (lm *LLM) BatterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) { + return lm.Ask(messages, llm.ChatConfig{ + Model: ModelGPT_4_turbo, + }, callback) +} + +func (lm *LLM) BestAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) { + return lm.Ask(messages, llm.ChatConfig{ + Model: ModelGPT_4o_2024_08_06, + }, callback) +} + +func (lm *LLM) MultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) { + return lm.Ask(messages, llm.ChatConfig{ + Model: ModelGPT_4o_mini_2024_07_18, + }, callback) +} + +func (lm *LLM) BestMultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) { + return lm.Ask(messages, llm.ChatConfig{ + Model: ModelGPT_4o_2024_08_06, + }, callback) +} + +func (lm *LLM) CodeInterpreterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) { + return lm.Ask(messages, llm.ChatConfig{ + Model: ModelGPT_4o, + Tools: map[string]any{llm.ToolCodeInterpreter: nil}, + }, callback) +} + +func (lm *LLM) WebSearchAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) { + return lm.Ask(messages, llm.ChatConfig{ + Model: ModelGPT_4o_mini_2024_07_18, + Tools: map[string]any{llm.ToolWebSearch: nil}, + }, callback) +} + +func (lm *LLM) Ask(messages []llm.ChatMessage, config llm.ChatConfig, callback func(answer string)) (string, llm.TokenUsage, error) { + openaiConf := openai.DefaultConfig(lm.config.ApiKey) + if lm.config.Endpoint != "" { + openaiConf.BaseURL = lm.config.Endpoint + } + + config.SetDefault(&lm.config.ChatConfig) + + agentMessages := make([]openai.ChatCompletionMessage, len(messages)) + for i, msg := range messages { + var contents []openai.ChatMessagePart + if msg.Contents != nil { + contents = make([]openai.ChatMessagePart, len(msg.Contents)) + for j, inPart := range msg.Contents { + part := openai.ChatMessagePart{} + part.Type = TypeMap[inPart.Type] + switch inPart.Type { + case llm.TypeText: + part.Text = inPart.Content + case llm.TypeImage: + part.ImageURL = &openai.ChatMessageImageURL{ + URL: inPart.Content, + Detail: openai.ImageURLDetailAuto, + } + } + contents[j] = part + } + } + agentMessages[i] = openai.ChatCompletionMessage{ + Role: RoleMap[msg.Role], + MultiContent: contents, + } + } + + opt := openai.ChatCompletionRequest{ + Model: config.GetModel(), + Messages: agentMessages, + MaxTokens: config.GetMaxTokens(), + Temperature: float32(config.GetTemperature()), + TopP: float32(config.GetTopP()), + StreamOptions: &openai.StreamOptions{ + IncludeUsage: true, + }, + } + + for name := range config.GetTools() { + switch name { + case llm.ToolCodeInterpreter: + opt.Tools = append(opt.Tools, openai.Tool{Type: "code_interpreter"}) + case llm.ToolWebSearch: + } + } + + c := openai.NewClientWithConfig(openaiConf) + if callback != nil { + opt.Stream = true + r, err := c.CreateChatCompletionStream(context.Background(), opt) + if err == nil { + results := make([]string, 0) + usage := llm.TokenUsage{} + for { + if r2, err := r.Recv(); err == nil { + if r2.Choices != nil { + for _, ch := range r2.Choices { + text := ch.Delta.Content + callback(text) + results = append(results, text) + } + } + if r2.Usage != nil { + usage.AskTokens += int64(r2.Usage.PromptTokens) + usage.AnswerTokens += int64(r2.Usage.CompletionTokens) + usage.TotalTokens += int64(r2.Usage.TotalTokens) + } + } else { + break + } + } + _ = r.Close() + return strings.Join(results, ""), usage, nil + } else { + log.DefaultLogger.Error(err.Error()) + return "", llm.TokenUsage{}, err + } + } else { + r, err := c.CreateChatCompletion(context.Background(), opt) + + if err == nil { + results := make([]string, 0) + if r.Choices != nil { + for _, ch := range r.Choices { + results = append(results, ch.Message.Content) + } + } + return strings.Join(results, ""), llm.TokenUsage{ + AskTokens: int64(r.Usage.PromptTokens), + AnswerTokens: int64(r.Usage.CompletionTokens), + TotalTokens: int64(r.Usage.TotalTokens), + }, nil + } else { + //fmt.Println(u.BMagenta(err.Error()), u.BMagenta(u.JsonP(r))) + return "", llm.TokenUsage{}, err + } + } +} diff --git a/llm/openai/config.go b/llm/openai/config.go new file mode 100644 index 0000000..d8f0dc1 --- /dev/null +++ b/llm/openai/config.go @@ -0,0 +1,81 @@ +package openai + +import ( + "apigo.cc/ai/ai/llm" + "github.com/sashabaranov/go-openai" +) + +type LLM struct { + config llm.Config +} + +var TypeMap = map[string]openai.ChatMessagePartType{ + llm.TypeText: openai.ChatMessagePartTypeText, + llm.TypeImage: openai.ChatMessagePartTypeImageURL, + //llm.TypeVideo: "video_url", +} +var RoleMap = map[string]string{ + llm.RoleSystem: openai.ChatMessageRoleSystem, + llm.RoleUser: openai.ChatMessageRoleUser, + llm.RoleAssistant: openai.ChatMessageRoleAssistant, + llm.RoleTool: openai.ChatMessageRoleTool, +} + +const ( + ModelGPT_4_32k_0613 = "gpt-4-32k-0613" + ModelGPT_4_32k_0314 = "gpt-4-32k-0314" + ModelGPT_4_32k = "gpt-4-32k" + ModelGPT_4_0613 = "gpt-4-0613" + ModelGPT_4_0314 = "gpt-4-0314" + ModelGPT_4o = "gpt-4o" + ModelGPT_4o_2024_05_13 = "gpt-4o-2024-05-13" + ModelGPT_4o_2024_08_06 = "gpt-4o-2024-08-06" + ModelGPT_4o_mini = "gpt-4o-mini" + ModelGPT_4o_mini_2024_07_18 = "gpt-4o-mini-2024-07-18" + ModelGPT_4_turbo = "gpt-4-turbo" + ModelGPT_4_turbo_2024_04_09 = "gpt-4-turbo-2024-04-09" + ModelGPT_4_0125_preview = "gpt-4-0125-preview" + ModelGPT_4_1106_preview = "gpt-4-1106-preview" + ModelGPT_4_turbo_preview = "gpt-4-turbo-preview" + ModelGPT_4_vision_preview = "gpt-4-vision-preview" + ModelGPT_4 = "gpt-4" + ModelGPT_3_5_turbo_0125 = "gpt-3.5-turbo-0125" + ModelGPT_3_5_turbo_1106 = "gpt-3.5-turbo-1106" + ModelGPT_3_5_turbo_0613 = "gpt-3.5-turbo-0613" + ModelGPT_3_5_turbo_0301 = "gpt-3.5-turbo-0301" + ModelGPT_3_5_turbo_16k = "gpt-3.5-turbo-16k" + ModelGPT_3_5_turbo_16k_0613 = "gpt-3.5-turbo-16k-0613" + ModelGPT_3_5_turbo = "gpt-3.5-turbo" + ModelGPT_3_5_turbo_instruct = "gpt-3.5-turbo-instruct" + ModelDavinci_002 = "davinci-002" + ModelCurie = "curie" + ModelCurie_002 = "curie-002" + ModelAda_002 = "ada-002" + ModelBabbage_002 = "babbage-002" + ModelCode_davinci_002 = "code-davinci-002" + ModelCode_cushman_001 = "code-cushman-001" + ModelCode_davinci_001 = "code-davinci-001" + ModelDallE2Std = "dall-e-2" + ModelDallE2HD = "dall-e-2-hd" + ModelDallE3Std = "dall-e-3" + ModelDallE3HD = "dall-e-3-hd" +) + +func (ag *LLM) Support() llm.Support { + return llm.Support{ + Ask: true, + AskWithImage: true, + AskWithVideo: false, + AskWithCodeInterpreter: true, + AskWithWebSearch: false, + MakeImage: true, + MakeVideo: false, + Models: []string{ModelGPT_4_32k_0613, ModelGPT_4_32k_0314, ModelGPT_4_32k, ModelGPT_4_0613, ModelGPT_4_0314, ModelGPT_4o, ModelGPT_4o_2024_05_13, ModelGPT_4o_2024_08_06, ModelGPT_4o_mini, ModelGPT_4o_mini_2024_07_18, ModelGPT_4_turbo, ModelGPT_4_turbo_2024_04_09, ModelGPT_4_0125_preview, ModelGPT_4_1106_preview, ModelGPT_4_turbo_preview, ModelGPT_4_vision_preview, ModelGPT_4, ModelGPT_3_5_turbo_0125, ModelGPT_3_5_turbo_1106, ModelGPT_3_5_turbo_0613, ModelGPT_3_5_turbo_0301, ModelGPT_3_5_turbo_16k, ModelGPT_3_5_turbo_16k_0613, ModelGPT_3_5_turbo, ModelGPT_3_5_turbo_instruct, ModelDavinci_002, ModelCurie, ModelCurie_002, ModelAda_002, ModelBabbage_002, ModelCode_davinci_002, ModelCode_cushman_001, ModelCode_davinci_001, ModelDallE2Std, ModelDallE2HD, ModelDallE3Std, ModelDallE3HD}, + } +} + +func init() { + llm.Register("openai", func(config llm.Config) llm.LLM { + return &LLM{config: config} + }) +} diff --git a/llm/openai/gc.go b/llm/openai/gc.go new file mode 100644 index 0000000..a042e1b --- /dev/null +++ b/llm/openai/gc.go @@ -0,0 +1,73 @@ +package openai + +import ( + "apigo.cc/ai/ai/llm" + "context" + "github.com/sashabaranov/go-openai" + "strings" +) + +// func (lm *LLM) FastMakeImage(prompt, size, refImage string) ([]string, error) { +// return lm.MakeImage(ModelDallE3Std, prompt, size, refImage) +// } +// +// func (lm *LLM) BestMakeImage(prompt, size, refImage string) ([]string, error) { +// return lm.MakeImage(ModelDallE3HD, prompt, size, refImage) +// } +// +// func (lm *LLM) MakeImage(model, prompt, size, refImage string) ([]string, error) { +func (lm *LLM) FastMakeImage(prompt string, config llm.GCConfig) ([]string, error) { + config.Model = ModelDallE3Std + return lm.MakeImage(prompt, config) +} + +func (lm *LLM) BestMakeImage(prompt string, config llm.GCConfig) ([]string, error) { + config.Model = ModelDallE3HD + return lm.MakeImage(prompt, config) +} + +func (lm *LLM) MakeImage(prompt string, config llm.GCConfig) ([]string, error) { + openaiConf := openai.DefaultConfig(lm.config.ApiKey) + if lm.config.Endpoint != "" { + openaiConf.BaseURL = lm.config.Endpoint + } + c := openai.NewClientWithConfig(openaiConf) + style := openai.CreateImageStyleVivid + if (!strings.Contains(prompt, "vivid") || !strings.Contains(prompt, "生动的")) && (strings.Contains(prompt, "natural") || strings.Contains(prompt, "自然的")) { + style = openai.CreateImageStyleNatural + } + quality := openai.CreateImageQualityStandard + if strings.HasSuffix(config.Model, "-hd") { + quality = openai.CreateImageQualityHD + config.Model = config.Model[0 : len(config.Model)-3] + } + r, err := c.CreateImage(context.Background(), openai.ImageRequest{ + Prompt: prompt, + Model: config.Model, + Quality: quality, + Size: config.Size, + Style: style, + ResponseFormat: openai.CreateImageResponseFormatURL, + }) + if err == nil { + results := make([]string, 0) + for _, item := range r.Data { + results = append(results, item.URL) + } + return results, nil + } else { + return nil, err + } +} + +func (lm *LLM) FastMakeVideo(prompt string, config llm.GCConfig) ([]string, []string, error) { + return lm.MakeVideo(prompt, config) +} + +func (lm *LLM) BestMakeVideo(prompt string, config llm.GCConfig) ([]string, []string, error) { + return lm.MakeVideo(prompt, config) +} + +func (lm *LLM) MakeVideo(prompt string, config llm.GCConfig) ([]string, []string, error) { + return nil, nil, nil +} diff --git a/llm/zhipu/chat.go b/llm/zhipu/chat.go new file mode 100644 index 0000000..51d9732 --- /dev/null +++ b/llm/zhipu/chat.go @@ -0,0 +1,137 @@ +package zhipu + +import ( + "apigo.cc/ai/ai/llm" + "apigo.cc/ai/ai/llm/zhipu/zhipu" + "context" + "strings" +) + +func (lm *LLM) FastAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) { + return lm.Ask(messages, llm.ChatConfig{ + Model: ModelGLM4Flash, + }, callback) +} + +func (lm *LLM) LongAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) { + return lm.Ask(messages, llm.ChatConfig{ + Model: ModelGLM4Long, + }, callback) +} + +func (lm *LLM) BatterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) { + return lm.Ask(messages, llm.ChatConfig{ + Model: ModelGLM4Plus, + }, callback) +} + +func (lm *LLM) BestAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) { + return lm.Ask(messages, llm.ChatConfig{ + Model: ModelGLM40520, + }, callback) +} + +func (lm *LLM) MultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) { + return lm.Ask(messages, llm.ChatConfig{ + Model: ModelGLM4VPlus, + }, callback) +} + +func (lm *LLM) BestMultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) { + return lm.Ask(messages, llm.ChatConfig{ + Model: ModelGLM4V, + }, callback) +} + +func (lm *LLM) CodeInterpreterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) { + return lm.Ask(messages, llm.ChatConfig{ + Model: ModelGLM4AllTools, + Tools: map[string]any{llm.ToolCodeInterpreter: nil}, + }, callback) +} + +func (lm *LLM) WebSearchAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.TokenUsage, error) { + return lm.Ask(messages, llm.ChatConfig{ + Model: ModelGLM4AllTools, + Tools: map[string]any{llm.ToolWebSearch: nil}, + }, callback) +} + +func (lm *LLM) Ask(messages []llm.ChatMessage, config llm.ChatConfig, callback func(answer string)) (string, llm.TokenUsage, error) { + config.SetDefault(&lm.config.ChatConfig) + c, err := zhipu.NewClient(zhipu.WithAPIKey(lm.config.ApiKey), zhipu.WithBaseURL(lm.config.Endpoint)) + if err != nil { + return "", llm.TokenUsage{}, err + } + + cc := c.ChatCompletion(config.GetModel()) + for _, msg := range messages { + var contents []zhipu.ChatCompletionMultiContent + if msg.Contents != nil { + contents = make([]zhipu.ChatCompletionMultiContent, len(msg.Contents)) + for j, inPart := range msg.Contents { + part := zhipu.ChatCompletionMultiContent{} + part.Type = NameMap[inPart.Type] + switch inPart.Type { + case llm.TypeText: + part.Text = inPart.Content + case llm.TypeImage: + part.ImageURL = &zhipu.URLItem{URL: inPart.Content} + case llm.TypeVideo: + part.VideoURL = &zhipu.URLItem{URL: inPart.Content} + } + contents[j] = part + } + } + cc.AddMessage(zhipu.ChatCompletionMultiMessage{ + Role: NameMap[msg.Role], + Content: contents, + }) + } + + for name := range config.GetTools() { + switch name { + case llm.ToolCodeInterpreter: + cc.AddTool(zhipu.ChatCompletionToolCodeInterpreter{}) + case llm.ToolWebSearch: + cc.AddTool(zhipu.ChatCompletionToolWebBrowser{}) + } + } + + if config.GetMaxTokens() != 0 { + cc.SetMaxTokens(config.GetMaxTokens()) + } + if config.GetTemperature() != 0 { + cc.SetTemperature(config.GetTemperature()) + } + if config.GetTopP() != 0 { + cc.SetTopP(config.GetTopP()) + } + if callback != nil { + cc.SetStreamHandler(func(r2 zhipu.ChatCompletionResponse) error { + if r2.Choices != nil { + for _, ch := range r2.Choices { + text := ch.Delta.Content + callback(text) + } + } + return nil + }) + } + + if r, err := cc.Do(context.Background()); err == nil { + results := make([]string, 0) + if r.Choices != nil { + for _, ch := range r.Choices { + results = append(results, ch.Message.Content) + } + } + return strings.Join(results, ""), llm.TokenUsage{ + AskTokens: r.Usage.PromptTokens, + AnswerTokens: r.Usage.CompletionTokens, + TotalTokens: r.Usage.TotalTokens, + }, nil + } else { + return "", llm.TokenUsage{}, err + } +} diff --git a/llm/zhipu/config.go b/llm/zhipu/config.go new file mode 100644 index 0000000..d6dfb50 --- /dev/null +++ b/llm/zhipu/config.go @@ -0,0 +1,60 @@ +package zhipu + +import ( + "apigo.cc/ai/ai/llm" + "apigo.cc/ai/ai/llm/zhipu/zhipu" +) + +type LLM struct { + config llm.Config +} + +var NameMap = map[string]string{ + llm.TypeText: zhipu.MultiContentTypeText, + llm.TypeImage: zhipu.MultiContentTypeImageURL, + llm.TypeVideo: zhipu.MultiContentTypeVideoURL, + llm.RoleSystem: zhipu.RoleSystem, + llm.RoleUser: zhipu.RoleUser, + llm.RoleAssistant: zhipu.RoleAssistant, + llm.RoleTool: zhipu.RoleTool, +} + +const ( + ModelGLM4Plus = "GLM-4-Plus" + ModelGLM40520 = "GLM-4-0520" + ModelGLM4Long = "GLM-4-Long" + ModelGLM4AirX = "GLM-4-AirX" + ModelGLM4Air = "GLM-4-Air" + ModelGLM4Flash = "GLM-4-Flash" + ModelGLM4AllTools = "GLM-4-AllTools" + ModelGLM4 = "GLM-4" + ModelGLM4VPlus = "GLM-4V-Plus" + ModelGLM4V = "GLM-4V" + ModelCogVideoX = "CogVideoX" + ModelCogView3Plus = "CogView-3-Plus" + ModelCogView3 = "CogView-3" + ModelEmbedding3 = "Embedding-3" + ModelEmbedding2 = "Embedding-2" + ModelCharGLM3 = "CharGLM-3" + ModelEmohaa = "Emohaa" + ModelCodeGeeX4 = "CodeGeeX-4" +) + +func (lm *LLM) Support() llm.Support { + return llm.Support{ + Ask: true, + AskWithImage: true, + AskWithVideo: true, + AskWithCodeInterpreter: true, + AskWithWebSearch: true, + MakeImage: true, + MakeVideo: true, + Models: []string{ModelGLM4Plus, ModelGLM40520, ModelGLM4Long, ModelGLM4AirX, ModelGLM4Air, ModelGLM4Flash, ModelGLM4AllTools, ModelGLM4, ModelGLM4VPlus, ModelGLM4V, ModelCogVideoX, ModelCogView3Plus, ModelCogView3, ModelEmbedding3, ModelEmbedding2, ModelCharGLM3, ModelEmohaa, ModelCodeGeeX4}, + } +} + +func init() { + llm.Register("zhipu", func(config llm.Config) llm.LLM { + return &LLM{config: config} + }) +} diff --git a/llm/zhipu/gc.go b/llm/zhipu/gc.go new file mode 100644 index 0000000..82f404a --- /dev/null +++ b/llm/zhipu/gc.go @@ -0,0 +1,88 @@ +package zhipu + +import ( + "apigo.cc/ai/ai/llm" + "apigo.cc/ai/ai/llm/zhipu/zhipu" + "context" + "errors" + "time" +) + +func (lm *LLM) FastMakeImage(prompt string, config llm.GCConfig) ([]string, error) { + config.Model = ModelCogView3Plus + return lm.MakeImage(prompt, config) +} + +func (lm *LLM) BestMakeImage(prompt string, config llm.GCConfig) ([]string, error) { + config.Model = ModelCogView3 + return lm.MakeImage(prompt, config) +} + +func (lm *LLM) MakeImage(prompt string, config llm.GCConfig) ([]string, error) { + c, err := zhipu.NewClient(zhipu.WithAPIKey(lm.config.ApiKey), zhipu.WithBaseURL(lm.config.Endpoint)) + if err != nil { + return nil, err + } + + cc := c.ImageGeneration(config.Model).SetPrompt(prompt) + if config.Size != "" { + cc.SetSize(config.Size) + } + + if r, err := cc.Do(context.Background()); err == nil { + results := make([]string, 0) + for _, item := range r.Data { + results = append(results, item.URL) + } + return results, nil + } else { + return nil, err + } +} + +func (lm *LLM) FastMakeVideo(prompt string, config llm.GCConfig) ([]string, []string, error) { + config.Model = ModelCogVideoX + return lm.MakeVideo(prompt, config) +} + +func (lm *LLM) BestMakeVideo(prompt string, config llm.GCConfig) ([]string, []string, error) { + config.Model = ModelCogVideoX + return lm.MakeVideo(prompt, config) +} + +func (lm *LLM) MakeVideo(prompt string, config llm.GCConfig) ([]string, []string, error) { + c, err := zhipu.NewClient(zhipu.WithAPIKey(lm.config.ApiKey), zhipu.WithBaseURL(lm.config.Endpoint)) + if err != nil { + return nil, nil, err + } + + cc := c.VideoGeneration(config.Model).SetPrompt(prompt) + if config.Ref != "" { + cc.SetImageURL(config.Ref) + } + + if resp, err := cc.Do(context.Background()); err == nil { + for i := 0; i < 1200; i++ { + r, err := c.AsyncResult(resp.ID).Do(context.Background()) + if err != nil { + return nil, nil, err + } + if r.TaskStatus == zhipu.VideoGenerationTaskStatusSuccess { + covers := make([]string, 0) + results := make([]string, 0) + for _, item := range r.VideoResult { + results = append(results, item.URL) + covers = append(covers, item.CoverImageURL) + } + return results, covers, nil + } + if r.TaskStatus == zhipu.VideoGenerationTaskStatusFail { + return nil, nil, errors.New("fail on task " + resp.ID) + } + time.Sleep(3 * time.Second) + } + return nil, nil, errors.New("timeout on task " + resp.ID) + } else { + return nil, nil, err + } +} diff --git a/llm/zhipu/zhipu/CHANGELOG.md b/llm/zhipu/zhipu/CHANGELOG.md new file mode 100644 index 0000000..8e1033f --- /dev/null +++ b/llm/zhipu/zhipu/CHANGELOG.md @@ -0,0 +1,108 @@ +# Changelog +All notable changes to this project will be documented in this file. See [conventional commits](https://www.conventionalcommits.org/) for commit guidelines. + +- - - +## v0.1.2 - 2024-08-15 +#### Bug Fixes +- add FinishReasonStopSequence - (01b4201) - GUO YANKE +#### Documentation +- update README.md [skip ci] - (e48a88b) - GUO YANKE +#### Features +- add videos/generations - (7261999) - GUO YANKE +#### Miscellaneous Chores +- relaxing go version to 1.18 - (6acc17c) - GUO YANKE + +- - - + +## v0.1.1 - 2024-07-17 +#### Documentation +- update README.md [skip ci] - (695432a) - GUO YANKE +#### Features +- add support for GLM-4-AllTools - (9627a36) - GUO YANKE + +- - - + +## v0.1.0 - 2024-06-28 +#### Bug Fixes +- rename client function for batch list - (40ac05f) - GUO YANKE +#### Documentation +- update README.md [skip ci] - (6ce5754) - GUO YANKE +#### Features +- add knowledge capacity service - (4ce62b3) - GUO YANKE +#### Refactoring +- update batch service - (b92d438) - GUO YANKE +- update chat completion service - (19dd77f) - GUO YANKE +- update embedding service - (c1bbc2d) - GUO YANKE +- update file services - (7ef4d87) - GUO YANKE +- update fine tune services, using APIError - (15aed88) - GUO YANKE +- update fine tune services - (664523b) - GUO YANKE +- update image generation service - (a18e028) - GUO YANKE +- update knowledge services - (c7bfb73) - GUO YANKE + +- - - + +## v0.0.6 - 2024-06-28 +#### Features +- add batch support for result reader - (c062095) - GUO YANKE +- add fine tune services - (f172f51) - GUO YANKE +- add knowledge service - (09792b5) - GUO YANKE + +- - - + +## v0.0.5 - 2024-06-28 +#### Bug Fixes +- api error parsing - (60a17f4) - GUO YANKE +#### Features +- add batch service - (389aec3) - GUO YANKE +- add batch support for chat completions, image generations and embeddings - (c017ffd) - GUO YANKE +- add file edit/get/delete service - (8a4d309) - GUO YANKE +- add file create serivce - (6d2140b) - GUO YANKE + +- - - + +## v0.0.4 - 2024-06-26 +#### Bug Fixes +- remove Client.R(), hide resty for future removal - (dc2a4ca) - GUO YANKE +#### Features +- add meta support for charglm - (fdd20e7) - GUO YANKE +- add client option to custom http client - (c62d6a9) - GUO YANKE + +- - - + +## v0.0.3 - 2024-06-26 +#### Features +- add image generation service - (9f3f54f) - GUO YANKE +- add support for vision models - (2dcd82a) - GUO YANKE +- add embedding service - (f57806a) - GUO YANKE + +- - - + +## v0.0.2 - 2024-06-26 +#### Bug Fixes +- **(deps)** update golang-jwt/jwt to v5 - (2f76a57) - GUO YANKE +#### Features +- add constants for roles - (3d08a72) - GUO YANKE + +- - - + +## v0.0.1 - 2024-06-26 +#### Bug Fixes +- add json tag "omitempty" to various types - (bf81097) - GUO YANKE +#### Continuous Integration +- add github action workflows for testing - (5a64987) - GUO YANKE +#### Documentation +- update README.md [skip ci] - (d504f57) - GUO YANKE +#### Features +- add chat completion in stream mode - (130fe1d) - GUO YANKE +- add chat completion in non-stream mode - (2326e37) - GUO YANKE +- support debug option while creating client - (0f104d8) - GUO YANKE +- add APIError and APIErrorResponse - (1886d85) - GUO YANKE +- add client struct - (710d8e8) - GUO YANKE +#### Refactoring +- change signature of Client#createJWT since there is no reason to fail - (f0d7887) - GUO YANKE +#### Tests +- add client_test.go - (a3fc217) - GUO YANKE + +- - - + +Changelog generated by [cocogitto](https://github.com/cocogitto/cocogitto). \ No newline at end of file diff --git a/llm/zhipu/zhipu/LICENSE b/llm/zhipu/zhipu/LICENSE new file mode 100644 index 0000000..67dc60b --- /dev/null +++ b/llm/zhipu/zhipu/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Yanke G. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/llm/zhipu/zhipu/README.md b/llm/zhipu/zhipu/README.md new file mode 100644 index 0000000..555e351 --- /dev/null +++ b/llm/zhipu/zhipu/README.md @@ -0,0 +1,280 @@ +# zhipu + +[![Go Reference](https://pkg.go.dev/badge/github.com/yankeguo/zhipu.svg)](https://pkg.go.dev/github.com/yankeguo/zhipu) +[![go](https://github.com/yankeguo/zhipu/actions/workflows/go.yml/badge.svg)](https://github.com/yankeguo/zhipu/actions/workflows/go.yml) + +[中文文档](README.zh.md) + +A 3rd-Party Golang Client Library for Zhipu AI Platform + +## Usage + +### Install the package + +```bash +go get -u github.com/yankeguo/zhipu +``` + +### Create a client + +```go +// this will use environment variables ZHIPUAI_API_KEY +client, err := zhipu.NewClient() +// or you can specify the API key +client, err = zhipu.NewClient(zhipu.WithAPIKey("your api key")) +``` + +### Use the client + +**ChatCompletion** + +```go +service := client.ChatCompletion("glm-4-flash"). + AddMessage(zhipu.ChatCompletionMessage{ + Role: "user", + Content: "你好", + }) + +res, err := service.Do(context.Background()) + +if err != nil { + zhipu.GetAPIErrorCode(err) // get the API error code +} else { + println(res.Choices[0].Message.Content) +} +``` + +**ChatCompletion (Stream)** + +```go +service := client.ChatCompletion("glm-4-flash"). + AddMessage(zhipu.ChatCompletionMessage{ + Role: "user", + Content: "你好", + }).SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error { + println(chunk.Choices[0].Delta.Content) + return nil + }) + +res, err := service.Do(context.Background()) + +if err != nil { + zhipu.GetAPIErrorCode(err) // get the API error code +} else { + // this package will combine the stream chunks and build a final result mimicking the non-streaming API + println(res.Choices[0].Message.Content) +} +``` + +**ChatCompletion (Stream with GLM-4-AllTools)** + +```go +// CodeInterpreter +s := client.ChatCompletion("GLM-4-AllTools") +s.AddMessage(zhipu.ChatCompletionMultiMessage{ + Role: "user", + Content: []zhipu.ChatCompletionMultiContent{ + { + Type: "text", + Text: "计算[5,10,20,700,99,310,978,100]的平均值和方差。", + }, + }, +}) +s.AddTool(zhipu.ChatCompletionToolCodeInterpreter{ + Sandbox: zhipu.Ptr(CodeInterpreterSandboxAuto), +}) +s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error { + for _, c := range chunk.Choices { + for _, tc := range c.Delta.ToolCalls { + if tc.Type == ToolTypeCodeInterpreter && tc.CodeInterpreter != nil { + if tc.CodeInterpreter.Input != "" { + // DO SOMETHING + } + if len(tc.CodeInterpreter.Outputs) > 0 { + // DO SOMETHING + } + } + } + } + return nil +}) + +// WebBrowser +// CAUTION: NOT 'WebSearch' +s := client.ChatCompletion("GLM-4-AllTools") +s.AddMessage(zhipu.ChatCompletionMultiMessage{ + Role: "user", + Content: []zhipu.ChatCompletionMultiContent{ + { + Type: "text", + Text: "搜索下本周深圳天气如何", + }, + }, +}) +s.AddTool(zhipu.ChatCompletionToolWebBrowser{}) +s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error { + for _, c := range chunk.Choices { + for _, tc := range c.Delta.ToolCalls { + if tc.Type == ToolTypeWebBrowser && tc.WebBrowser != nil { + if tc.WebBrowser.Input != "" { + // DO SOMETHING + } + if len(tc.WebBrowser.Outputs) > 0 { + // DO SOMETHING + } + } + } + } + return nil +}) +s.Do(context.Background()) + +// DrawingTool +s := client.ChatCompletion("GLM-4-AllTools") +s.AddMessage(zhipu.ChatCompletionMultiMessage{ + Role: "user", + Content: []zhipu.ChatCompletionMultiContent{ + { + Type: "text", + Text: "画一个正弦函数图像", + }, + }, +}) +s.AddTool(zhipu.ChatCompletionToolDrawingTool{}) +s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error { + for _, c := range chunk.Choices { + for _, tc := range c.Delta.ToolCalls { + if tc.Type == ToolTypeDrawingTool && tc.DrawingTool != nil { + if tc.DrawingTool.Input != "" { + // DO SOMETHING + } + if len(tc.DrawingTool.Outputs) > 0 { + // DO SOMETHING + } + } + } + } + return nil +}) +s.Do(context.Background()) +``` + +**Embedding** + +```go +service := client.Embedding("embedding-v2").SetInput("你好呀") +service.Do(context.Background()) +``` + +**Image Generation** + +```go +service := client.ImageGeneration("cogview-3").SetPrompt("一只可爱的小猫咪") +service.Do(context.Background()) +``` + +**Video Generation** + +```go +service := client.VideoGeneration("cogvideox").SetPrompt("一只可爱的小猫咪") +resp, err := service.Do(context.Background()) + +for { + result, err := client.AsyncResult(resp.ID).Do(context.Background()) + + if result.TaskStatus == zhipu.VideoGenerationTaskStatusSuccess { + _ = result.VideoResult[0].URL + _ = result.VideoResult[0].CoverImageURL + break + } + + if result.TaskStatus != zhipu.VideoGenerationTaskStatusProcessing { + break + } + + time.Sleep(5 * time.Second) +} +``` + +**Upload File (Retrieval)** + +```go +service := client.FileCreate(zhipu.FilePurposeRetrieval) +service.SetLocalFile(filepath.Join("testdata", "test-file.txt")) +service.SetKnowledgeID("your-knowledge-id") + +service.Do(context.Background()) +``` + +**Upload File (Fine-Tune)** + +```go +service := client.FileCreate(zhipu.FilePurposeFineTune) +service.SetLocalFile(filepath.Join("testdata", "test-file.jsonl")) +service.Do(context.Background()) +``` + +**Batch Create** + +```go +service := client.BatchCreate(). + SetInputFileID("fileid"). + SetCompletionWindow(zhipu.BatchCompletionWindow24h). + SetEndpoint(BatchEndpointV4ChatCompletions) +service.Do(context.Background()) +``` + +**Knowledge Base** + +```go +client.KnowledgeCreate("") +client.KnowledgeEdit("") +``` + +**Fine Tune** + +```go +client.FineTuneCreate("") +``` + +### Batch Support + +**Batch File Writer** + +```go +f, err := os.OpenFile("batch.jsonl", os.O_CREATE|os.O_WRONLY, 0644) + +bw := zhipu.NewBatchFileWriter(f) + +bw.Add("action_1", client.ChatCompletion("glm-4-flash"). + AddMessage(zhipu.ChatCompletionMessage{ + Role: "user", + Content: "你好", + })) +bw.Add("action_2", client.Embedding("embedding-v2").SetInput("你好呀")) +bw.Add("action_3", client.ImageGeneration("cogview-3").SetPrompt("一只可爱的小猫咪")) +``` + +**Batch Result Reader** + +```go +br := zhipu.NewBatchResultReader[zhipu.ChatCompletionResponse](r) + +for { + var res zhipu.BatchResult[zhipu.ChatCompletionResponse] + err := br.Read(&res) + if err != nil { + break + } +} +``` + +## Donation + +Executing unit tests will actually call the ChatGLM API and consume my quota. Please donate and thank you for your support! + + + +## Credits + +GUO YANKE, MIT License diff --git a/llm/zhipu/zhipu/README.zh.md b/llm/zhipu/zhipu/README.zh.md new file mode 100644 index 0000000..2563f05 --- /dev/null +++ b/llm/zhipu/zhipu/README.zh.md @@ -0,0 +1,278 @@ +# zhipu + +[![Go Reference](https://pkg.go.dev/badge/github.com/yankeguo/zhipu.svg)](https://pkg.go.dev/github.com/yankeguo/zhipu) +[![go](https://github.com/yankeguo/zhipu/actions/workflows/go.yml/badge.svg)](https://github.com/yankeguo/zhipu/actions/workflows/go.yml) + +Zhipu AI 平台第三方 Golang 客户端库 + +## 用法 + +### 安装库 + +```bash +go get -u github.com/yankeguo/zhipu +``` + +### 创建客户端 + +```go +// 默认使用环境变量 ZHIPUAI_API_KEY +client, err := zhipu.NewClient() +// 或者手动指定密钥 +client, err = zhipu.NewClient(zhipu.WithAPIKey("your api key")) +``` + +### 使用客户端 + +**ChatCompletion(大语言模型)** + +```go +service := client.ChatCompletion("glm-4-flash"). + AddMessage(zhipu.ChatCompletionMessage{ + Role: "user", + Content: "你好", + }) + +res, err := service.Do(context.Background()) + +if err != nil { + zhipu.GetAPIErrorCode(err) // get the API error code +} else { + println(res.Choices[0].Message.Content) +} +``` + +**ChatCompletion(流式调用大语言模型)** + +```go +service := client.ChatCompletion("glm-4-flash"). + AddMessage(zhipu.ChatCompletionMessage{ + Role: "user", + Content: "你好", + }).SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error { + println(chunk.Choices[0].Delta.Content) + return nil + }) + +res, err := service.Do(context.Background()) + +if err != nil { + zhipu.GetAPIErrorCode(err) // get the API error code +} else { + // this package will combine the stream chunks and build a final result mimicking the non-streaming API + println(res.Choices[0].Message.Content) +} +``` + +**ChatCompletion(流式调用大语言工具模型GLM-4-AllTools)** + +```go +// CodeInterpreter +s := client.ChatCompletion("GLM-4-AllTools") +s.AddMessage(zhipu.ChatCompletionMultiMessage{ + Role: "user", + Content: []zhipu.ChatCompletionMultiContent{ + { + Type: "text", + Text: "计算[5,10,20,700,99,310,978,100]的平均值和方差。", + }, + }, +}) +s.AddTool(zhipu.ChatCompletionToolCodeInterpreter{ + Sandbox: zhipu.Ptr(CodeInterpreterSandboxAuto), +}) +s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error { + for _, c := range chunk.Choices { + for _, tc := range c.Delta.ToolCalls { + if tc.Type == ToolTypeCodeInterpreter && tc.CodeInterpreter != nil { + if tc.CodeInterpreter.Input != "" { + // DO SOMETHING + } + if len(tc.CodeInterpreter.Outputs) > 0 { + // DO SOMETHING + } + } + } + } + return nil +}) + +// WebBrowser +// CAUTION: NOT 'WebSearch' +s := client.ChatCompletion("GLM-4-AllTools") +s.AddMessage(zhipu.ChatCompletionMultiMessage{ + Role: "user", + Content: []zhipu.ChatCompletionMultiContent{ + { + Type: "text", + Text: "搜索下本周深圳天气如何", + }, + }, +}) +s.AddTool(zhipu.ChatCompletionToolWebBrowser{}) +s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error { + for _, c := range chunk.Choices { + for _, tc := range c.Delta.ToolCalls { + if tc.Type == ToolTypeWebBrowser && tc.WebBrowser != nil { + if tc.WebBrowser.Input != "" { + // DO SOMETHING + } + if len(tc.WebBrowser.Outputs) > 0 { + // DO SOMETHING + } + } + } + } + return nil +}) +s.Do(context.Background()) + +// DrawingTool +s := client.ChatCompletion("GLM-4-AllTools") +s.AddMessage(zhipu.ChatCompletionMultiMessage{ + Role: "user", + Content: []zhipu.ChatCompletionMultiContent{ + { + Type: "text", + Text: "画一个正弦函数图像", + }, + }, +}) +s.AddTool(zhipu.ChatCompletionToolDrawingTool{}) +s.SetStreamHandler(func(chunk zhipu.ChatCompletionResponse) error { + for _, c := range chunk.Choices { + for _, tc := range c.Delta.ToolCalls { + if tc.Type == ToolTypeDrawingTool && tc.DrawingTool != nil { + if tc.DrawingTool.Input != "" { + // DO SOMETHING + } + if len(tc.DrawingTool.Outputs) > 0 { + // DO SOMETHING + } + } + } + } + return nil +}) +s.Do(context.Background()) +``` + +**Embedding** + +```go +service := client.Embedding("embedding-v2").SetInput("你好呀") +service.Do(context.Background()) +``` + +**ImageGeneration(图像生成)** + +```go +service := client.ImageGeneration("cogview-3").SetPrompt("一只可爱的小猫咪") +service.Do(context.Background()) +``` + +**VideoGeneration(视频生成)** + +```go +service := client.VideoGeneration("cogvideox").SetPrompt("一只可爱的小猫咪") +resp, err := service.Do(context.Background()) + +for { + result, err := client.AsyncResult(resp.ID).Do(context.Background()) + + if result.TaskStatus == zhipu.VideoGenerationTaskStatusSuccess { + _ = result.VideoResult[0].URL + _ = result.VideoResult[0].CoverImageURL + break + } + + if result.TaskStatus != zhipu.VideoGenerationTaskStatusProcessing { + break + } + + time.Sleep(5 * time.Second) +} +``` + +**UploadFile(上传文件用于取回)** + +```go +service := client.FileCreate(zhipu.FilePurposeRetrieval) +service.SetLocalFile(filepath.Join("testdata", "test-file.txt")) +service.SetKnowledgeID("your-knowledge-id") + +service.Do(context.Background()) +``` + +**UploadFile(上传文件用于微调)** + +```go +service := client.FileCreate(zhipu.FilePurposeFineTune) +service.SetLocalFile(filepath.Join("testdata", "test-file.jsonl")) +service.Do(context.Background()) +``` + +**BatchCreate(创建批量任务)** + +```go +service := client.BatchCreate(). + SetInputFileID("fileid"). + SetCompletionWindow(zhipu.BatchCompletionWindow24h). + SetEndpoint(BatchEndpointV4ChatCompletions) +service.Do(context.Background()) +``` + +**KnowledgeBase(知识库)** + +```go +client.KnowledgeCreate("") +client.KnowledgeEdit("") +``` + +**FineTune(微调)** + +```go +client.FineTuneCreate("") +``` + +### 批量任务辅助工具 + +**批量任务文件创建** + +```go +f, err := os.OpenFile("batch.jsonl", os.O_CREATE|os.O_WRONLY, 0644) + +bw := zhipu.NewBatchFileWriter(f) + +bw.Add("action_1", client.ChatCompletion("glm-4-flash"). + AddMessage(zhipu.ChatCompletionMessage{ + Role: "user", + Content: "你好", + })) +bw.Add("action_2", client.Embedding("embedding-v2").SetInput("你好呀")) +bw.Add("action_3", client.ImageGeneration("cogview-3").SetPrompt("一只可爱的小猫咪")) +``` + +**批量任务结果解析** + +```go +br := zhipu.NewBatchResultReader[zhipu.ChatCompletionResponse](r) + +for { + var res zhipu.BatchResult[zhipu.ChatCompletionResponse] + err := br.Read(&res) + if err != nil { + break + } +} +``` + +## 赞助 + +执行单元测试会真实调用GLM接口,消耗我充值的额度,开发不易,请微信扫码捐赠,感谢您的支持! + + + +## 许可证 + +GUO YANKE, MIT License diff --git a/llm/zhipu/zhipu/async_result.go b/llm/zhipu/zhipu/async_result.go new file mode 100644 index 0000000..5e51db1 --- /dev/null +++ b/llm/zhipu/zhipu/async_result.go @@ -0,0 +1,63 @@ +package zhipu + +import ( + "context" + + "github.com/go-resty/resty/v2" +) + +// AsyncResultService creates a new async result get service +type AsyncResultService struct { + client *Client + + id string +} + +// AsyncResultVideo is the video result of the AsyncResultService +type AsyncResultVideo struct { + URL string `json:"url"` + CoverImageURL string `json:"cover_image_url"` +} + +// AsyncResultResponse is the response of the AsyncResultService +type AsyncResultResponse struct { + Model string `json:"model"` + TaskStatus string `json:"task_status"` + RequestID string `json:"request_id"` + ID string `json:"id"` + VideoResult []AsyncResultVideo `json:"video_result"` +} + +// NewAsyncResultService creates a new async result get service +func NewAsyncResultService(client *Client) *AsyncResultService { + return &AsyncResultService{ + client: client, + } +} + +// SetID sets the id parameter +func (s *AsyncResultService) SetID(id string) *AsyncResultService { + s.id = id + return s +} + +func (s *AsyncResultService) Do(ctx context.Context) (res AsyncResultResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + if resp, err = s.client.request(ctx). + SetResult(&res). + SetError(&apiError). + Get("async-result/" + s.id); err != nil { + return + } + + if resp.IsError() { + err = apiError + return + } + + return +} diff --git a/llm/zhipu/zhipu/async_result_test.go b/llm/zhipu/zhipu/async_result_test.go new file mode 100644 index 0000000..f4a78b5 --- /dev/null +++ b/llm/zhipu/zhipu/async_result_test.go @@ -0,0 +1 @@ +package zhipu diff --git a/llm/zhipu/zhipu/batch.go b/llm/zhipu/zhipu/batch.go new file mode 100644 index 0000000..d661f56 --- /dev/null +++ b/llm/zhipu/zhipu/batch.go @@ -0,0 +1,258 @@ +package zhipu + +import ( + "context" + "encoding/json" + "strconv" + + "github.com/go-resty/resty/v2" +) + +const ( + BatchEndpointV4ChatCompletions = "/v4/chat/completions" + BatchEndpointV4ImagesGenerations = "/v4/images/generations" + BatchEndpointV4Embeddings = "/v4/embeddings" + BatchEndpointV4VideosGenerations = "/v4/videos/generations" + + BatchCompletionWindow24h = "24h" +) + +// BatchRequestCounts represents the counts of the batch requests. +type BatchRequestCounts struct { + Total int64 `json:"total"` + Completed int64 `json:"completed"` + Failed int64 `json:"failed"` +} + +// BatchItem represents a batch item. +type BatchItem struct { + ID string `json:"id"` + Object any `json:"object"` + Endpoint string `json:"endpoint"` + InputFileID string `json:"input_file_id"` + CompletionWindow string `json:"completion_window"` + Status string `json:"status"` + OutputFileID string `json:"output_file_id"` + ErrorFileID string `json:"error_file_id"` + CreatedAt int64 `json:"created_at"` + InProgressAt int64 `json:"in_progress_at"` + ExpiresAt int64 `json:"expires_at"` + FinalizingAt int64 `json:"finalizing_at"` + CompletedAt int64 `json:"completed_at"` + FailedAt int64 `json:"failed_at"` + ExpiredAt int64 `json:"expired_at"` + CancellingAt int64 `json:"cancelling_at"` + CancelledAt int64 `json:"cancelled_at"` + RequestCounts BatchRequestCounts `json:"request_counts"` + Metadata json.RawMessage `json:"metadata"` +} + +// BatchCreateService is a service to create a batch. +type BatchCreateService struct { + client *Client + + inputFileID string + endpoint string + completionWindow string + metadata any +} + +// NewBatchCreateService creates a new BatchCreateService. +func NewBatchCreateService(client *Client) *BatchCreateService { + return &BatchCreateService{client: client} +} + +// SetInputFileID sets the input file id for the batch. +func (s *BatchCreateService) SetInputFileID(inputFileID string) *BatchCreateService { + s.inputFileID = inputFileID + return s +} + +// SetEndpoint sets the endpoint for the batch. +func (s *BatchCreateService) SetEndpoint(endpoint string) *BatchCreateService { + s.endpoint = endpoint + return s +} + +// SetCompletionWindow sets the completion window for the batch. +func (s *BatchCreateService) SetCompletionWindow(window string) *BatchCreateService { + s.completionWindow = window + return s +} + +// SetMetadata sets the metadata for the batch. +func (s *BatchCreateService) SetMetadata(metadata any) *BatchCreateService { + s.metadata = metadata + return s +} + +// Do executes the batch create service. +func (s *BatchCreateService) Do(ctx context.Context) (res BatchItem, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + if resp, err = s.client.request(ctx). + SetBody(M{ + "input_file_id": s.inputFileID, + "endpoint": s.endpoint, + "completion_window": s.completionWindow, + "metadata": s.metadata, + }). + SetResult(&res). + SetError(&apiError). + Post("batches"); err != nil { + return + } + + if resp.IsError() { + err = apiError + } + + return +} + +// BatchGetService is a service to get a batch. +type BatchGetService struct { + client *Client + batchID string +} + +// BatchGetResponse represents the response of the batch get service. +type BatchGetResponse = BatchItem + +// NewBatchGetService creates a new BatchGetService. +func NewBatchGetService(client *Client) *BatchGetService { + return &BatchGetService{client: client} +} + +// SetBatchID sets the batch id for the batch get service. +func (s *BatchGetService) SetBatchID(batchID string) *BatchGetService { + s.batchID = batchID + return s +} + +// Do executes the batch get service. +func (s *BatchGetService) Do(ctx context.Context) (res BatchGetResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + if resp, err = s.client.request(ctx). + SetPathParam("batch_id", s.batchID). + SetResult(&res). + SetError(&apiError). + Get("batches/{batch_id}"); err != nil { + return + } + + if resp.IsError() { + err = apiError + } + + return +} + +// BatchCancelService is a service to cancel a batch. +type BatchCancelService struct { + client *Client + batchID string +} + +// NewBatchCancelService creates a new BatchCancelService. +func NewBatchCancelService(client *Client) *BatchCancelService { + return &BatchCancelService{client: client} +} + +// SetBatchID sets the batch id for the batch cancel service. +func (s *BatchCancelService) SetBatchID(batchID string) *BatchCancelService { + s.batchID = batchID + return s +} + +// Do executes the batch cancel service. +func (s *BatchCancelService) Do(ctx context.Context) (err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + if resp, err = s.client.request(ctx). + SetPathParam("batch_id", s.batchID). + SetBody(M{}). + SetError(&apiError). + Post("batches/{batch_id}/cancel"); err != nil { + return + } + + if resp.IsError() { + err = apiError + } + + return +} + +// BatchListService is a service to list batches. +type BatchListService struct { + client *Client + + after *string + limit *int +} + +// BatchListResponse represents the response of the batch list service. +type BatchListResponse struct { + Object string `json:"object"` + Data []BatchItem `json:"data"` + FirstID string `json:"first_id"` + LastID string `json:"last_id"` + HasMore bool `json:"has_more"` +} + +// NewBatchListService creates a new BatchListService. +func NewBatchListService(client *Client) *BatchListService { + return &BatchListService{client: client} +} + +// SetAfter sets the after cursor for the batch list service. +func (s *BatchListService) SetAfter(after string) *BatchListService { + s.after = &after + return s +} + +// SetLimit sets the limit for the batch list service. +func (s *BatchListService) SetLimit(limit int) *BatchListService { + s.limit = &limit + return s +} + +// Do executes the batch list service. +func (s *BatchListService) Do(ctx context.Context) (res BatchListResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + req := s.client.request(ctx) + if s.after != nil { + req.SetQueryParam("after", *s.after) + } + if s.limit != nil { + req.SetQueryParam("limit", strconv.Itoa(*s.limit)) + } + + if resp, err = req. + SetResult(&res). + SetError(&apiError). + Get("batches"); err != nil { + return + } + + if resp.IsError() { + err = apiError + } + + return +} diff --git a/llm/zhipu/zhipu/batch_support.go b/llm/zhipu/zhipu/batch_support.go new file mode 100644 index 0000000..9427ef4 --- /dev/null +++ b/llm/zhipu/zhipu/batch_support.go @@ -0,0 +1,63 @@ +package zhipu + +import ( + "encoding/json" + "io" +) + +// BatchSupport is the interface for services with batch support. +type BatchSupport interface { + BatchMethod() string + BatchURL() string + BatchBody() any +} + +// BatchFileWriter is a writer for batch files. +type BatchFileWriter struct { + w io.Writer + je *json.Encoder +} + +// NewBatchFileWriter creates a new BatchFileWriter. +func NewBatchFileWriter(w io.Writer) *BatchFileWriter { + return &BatchFileWriter{w: w, je: json.NewEncoder(w)} +} + +// Write writes a batch file. +func (b *BatchFileWriter) Write(customID string, s BatchSupport) error { + return b.je.Encode(M{ + "custom_id": customID, + "method": s.BatchMethod(), + "url": s.BatchURL(), + "body": s.BatchBody(), + }) +} + +// BatchResultResponse is the response of a batch result. +type BatchResultResponse[T any] struct { + StatusCode int `json:"status_code"` + Body T `json:"body"` +} + +// BatchResult is the result of a batch. +type BatchResult[T any] struct { + ID string `json:"id"` + CustomID string `json:"custom_id"` + Response BatchResultResponse[T] `json:"response"` +} + +// BatchResultReader reads batch results. +type BatchResultReader[T any] struct { + r io.Reader + jd *json.Decoder +} + +// NewBatchResultReader creates a new BatchResultReader. +func NewBatchResultReader[T any](r io.Reader) *BatchResultReader[T] { + return &BatchResultReader[T]{r: r, jd: json.NewDecoder(r)} +} + +// Read reads a batch result. +func (r *BatchResultReader[T]) Read(out *BatchResult[T]) error { + return r.jd.Decode(out) +} diff --git a/llm/zhipu/zhipu/batch_support_test.go b/llm/zhipu/zhipu/batch_support_test.go new file mode 100644 index 0000000..59a19c6 --- /dev/null +++ b/llm/zhipu/zhipu/batch_support_test.go @@ -0,0 +1,73 @@ +package zhipu + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBatchFileWriter(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + buf := &bytes.Buffer{} + + w := NewBatchFileWriter(buf) + err = w.Write("batch-1", client.ChatCompletion("a").AddMessage(ChatCompletionMessage{ + Role: "user", Content: "hello", + })) + require.NoError(t, err) + err = w.Write("batch-2", client.Embedding("c").SetInput("whoa")) + require.NoError(t, err) + err = w.Write("batch-3", client.ImageGeneration("d").SetPrompt("whoa")) + require.NoError(t, err) + + require.Equal(t, `{"body":{"messages":[{"role":"user","content":"hello"}],"model":"a"},"custom_id":"batch-1","method":"POST","url":"/v4/chat/completions"} +{"body":{"input":"whoa","model":"c"},"custom_id":"batch-2","method":"POST","url":"/v4/embeddings"} +{"body":{"model":"d","prompt":"whoa"},"custom_id":"batch-3","method":"POST","url":"/v4/images/generations"} +`, buf.String()) +} + +func TestBatchResultReader(t *testing.T) { + result := ` + {"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":26,"prompt_tokens":89,"total_tokens":115},"model":"glm-4","id":"8668357533850320547","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"订单处理慢\"\n}\n'''"}}],"request_id":"615-request-1"}},"custom_id":"request-1","id":"batch_1791490810192076800"} +{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":22,"prompt_tokens":94,"total_tokens":116},"model":"glm-4","id":"8668368425887509080","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"产品缺陷\"\n}\n'''"}}],"request_id":"616-request-2"}},"custom_id":"request-2","id":"batch_1791490810192076800"} +{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":25,"prompt_tokens":86,"total_tokens":111},"model":"glm-4","id":"8668355815863214980","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"正面\",\n \"特定问题标注\": \"性价比\"\n}\n'''"}}],"request_id":"617-request-3"}},"custom_id":"request-3","id":"batch_1791490810192076800"} +{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":28,"prompt_tokens":89,"total_tokens":117},"model":"glm-4","id":"8668355815863214981","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"说明文档不清晰\"\n}\n'''"}}],"request_id":"618-request-4"}},"custom_id":"request-4","id":"batch_1791490810192076800"} + +{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":26,"prompt_tokens":88,"total_tokens":114},"model":"glm-4","id":"8668357533850320546","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"中性\",\n \"特定问题标注\": \"价格问题\"\n}\n'''"}}],"request_id":"619-request-5"}},"custom_id":"request-5","id":"batch_1791490810192076800"} + +{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":26,"prompt_tokens":90,"total_tokens":116},"model":"glm-4","id":"8668356159460662846","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"配送延迟\"\n}\n'''"}}],"request_id":"620-request-6"}},"custom_id":"request-6","id":"batch_1791490810192076800"} + + +{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":27,"prompt_tokens":88,"total_tokens":115},"model":"glm-4","id":"8668357671289274638","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"产品描述不符\"\n}\n'''"}}],"request_id":"621-request-7"}},"custom_id":"request-7","id":"batch_1791490810192076800"} +{"response":{"status_code":200,"body":{"created":1715959702,"usage":{"completion_tokens":26,"prompt_tokens":87,"total_tokens":113},"model":"glm-4","id":"8668355644064514872","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"正面\",\n \"特定问题标注\": \"客服态度\"\n}\n'''"}}],"request_id":"622-request-8"}},"custom_id":"request-8","id":"batch_1791490810192076800"} + {"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":29,"prompt_tokens":90,"total_tokens":119},"model":"glm-4","id":"8668357671289274639","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"负面\",\n \"特定问题标注\": \"包装问题, 产品损坏\"\n}\n'''"}}],"request_id":"623-request-9"}},"custom_id":"request-9","id":"batch_1791490810192076800"} +{"response":{"status_code":200,"body":{"created":1715959701,"usage":{"completion_tokens":27,"prompt_tokens":87,"total_tokens":114},"model":"glm-4","id":"8668355644064514871","choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","content":"'''json\n{\n \"分类标签\": \"正面\",\n \"特定问题标注\": \"产品描述不符\"\n}\n'''"}}],"request_id":"624-request-10"}},"custom_id":"request-10","id":"batch_1791490810192076800"} +` + + brr := NewBatchResultReader[ChatCompletionResponse](bytes.NewReader([]byte(result))) + + var count int + + for { + var res BatchResult[ChatCompletionResponse] + + err := brr.Read(&res) + + if err != nil { + if err == io.EOF { + err = nil + } + require.Equal(t, 10, count) + require.NoError(t, err) + break + } + + require.Equal(t, 200, res.Response.StatusCode) + + count++ + } +} diff --git a/llm/zhipu/zhipu/batch_test.go b/llm/zhipu/zhipu/batch_test.go new file mode 100644 index 0000000..8ce5aa5 --- /dev/null +++ b/llm/zhipu/zhipu/batch_test.go @@ -0,0 +1,59 @@ +package zhipu + +import ( + "bytes" + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBatchServiceAll(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + buf := &bytes.Buffer{} + + bfw := NewBatchFileWriter(buf) + err = bfw.Write("batch_1", client.ChatCompletion("glm-4-flash").AddMessage(ChatCompletionMessage{ + Role: RoleUser, Content: "你好呀", + })) + require.NoError(t, err) + err = bfw.Write("batch_2", client.ChatCompletion("glm-4-flash").AddMessage(ChatCompletionMessage{ + Role: RoleUser, Content: "你叫什么名字", + })) + require.NoError(t, err) + + res, err := client.FileCreate(FilePurposeBatch).SetFile(bytes.NewReader(buf.Bytes()), "batch.jsonl").Do(context.Background()) + require.NoError(t, err) + + fileID := res.FileCreateFineTuneResponse.ID + require.NotEmpty(t, fileID) + + res1, err := client.BatchCreate(). + SetInputFileID(fileID). + SetCompletionWindow(BatchCompletionWindow24h). + SetEndpoint(BatchEndpointV4ChatCompletions).Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res1.ID) + + res2, err := client.BatchGet(res1.ID).Do(context.Background()) + require.NoError(t, err) + require.Equal(t, res2.ID, res1.ID) + + res3, err := client.BatchList().Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res3.Data) + + err = client.BatchCancel(res1.ID).Do(context.Background()) + require.NoError(t, err) +} + +func TestBatchListService(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + res, err := client.BatchList().Do(context.Background()) + require.NoError(t, err) + t.Log(res) +} diff --git a/llm/zhipu/zhipu/chat_completion.go b/llm/zhipu/zhipu/chat_completion.go new file mode 100644 index 0000000..b00db85 --- /dev/null +++ b/llm/zhipu/zhipu/chat_completion.go @@ -0,0 +1,577 @@ +package zhipu + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "io" + + "github.com/go-resty/resty/v2" +) + +const ( + RoleSystem = "system" + RoleUser = "user" + RoleAssistant = "assistant" + RoleTool = "tool" + + ToolChoiceAuto = "auto" + + FinishReasonStop = "stop" + FinishReasonStopSequence = "stop_sequence" + FinishReasonToolCalls = "tool_calls" + FinishReasonLength = "length" + FinishReasonSensitive = "sensitive" + FinishReasonNetworkError = "network_error" + + ToolTypeFunction = "function" + ToolTypeWebSearch = "web_search" + ToolTypeRetrieval = "retrieval" + + MultiContentTypeText = "text" + MultiContentTypeImageURL = "image_url" + MultiContentTypeVideoURL = "video_url" + + // New in GLM-4-AllTools + ToolTypeCodeInterpreter = "code_interpreter" + ToolTypeDrawingTool = "drawing_tool" + ToolTypeWebBrowser = "web_browser" + + CodeInterpreterSandboxNone = "none" + CodeInterpreterSandboxAuto = "auto" + + ChatCompletionStatusFailed = "failed" + ChatCompletionStatusCompleted = "completed" + ChatCompletionStatusRequiresAction = "requires_action" +) + +// ChatCompletionTool is the interface for chat completion tool +type ChatCompletionTool interface { + isChatCompletionTool() +} + +// ChatCompletionToolFunction is the function for chat completion +type ChatCompletionToolFunction struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters any `json:"parameters"` +} + +func (ChatCompletionToolFunction) isChatCompletionTool() {} + +// ChatCompletionToolRetrieval is the retrieval for chat completion +type ChatCompletionToolRetrieval struct { + KnowledgeID string `json:"knowledge_id"` + PromptTemplate string `json:"prompt_template,omitempty"` +} + +func (ChatCompletionToolRetrieval) isChatCompletionTool() {} + +// ChatCompletionToolWebSearch is the web search for chat completion +type ChatCompletionToolWebSearch struct { + Enable *bool `json:"enable,omitempty"` + SearchQuery string `json:"search_query,omitempty"` + SearchResult bool `json:"search_result,omitempty"` +} + +func (ChatCompletionToolWebSearch) isChatCompletionTool() {} + +// ChatCompletionToolCodeInterpreter is the code interpreter for chat completion +// only in GLM-4-AllTools +type ChatCompletionToolCodeInterpreter struct { + Sandbox *string `json:"sandbox,omitempty"` +} + +func (ChatCompletionToolCodeInterpreter) isChatCompletionTool() {} + +// ChatCompletionToolDrawingTool is the drawing tool for chat completion +// only in GLM-4-AllTools +type ChatCompletionToolDrawingTool struct { + // no fields +} + +func (ChatCompletionToolDrawingTool) isChatCompletionTool() {} + +// ChatCompletionToolWebBrowser is the web browser for chat completion +type ChatCompletionToolWebBrowser struct { + // no fields +} + +func (ChatCompletionToolWebBrowser) isChatCompletionTool() {} + +// ChatCompletionUsage is the usage for chat completion +type ChatCompletionUsage struct { + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` +} + +// ChatCompletionWebSearch is the web search result for chat completion +type ChatCompletionWebSearch struct { + Icon string `json:"icon"` + Title string `json:"title"` + Link string `json:"link"` + Media string `json:"media"` + Content string `json:"content"` +} + +// ChatCompletionToolCallFunction is the function for chat completion tool call +type ChatCompletionToolCallFunction struct { + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments"` +} + +// ChatCompletionToolCallCodeInterpreterOutput is the output for chat completion tool call code interpreter +type ChatCompletionToolCallCodeInterpreterOutput struct { + Type string `json:"type"` + Logs string `json:"logs"` + File string `json:"file"` +} + +// ChatCompletionToolCallCodeInterpreter is the code interpreter for chat completion tool call +type ChatCompletionToolCallCodeInterpreter struct { + Input string `json:"input"` + Outputs []ChatCompletionToolCallCodeInterpreterOutput `json:"outputs"` +} + +// ChatCompletionToolCallDrawingToolOutput is the output for chat completion tool call drawing tool +type ChatCompletionToolCallDrawingToolOutput struct { + Image string `json:"image"` +} + +// ChatCompletionToolCallDrawingTool is the drawing tool for chat completion tool call +type ChatCompletionToolCallDrawingTool struct { + Input string `json:"input"` + Outputs []ChatCompletionToolCallDrawingToolOutput `json:"outputs"` +} + +// ChatCompletionToolCallWebBrowserOutput is the output for chat completion tool call web browser +type ChatCompletionToolCallWebBrowserOutput struct { + Title string `json:"title"` + Link string `json:"link"` + Content string `json:"content"` +} + +// ChatCompletionToolCallWebBrowser is the web browser for chat completion tool call +type ChatCompletionToolCallWebBrowser struct { + Input string `json:"input"` + Outputs []ChatCompletionToolCallWebBrowserOutput `json:"outputs"` +} + +// ChatCompletionToolCall is the tool call for chat completion +type ChatCompletionToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function *ChatCompletionToolCallFunction `json:"function,omitempty"` + CodeInterpreter *ChatCompletionToolCallCodeInterpreter `json:"code_interpreter,omitempty"` + DrawingTool *ChatCompletionToolCallDrawingTool `json:"drawing_tool,omitempty"` + WebBrowser *ChatCompletionToolCallWebBrowser `json:"web_browser,omitempty"` +} + +type ChatCompletionMessageType interface { + isChatCompletionMessageType() +} + +// ChatCompletionMessage is the message for chat completion +type ChatCompletionMessage struct { + Role string `json:"role"` + Content string `json:"content,omitempty"` + ToolCalls []ChatCompletionToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +func (ChatCompletionMessage) isChatCompletionMessageType() {} + +type ChatCompletionMultiContent struct { + Type string `json:"type"` + Text string `json:"text"` + ImageURL *URLItem `json:"image_url,omitempty"` + VideoURL *URLItem `json:"video_url,omitempty"` +} + +// ChatCompletionMultiMessage is the multi message for chat completion +type ChatCompletionMultiMessage struct { + Role string `json:"role"` + Content []ChatCompletionMultiContent `json:"content"` +} + +func (ChatCompletionMultiMessage) isChatCompletionMessageType() {} + +// ChatCompletionMeta is the meta for chat completion +type ChatCompletionMeta struct { + UserInfo string `json:"user_info"` + BotInfo string `json:"bot_info"` + UserName string `json:"user_name"` + BotName string `json:"bot_name"` +} + +// ChatCompletionChoice is the choice for chat completion +type ChatCompletionChoice struct { + Index int `json:"index"` + FinishReason string `json:"finish_reason"` + Delta ChatCompletionMessage `json:"delta"` // stream mode + Message ChatCompletionMessage `json:"message"` // non-stream mode +} + +// ChatCompletionResponse is the response for chat completion +type ChatCompletionResponse struct { + ID string `json:"id"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionChoice `json:"choices"` + Usage ChatCompletionUsage `json:"usage"` + WebSearch []ChatCompletionWebSearch `json:"web_search"` + // Status is the status of the chat completion, only in GLM-4-AllTools + Status string `json:"status"` +} + +// ChatCompletionStreamHandler is the handler for chat completion stream +type ChatCompletionStreamHandler func(chunk ChatCompletionResponse) error + +var ( + chatCompletionStreamPrefix = []byte("data:") + chatCompletionStreamDone = []byte("[DONE]") +) + +// chatCompletionReduceResponse reduce the chunk to the response +func chatCompletionReduceResponse(out *ChatCompletionResponse, chunk ChatCompletionResponse) { + if len(out.Choices) == 0 { + out.Choices = append(out.Choices, ChatCompletionChoice{}) + } + + // basic + out.ID = chunk.ID + out.Created = chunk.Created + out.Model = chunk.Model + + // choices + if len(chunk.Choices) != 0 { + oc := &out.Choices[0] + cc := chunk.Choices[0] + + oc.Index = cc.Index + if cc.Delta.Role != "" { + oc.Message.Role = cc.Delta.Role + } + oc.Message.Content += cc.Delta.Content + oc.Message.ToolCalls = append(oc.Message.ToolCalls, cc.Delta.ToolCalls...) + if cc.FinishReason != "" { + oc.FinishReason = cc.FinishReason + } + } + + // usage + if chunk.Usage.CompletionTokens != 0 { + out.Usage.CompletionTokens = chunk.Usage.CompletionTokens + } + if chunk.Usage.PromptTokens != 0 { + out.Usage.PromptTokens = chunk.Usage.PromptTokens + } + if chunk.Usage.TotalTokens != 0 { + out.Usage.TotalTokens = chunk.Usage.TotalTokens + } + + // web search + out.WebSearch = append(out.WebSearch, chunk.WebSearch...) +} + +// chatCompletionDecodeStream decode the sse stream of chat completion +func chatCompletionDecodeStream(r io.Reader, fn func(chunk ChatCompletionResponse) error) (err error) { + br := bufio.NewReader(r) + + for { + var line []byte + + if line, err = br.ReadBytes('\n'); err != nil { + if errors.Is(err, io.EOF) { + err = nil + } + break + } + + line = bytes.TrimSpace(line) + + if len(line) == 0 { + continue + } + + if !bytes.HasPrefix(line, chatCompletionStreamPrefix) { + continue + } + + data := bytes.TrimSpace(line[len(chatCompletionStreamPrefix):]) + + if bytes.Equal(data, chatCompletionStreamDone) { + break + } + + if len(data) == 0 { + continue + } + + var chunk ChatCompletionResponse + if err = json.Unmarshal(data, &chunk); err != nil { + return + } + if err = fn(chunk); err != nil { + return + } + } + + return +} + +// ChatCompletionStreamService is the service for chat completion stream +type ChatCompletionService struct { + client *Client + + model string + requestID *string + doSample *bool + temperature *float64 + topP *float64 + maxTokens *int + stop []string + toolChoice *string + userID *string + meta *ChatCompletionMeta + + messages []any + tools []any + + streamHandler ChatCompletionStreamHandler +} + +var ( + _ BatchSupport = &ChatCompletionService{} +) + +// NewChatCompletionService creates a new ChatCompletionService. +func NewChatCompletionService(client *Client) *ChatCompletionService { + return &ChatCompletionService{ + client: client, + } +} + +func (s *ChatCompletionService) BatchMethod() string { + return "POST" +} + +func (s *ChatCompletionService) BatchURL() string { + return BatchEndpointV4ChatCompletions +} + +func (s *ChatCompletionService) BatchBody() any { + return s.buildBody() +} + +// SetModel set the model of the chat completion +func (s *ChatCompletionService) SetModel(model string) *ChatCompletionService { + s.model = model + return s +} + +// SetMeta set the meta of the chat completion, optional +func (s *ChatCompletionService) SetMeta(meta ChatCompletionMeta) *ChatCompletionService { + s.meta = &meta + return s +} + +// SetRequestID set the request id of the chat completion, optional +func (s *ChatCompletionService) SetRequestID(requestID string) *ChatCompletionService { + s.requestID = &requestID + return s +} + +// SetTemperature set the temperature of the chat completion, optional +func (s *ChatCompletionService) SetDoSample(doSample bool) *ChatCompletionService { + s.doSample = &doSample + return s +} + +// SetTemperature set the temperature of the chat completion, optional +func (s *ChatCompletionService) SetTemperature(temperature float64) *ChatCompletionService { + s.temperature = &temperature + return s +} + +// SetTopP set the top p of the chat completion, optional +func (s *ChatCompletionService) SetTopP(topP float64) *ChatCompletionService { + s.topP = &topP + return s +} + +// SetMaxTokens set the max tokens of the chat completion, optional +func (s *ChatCompletionService) SetMaxTokens(maxTokens int) *ChatCompletionService { + s.maxTokens = &maxTokens + return s +} + +// SetStop set the stop of the chat completion, optional +func (s *ChatCompletionService) SetStop(stop ...string) *ChatCompletionService { + s.stop = stop + return s +} + +// SetToolChoice set the tool choice of the chat completion, optional +func (s *ChatCompletionService) SetToolChoice(toolChoice string) *ChatCompletionService { + s.toolChoice = &toolChoice + return s +} + +// SetUserID set the user id of the chat completion, optional +func (s *ChatCompletionService) SetUserID(userID string) *ChatCompletionService { + s.userID = &userID + return s +} + +// SetStreamHandler set the stream handler of the chat completion, optional +// this will enable the stream mode +func (s *ChatCompletionService) SetStreamHandler(handler ChatCompletionStreamHandler) *ChatCompletionService { + s.streamHandler = handler + return s +} + +// AddMessage add the message to the chat completion +func (s *ChatCompletionService) AddMessage(messages ...ChatCompletionMessageType) *ChatCompletionService { + for _, message := range messages { + s.messages = append(s.messages, message) + } + return s +} + +// AddFunction add the function to the chat completion +func (s *ChatCompletionService) AddTool(tools ...ChatCompletionTool) *ChatCompletionService { + for _, tool := range tools { + switch tool := tool.(type) { + case ChatCompletionToolFunction: + s.tools = append(s.tools, map[string]any{ + "type": ToolTypeFunction, + ToolTypeFunction: tool, + }) + case ChatCompletionToolRetrieval: + s.tools = append(s.tools, map[string]any{ + "type": ToolTypeRetrieval, + ToolTypeRetrieval: tool, + }) + case ChatCompletionToolWebSearch: + s.tools = append(s.tools, map[string]any{ + "type": ToolTypeWebSearch, + ToolTypeWebSearch: tool, + }) + case ChatCompletionToolCodeInterpreter: + s.tools = append(s.tools, map[string]any{ + "type": ToolTypeCodeInterpreter, + ToolTypeCodeInterpreter: tool, + }) + case ChatCompletionToolDrawingTool: + s.tools = append(s.tools, map[string]any{ + "type": ToolTypeDrawingTool, + ToolTypeDrawingTool: tool, + }) + case ChatCompletionToolWebBrowser: + s.tools = append(s.tools, map[string]any{ + "type": ToolTypeWebBrowser, + ToolTypeWebBrowser: tool, + }) + } + } + return s +} + +func (s *ChatCompletionService) buildBody() M { + body := map[string]any{ + "model": s.model, + "messages": s.messages, + } + if s.requestID != nil { + body["request_id"] = *s.requestID + } + if s.doSample != nil { + body["do_sample"] = *s.doSample + } + if s.temperature != nil { + body["temperature"] = *s.temperature + } + if s.topP != nil { + body["top_p"] = *s.topP + } + if s.maxTokens != nil { + body["max_tokens"] = *s.maxTokens + } + if len(s.stop) != 0 { + body["stop"] = s.stop + } + if len(s.tools) != 0 { + body["tools"] = s.tools + } + if s.toolChoice != nil { + body["tool_choice"] = *s.toolChoice + } + if s.userID != nil { + body["user_id"] = *s.userID + } + if s.meta != nil { + body["meta"] = s.meta + } + return body +} + +// Do send the request of the chat completion and return the response +func (s *ChatCompletionService) Do(ctx context.Context) (res ChatCompletionResponse, err error) { + body := s.buildBody() + + streamHandler := s.streamHandler + + if streamHandler == nil { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + //fmt.Println(u.BMagenta(u.JsonP(body)), 111) + if resp, err = s.client.request(ctx).SetBody(body).SetResult(&res).SetError(&apiError).Post("chat/completions"); err != nil { + //fmt.Println(u.BRed(err.Error()), 2221) + return + } + if resp.IsError() { + err = apiError + //fmt.Println(u.BRed(err.Error()), 2222) + return + } + //fmt.Println(u.BGreen(u.JsonP(resp.Result())), resp.Status(), resp.Status(), 333) + return + } + + // stream mode + + body["stream"] = true + + var resp *resty.Response + + if resp, err = s.client.request(ctx).SetBody(body).SetDoNotParseResponse(true).Post("chat/completions"); err != nil { + return + } + defer resp.RawBody().Close() + + if resp.IsError() { + err = errors.New(resp.Status()) + return + } + + var choice ChatCompletionChoice + + if err = chatCompletionDecodeStream(resp.RawBody(), func(chunk ChatCompletionResponse) error { + // reduce the chunk to the response + chatCompletionReduceResponse(&res, chunk) + // invoke the stream handler + return streamHandler(chunk) + }); err != nil { + return + } + + res.Choices = append(res.Choices, choice) + + return +} diff --git a/llm/zhipu/zhipu/chat_completion_test.go b/llm/zhipu/zhipu/chat_completion_test.go new file mode 100644 index 0000000..8839850 --- /dev/null +++ b/llm/zhipu/zhipu/chat_completion_test.go @@ -0,0 +1,251 @@ +package zhipu + +import ( + "context" + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestChatCompletionService(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.ChatCompletion("glm-4-flash") + s.AddMessage(ChatCompletionMessage{ + Role: RoleUser, + Content: "你好呀", + }) + res, err := s.Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res.Choices) + choice := res.Choices[0] + require.Equal(t, FinishReasonStop, choice.FinishReason) + require.NotEmpty(t, choice.Message.Content) +} + +func TestChatCompletionServiceCharGLM(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.ChatCompletion("charglm-3") + s.SetMeta( + ChatCompletionMeta{ + UserName: "啵酱", + UserInfo: "啵酱是小少爷", + BotName: "塞巴斯酱", + BotInfo: "塞巴斯酱是一个冷酷的恶魔管家", + }, + ).AddMessage(ChatCompletionMessage{ + Role: RoleUser, + Content: "早上好", + }) + res, err := s.Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res.Choices) + choice := res.Choices[0] + require.Contains(t, []string{FinishReasonStop, FinishReasonStopSequence}, choice.FinishReason) + require.NotEmpty(t, choice.Message.Content) +} + +func TestChatCompletionServiceAllToolsCodeInterpreter(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.ChatCompletion("GLM-4-AllTools") + s.AddMessage(ChatCompletionMultiMessage{ + Role: "user", + Content: []ChatCompletionMultiContent{ + { + Type: "text", + Text: "计算[5,10,20,700,99,310,978,100]的平均值和方差。", + }, + }, + }) + s.AddTool(ChatCompletionToolCodeInterpreter{ + Sandbox: Ptr(CodeInterpreterSandboxAuto), + }) + + foundInterpreterInput := false + foundInterpreterOutput := false + + s.SetStreamHandler(func(chunk ChatCompletionResponse) error { + for _, c := range chunk.Choices { + for _, tc := range c.Delta.ToolCalls { + if tc.Type == ToolTypeCodeInterpreter && tc.CodeInterpreter != nil { + if tc.CodeInterpreter.Input != "" { + foundInterpreterInput = true + } + if len(tc.CodeInterpreter.Outputs) > 0 { + foundInterpreterOutput = true + } + } + } + } + buf, _ := json.MarshalIndent(chunk, "", " ") + t.Log(string(buf)) + return nil + }) + + res, err := s.Do(context.Background()) + require.True(t, foundInterpreterInput) + require.True(t, foundInterpreterOutput) + require.NotNil(t, res) + require.NoError(t, err) +} + +func TestChatCompletionServiceAllToolsDrawingTool(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.ChatCompletion("GLM-4-AllTools") + s.AddMessage(ChatCompletionMultiMessage{ + Role: "user", + Content: []ChatCompletionMultiContent{ + { + Type: "text", + Text: "画一个正弦函数图像", + }, + }, + }) + s.AddTool(ChatCompletionToolDrawingTool{}) + + foundInput := false + foundOutput := false + outputImage := "" + + s.SetStreamHandler(func(chunk ChatCompletionResponse) error { + for _, c := range chunk.Choices { + for _, tc := range c.Delta.ToolCalls { + if tc.Type == ToolTypeDrawingTool && tc.DrawingTool != nil { + if tc.DrawingTool.Input != "" { + foundInput = true + } + if len(tc.DrawingTool.Outputs) > 0 { + foundOutput = true + } + for _, output := range tc.DrawingTool.Outputs { + if output.Image != "" { + outputImage = output.Image + } + } + } + } + } + buf, _ := json.MarshalIndent(chunk, "", " ") + t.Log(string(buf)) + return nil + }) + + res, err := s.Do(context.Background()) + require.True(t, foundInput) + require.True(t, foundOutput) + require.NotEmpty(t, outputImage) + t.Log(outputImage) + require.NotNil(t, res) + require.NoError(t, err) +} + +func TestChatCompletionServiceAllToolsWebBrowser(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.ChatCompletion("GLM-4-AllTools") + s.AddMessage(ChatCompletionMultiMessage{ + Role: "user", + Content: []ChatCompletionMultiContent{ + { + Type: "text", + Text: "搜索下本周深圳天气如何", + }, + }, + }) + s.AddTool(ChatCompletionToolWebBrowser{}) + + foundInput := false + foundOutput := false + outputContent := "" + + s.SetStreamHandler(func(chunk ChatCompletionResponse) error { + for _, c := range chunk.Choices { + for _, tc := range c.Delta.ToolCalls { + if tc.Type == ToolTypeWebBrowser && tc.WebBrowser != nil { + if tc.WebBrowser.Input != "" { + foundInput = true + } + if len(tc.WebBrowser.Outputs) > 0 { + foundOutput = true + } + for _, output := range tc.WebBrowser.Outputs { + if output.Content != "" { + outputContent = output.Content + } + } + } + } + } + buf, _ := json.MarshalIndent(chunk, "", " ") + t.Log(string(buf)) + return nil + }) + + res, err := s.Do(context.Background()) + require.True(t, foundInput) + require.True(t, foundOutput) + require.NotEmpty(t, outputContent) + t.Log(outputContent) + require.NotNil(t, res) + require.NoError(t, err) +} + +func TestChatCompletionServiceStream(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + var content string + + s := client.ChatCompletion("glm-4-flash").AddMessage(ChatCompletionMessage{ + Role: RoleUser, + Content: "你好呀", + }).SetStreamHandler(func(chunk ChatCompletionResponse) error { + content += chunk.Choices[0].Delta.Content + return nil + }) + res, err := s.Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res.Choices) + choice := res.Choices[0] + require.Equal(t, FinishReasonStop, choice.FinishReason) + require.NotEmpty(t, choice.Message.Content) + require.Equal(t, content, choice.Message.Content) +} + +func TestChatCompletionServiceVision(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.ChatCompletion("glm-4v") + s.AddMessage(ChatCompletionMultiMessage{ + Role: RoleUser, + Content: []ChatCompletionMultiContent{ + { + Type: MultiContentTypeText, + Text: "图里有什么", + }, + { + Type: MultiContentTypeImageURL, + ImageURL: &URLItem{ + URL: "https://img1.baidu.com/it/u=1369931113,3388870256&fm=253&app=138&size=w931&n=0&f=JPEG&fmt=auto?sec=1703696400&t=f3028c7a1dca43a080aeb8239f09cc2f", + }, + }, + }, + }) + res, err := s.Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res.Choices) + require.NotZero(t, res.Usage.CompletionTokens) + choice := res.Choices[0] + require.Equal(t, FinishReasonStop, choice.FinishReason) + require.NotEmpty(t, choice.Message.Content) +} diff --git a/llm/zhipu/zhipu/client.go b/llm/zhipu/zhipu/client.go new file mode 100644 index 0000000..5f2aa9b --- /dev/null +++ b/llm/zhipu/zhipu/client.go @@ -0,0 +1,291 @@ +package zhipu + +import ( + "context" + "errors" + "net/http" + "os" + "strconv" + "strings" + "time" + + "github.com/go-resty/resty/v2" + "github.com/golang-jwt/jwt/v5" +) + +const ( + envAPIKey = "ZHIPUAI_API_KEY" + envBaseURL = "ZHIPUAI_BASE_URL" + envDebug = "ZHIPUAI_DEBUG" + + defaultBaseURL = "https://open.bigmodel.cn/api/paas/v4" +) + +var ( + // ErrAPIKeyMissing is the error when the api key is missing + ErrAPIKeyMissing = errors.New("zhipu: api key is missing") + // ErrAPIKeyMalformed is the error when the api key is malformed + ErrAPIKeyMalformed = errors.New("zhipu: api key is malformed") +) + +type clientOptions struct { + baseURL string + apiKey string + client *http.Client + debug *bool +} + +// ClientOption is a function that configures the client +type ClientOption func(opts *clientOptions) + +// WithAPIKey set the api key of the client +func WithAPIKey(apiKey string) ClientOption { + return func(opts *clientOptions) { + opts.apiKey = apiKey + } +} + +// WithBaseURL set the base url of the client +func WithBaseURL(baseURL string) ClientOption { + return func(opts *clientOptions) { + opts.baseURL = baseURL + } +} + +// WithHTTPClient set the http client of the client +func WithHTTPClient(client *http.Client) ClientOption { + return func(opts *clientOptions) { + opts.client = client + } +} + +// WithDebug set the debug mode of the client +func WithDebug(debug bool) ClientOption { + return func(opts *clientOptions) { + opts.debug = new(bool) + *opts.debug = debug + } +} + +// Client is the client for zhipu ai platform +type Client struct { + client *resty.Client + debug bool + keyID string + keySecret []byte +} + +func (c *Client) createJWT() string { + timestamp := time.Now().UnixMilli() + exp := timestamp + time.Hour.Milliseconds()*24*7 + + t := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "api_key": c.keyID, + "timestamp": timestamp, + "exp": exp, + }) + t.Header = map[string]interface{}{ + "alg": "HS256", + "sign_type": "SIGN", + } + + token, err := t.SignedString(c.keySecret) + if err != nil { + panic(err) + } + return token +} + +// request creates a new resty request with the jwt token and context +func (c *Client) request(ctx context.Context) *resty.Request { + return c.client.R().SetContext(ctx).SetHeader("Authorization", c.createJWT()) +} + +// NewClient creates a new client +// It will read the api key from the environment variable ZHIPUAI_API_KEY +// It will read the base url from the environment variable ZHIPUAI_BASE_URL +func NewClient(optFns ...ClientOption) (client *Client, err error) { + var opts clientOptions + for _, optFn := range optFns { + optFn(&opts) + } + // base url + if opts.baseURL == "" { + opts.baseURL = strings.TrimSpace(os.Getenv(envBaseURL)) + } + if opts.baseURL == "" { + opts.baseURL = defaultBaseURL + } + // api key + if opts.apiKey == "" { + opts.apiKey = strings.TrimSpace(os.Getenv(envAPIKey)) + } + if opts.apiKey == "" { + err = ErrAPIKeyMissing + return + } + // debug + if opts.debug == nil { + if debugStr := strings.TrimSpace(os.Getenv(envDebug)); debugStr != "" { + if debug, err1 := strconv.ParseBool(debugStr); err1 == nil { + opts.debug = &debug + } + } + } + + keyComponents := strings.SplitN(opts.apiKey, ".", 2) + + if len(keyComponents) != 2 { + err = ErrAPIKeyMalformed + return + } + + client = &Client{ + keyID: keyComponents[0], + keySecret: []byte(keyComponents[1]), + } + + if opts.client == nil { + client.client = resty.New() + } else { + client.client = resty.NewWithClient(opts.client) + } + + client.client = client.client.SetBaseURL(opts.baseURL) + + if opts.debug != nil { + client.client.SetDebug(*opts.debug) + client.debug = *opts.debug + } + return +} + +// BatchCreate creates a new BatchCreateService. +func (c *Client) BatchCreate() *BatchCreateService { + return NewBatchCreateService(c) +} + +// BatchGet creates a new BatchGetService. +func (c *Client) BatchGet(batchID string) *BatchGetService { + return NewBatchGetService(c).SetBatchID(batchID) +} + +// BatchCancel creates a new BatchCancelService. +func (c *Client) BatchCancel(batchID string) *BatchCancelService { + return NewBatchCancelService(c).SetBatchID(batchID) +} + +// BatchList creates a new BatchListService. +func (c *Client) BatchList() *BatchListService { + return NewBatchListService(c) +} + +// ChatCompletion creates a new ChatCompletionService. +func (c *Client) ChatCompletion(model string) *ChatCompletionService { + return NewChatCompletionService(c).SetModel(model) +} + +// Embedding embeds a list of text into a vector space. +func (c *Client) Embedding(model string) *EmbeddingService { + return NewEmbeddingService(c).SetModel(model) +} + +// FileCreate creates a new FileCreateService. +func (c *Client) FileCreate(purpose string) *FileCreateService { + return NewFileCreateService(c).SetPurpose(purpose) +} + +// FileEditService creates a new FileEditService. +func (c *Client) FileEdit(documentID string) *FileEditService { + return NewFileEditService(c).SetDocumentID(documentID) +} + +// FileList creates a new FileListService. +func (c *Client) FileList(purpose string) *FileListService { + return NewFileListService(c).SetPurpose(purpose) +} + +// FileDeleteService creates a new FileDeleteService. +func (c *Client) FileDelete(documentID string) *FileDeleteService { + return NewFileDeleteService(c).SetDocumentID(documentID) +} + +// FileGetService creates a new FileGetService. +func (c *Client) FileGet(documentID string) *FileGetService { + return NewFileGetService(c).SetDocumentID(documentID) +} + +// FileDownload creates a new FileDownloadService. +func (c *Client) FileDownload(fileID string) *FileDownloadService { + return NewFileDownloadService(c).SetFileID(fileID) +} + +// FineTuneCreate creates a new fine tune create service +func (c *Client) FineTuneCreate(model string) *FineTuneCreateService { + return NewFineTuneCreateService(c).SetModel(model) +} + +// FineTuneEventList creates a new fine tune event list service +func (c *Client) FineTuneEventList(jobID string) *FineTuneEventListService { + return NewFineTuneEventListService(c).SetJobID(jobID) +} + +// FineTuneGet creates a new fine tune get service +func (c *Client) FineTuneGet(jobID string) *FineTuneGetService { + return NewFineTuneGetService(c).SetJobID(jobID) +} + +// FineTuneList creates a new fine tune list service +func (c *Client) FineTuneList() *FineTuneListService { + return NewFineTuneListService(c) +} + +// FineTuneDelete creates a new fine tune delete service +func (c *Client) FineTuneDelete(jobID string) *FineTuneDeleteService { + return NewFineTuneDeleteService(c).SetJobID(jobID) +} + +// FineTuneCancel creates a new fine tune cancel service +func (c *Client) FineTuneCancel(jobID string) *FineTuneCancelService { + return NewFineTuneCancelService(c).SetJobID(jobID) +} + +// ImageGeneration creates a new image generation service +func (c *Client) ImageGeneration(model string) *ImageGenerationService { + return NewImageGenerationService(c).SetModel(model) +} + +// KnowledgeCreate creates a new knowledge create service +func (c *Client) KnowledgeCreate() *KnowledgeCreateService { + return NewKnowledgeCreateService(c) +} + +// KnowledgeEdit creates a new knowledge edit service +func (c *Client) KnowledgeEdit(knowledgeID string) *KnowledgeEditService { + return NewKnowledgeEditService(c).SetKnowledgeID(knowledgeID) +} + +// KnowledgeList list all the knowledge +func (c *Client) KnowledgeList() *KnowledgeListService { + return NewKnowledgeListService(c) +} + +// KnowledgeDelete creates a new knowledge delete service +func (c *Client) KnowledgeDelete(knowledgeID string) *KnowledgeDeleteService { + return NewKnowledgeDeleteService(c).SetKnowledgeID(knowledgeID) +} + +// KnowledgeGet creates a new knowledge get service +func (c *Client) KnowledgeCapacity() *KnowledgeCapacityService { + return NewKnowledgeCapacityService(c) +} + +// VideoGeneration creates a new video generation service +func (c *Client) VideoGeneration(model string) *VideoGenerationService { + return NewVideoGenerationService(c).SetModel(model) +} + +// AsyncResult creates a new async result get service +func (c *Client) AsyncResult(id string) *AsyncResultService { + return NewAsyncResultService(c).SetID(id) +} diff --git a/llm/zhipu/zhipu/client_test.go b/llm/zhipu/zhipu/client_test.go new file mode 100644 index 0000000..dd4800d --- /dev/null +++ b/llm/zhipu/zhipu/client_test.go @@ -0,0 +1,17 @@ +package zhipu + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestClientR(t *testing.T) { + c, err := NewClient() + require.NoError(t, err) + // the only free api is to list fine-tuning jobs + res, err := c.request(context.Background()).Get("fine_tuning/jobs") + require.NoError(t, err) + require.True(t, res.IsSuccess()) +} diff --git a/llm/zhipu/zhipu/cog.toml b/llm/zhipu/zhipu/cog.toml new file mode 100644 index 0000000..ee92ed8 --- /dev/null +++ b/llm/zhipu/zhipu/cog.toml @@ -0,0 +1,25 @@ +from_latest_tag = false +ignore_merge_commits = false +disable_changelog = false +disable_bump_commit = false +generate_mono_repository_global_tag = true +branch_whitelist = [] +skip_ci = "[skip ci]" +skip_untracked = false +pre_bump_hooks = [] +post_bump_hooks = [] +pre_package_bump_hooks = [] +post_package_bump_hooks = [] +tag_prefix = "v" + +[git_hooks] + +[commit_types] + +[changelog] +path = "CHANGELOG.md" +authors = [] + +[bump_profiles] + +[packages] diff --git a/llm/zhipu/zhipu/embedding.go b/llm/zhipu/zhipu/embedding.go new file mode 100644 index 0000000..45c672b --- /dev/null +++ b/llm/zhipu/zhipu/embedding.go @@ -0,0 +1,87 @@ +package zhipu + +import ( + "context" + + "github.com/go-resty/resty/v2" +) + +// EmbeddingData is the data for each embedding. +type EmbeddingData struct { + Embedding []float64 `json:"embedding"` + Index int `json:"index"` + Object string `json:"object"` +} + +// EmbeddingResponse is the response from the embedding service. +type EmbeddingResponse struct { + Model string `json:"model"` + Data []EmbeddingData `json:"data"` + Object string `json:"object"` + Usage ChatCompletionUsage `json:"usage"` +} + +// EmbeddingService embeds a list of text into a vector space. +type EmbeddingService struct { + client *Client + + model string + input string +} + +var ( + _ BatchSupport = &EmbeddingService{} +) + +// NewEmbeddingService creates a new EmbeddingService. +func NewEmbeddingService(client *Client) *EmbeddingService { + return &EmbeddingService{client: client} +} + +func (s *EmbeddingService) BatchMethod() string { + return "POST" +} + +func (s *EmbeddingService) BatchURL() string { + return BatchEndpointV4Embeddings +} + +func (s *EmbeddingService) BatchBody() any { + return s.buildBody() +} + +// SetModel sets the model to use for the embedding. +func (s *EmbeddingService) SetModel(model string) *EmbeddingService { + s.model = model + return s +} + +// SetInput sets the input text to embed. +func (s *EmbeddingService) SetInput(input string) *EmbeddingService { + s.input = input + return s +} + +func (s *EmbeddingService) buildBody() M { + return M{"model": s.model, "input": s.input} +} + +func (s *EmbeddingService) Do(ctx context.Context) (res EmbeddingResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + if resp, err = s.client.request(ctx). + SetBody(s.buildBody()). + SetResult(&res). + SetError(&apiError). + Post("embeddings"); err != nil { + return + } + if resp.IsError() { + err = apiError + return + } + return +} diff --git a/llm/zhipu/zhipu/embedding_test.go b/llm/zhipu/zhipu/embedding_test.go new file mode 100644 index 0000000..46f4aeb --- /dev/null +++ b/llm/zhipu/zhipu/embedding_test.go @@ -0,0 +1,21 @@ +package zhipu + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestEmbeddingService(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + service := client.Embedding("embedding-2") + + resp, err := service.SetInput("你好").Do(context.Background()) + require.NoError(t, err) + require.NotZero(t, resp.Usage.TotalTokens) + require.NotEmpty(t, resp.Data) + require.NotEmpty(t, resp.Data[0].Embedding) +} diff --git a/llm/zhipu/zhipu/error.go b/llm/zhipu/zhipu/error.go new file mode 100644 index 0000000..ea07ad7 --- /dev/null +++ b/llm/zhipu/zhipu/error.go @@ -0,0 +1,58 @@ +package zhipu + +type APIError struct { + Code string `json:"code"` + Message string `json:"message"` +} + +func (e APIError) Error() string { + return e.Message +} + +type APIErrorResponse struct { + APIError `json:"error"` +} + +func (e APIErrorResponse) Error() string { + return e.APIError.Error() +} + +// GetAPIErrorCode returns the error code of an API error. +func GetAPIErrorCode(err error) string { + if err == nil { + return "" + } + if e, ok := err.(APIError); ok { + return e.Code + } + if e, ok := err.(APIErrorResponse); ok { + return e.Code + } + if e, ok := err.(*APIError); ok && e != nil { + return e.Code + } + if e, ok := err.(*APIErrorResponse); ok && e != nil { + return e.Code + } + return "" +} + +// GetAPIErrorMessage returns the error message of an API error. +func GetAPIErrorMessage(err error) string { + if err == nil { + return "" + } + if e, ok := err.(APIError); ok { + return e.Message + } + if e, ok := err.(APIErrorResponse); ok { + return e.Message + } + if e, ok := err.(*APIError); ok && e != nil { + return e.Message + } + if e, ok := err.(*APIErrorResponse); ok && e != nil { + return e.Message + } + return err.Error() +} diff --git a/llm/zhipu/zhipu/error_test.go b/llm/zhipu/zhipu/error_test.go new file mode 100644 index 0000000..3f2fa08 --- /dev/null +++ b/llm/zhipu/zhipu/error_test.go @@ -0,0 +1,38 @@ +package zhipu + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAPIError(t *testing.T) { + err := APIError{ + Code: "code", + Message: "message", + } + require.Equal(t, "message", err.Error()) + require.Equal(t, "code", GetAPIErrorCode(err)) + require.Equal(t, "message", GetAPIErrorMessage(err)) +} + +func TestAPIErrorResponse(t *testing.T) { + err := APIErrorResponse{ + APIError: APIError{ + Code: "code", + Message: "message", + }, + } + require.Equal(t, "message", err.Error()) + require.Equal(t, "code", GetAPIErrorCode(err)) + require.Equal(t, "message", GetAPIErrorMessage(err)) +} + +func TestAPIErrorResponseFromDoc(t *testing.T) { + var res APIErrorResponse + err := json.Unmarshal([]byte(`{"error":{"code":"1002","message":"Authorization Token非法,请确认Authorization Token正确传递。"}}`), &res) + require.NoError(t, err) + require.Equal(t, "1002", res.Code) + require.Equal(t, "1002", GetAPIErrorCode(res)) +} diff --git a/llm/zhipu/zhipu/file.go b/llm/zhipu/zhipu/file.go new file mode 100644 index 0000000..e8b66af --- /dev/null +++ b/llm/zhipu/zhipu/file.go @@ -0,0 +1,541 @@ +package zhipu + +import ( + "context" + "errors" + "io" + "os" + "path/filepath" + "strconv" + + "github.com/go-resty/resty/v2" +) + +const ( + FilePurposeFineTune = "fine-tune" + FilePurposeRetrieval = "retrieval" + FilePurposeBatch = "batch" + + KnowledgeTypeArticle = 1 + KnowledgeTypeQADocument = 2 + KnowledgeTypeQASpreadsheet = 3 + KnowledgeTypeProductDatabaseSpreadsheet = 4 + KnowledgeTypeCustom = 5 +) + +// FileCreateService is a service to create a file. +type FileCreateService struct { + client *Client + + purpose string + + localFile string + file io.Reader + filename string + + customSeparator *string + sentenceSize *int + knowledgeID *string +} + +// FileCreateKnowledgeSuccessInfo is the success info of the FileCreateKnowledgeResponse. +type FileCreateKnowledgeSuccessInfo struct { + Filename string `json:"fileName"` + DocumentID string `json:"documentId"` +} + +// FileCreateKnowledgeFailedInfo is the failed info of the FileCreateKnowledgeResponse. +type FileCreateKnowledgeFailedInfo struct { + Filename string `json:"fileName"` + FailReason string `json:"failReason"` +} + +// FileCreateKnowledgeResponse is the response of the FileCreateService. +type FileCreateKnowledgeResponse struct { + SuccessInfos []FileCreateKnowledgeSuccessInfo `json:"successInfos"` + FailedInfos []FileCreateKnowledgeFailedInfo `json:"failedInfos"` +} + +// FileCreateFineTuneResponse is the response of the FileCreateService. +type FileCreateFineTuneResponse struct { + Bytes int64 `json:"bytes"` + CreatedAt int64 `json:"created_at"` + Filename string `json:"filename"` + Object string `json:"object"` + Purpose string `json:"purpose"` + ID string `json:"id"` +} + +// FileCreateResponse is the response of the FileCreateService. +type FileCreateResponse struct { + FileCreateFineTuneResponse + FileCreateKnowledgeResponse +} + +// NewFileCreateService creates a new FileCreateService. +func NewFileCreateService(client *Client) *FileCreateService { + return &FileCreateService{client: client} +} + +// SetLocalFile sets the local_file parameter of the FileCreateService. +func (s *FileCreateService) SetLocalFile(localFile string) *FileCreateService { + s.localFile = localFile + return s +} + +// SetFile sets the file parameter of the FileCreateService. +func (s *FileCreateService) SetFile(file io.Reader, filename string) *FileCreateService { + s.file = file + s.filename = filename + return s +} + +// SetPurpose sets the purpose parameter of the FileCreateService. +func (s *FileCreateService) SetPurpose(purpose string) *FileCreateService { + s.purpose = purpose + return s +} + +// SetCustomSeparator sets the custom_separator parameter of the FileCreateService. +func (s *FileCreateService) SetCustomSeparator(customSeparator string) *FileCreateService { + s.customSeparator = &customSeparator + return s +} + +// SetSentenceSize sets the sentence_size parameter of the FileCreateService. +func (s *FileCreateService) SetSentenceSize(sentenceSize int) *FileCreateService { + s.sentenceSize = &sentenceSize + return s +} + +// SetKnowledgeID sets the knowledge_id parameter of the FileCreateService. +func (s *FileCreateService) SetKnowledgeID(knowledgeID string) *FileCreateService { + s.knowledgeID = &knowledgeID + return s +} + +// Do makes the request. +func (s *FileCreateService) Do(ctx context.Context) (res FileCreateResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + body := map[string]string{"purpose": s.purpose} + + if s.customSeparator != nil { + body["custom_separator"] = *s.customSeparator + } + if s.sentenceSize != nil { + body["sentence_size"] = strconv.Itoa(*s.sentenceSize) + } + if s.knowledgeID != nil { + body["knowledge_id"] = *s.knowledgeID + } + + file, filename := s.file, s.filename + + if file == nil && s.localFile != "" { + var f *os.File + if f, err = os.Open(s.localFile); err != nil { + return + } + defer f.Close() + + file = f + filename = filepath.Base(s.localFile) + } + + if file == nil { + err = errors.New("no file specified") + return + } + + if resp, err = s.client.request(ctx). + SetFileReader("file", filename, file). + SetMultipartFormData(body). + SetResult(&res). + SetError(&apiError). + Post("files"); err != nil { + return + } + + if resp.IsError() { + err = apiError + return + } + + return +} + +// FileEditService is a service to edit a file. +type FileEditService struct { + client *Client + + documentID string + + knowledgeType *int + customSeparator []string + sentenceSize *int +} + +// NewFileEditService creates a new FileEditService. +func NewFileEditService(client *Client) *FileEditService { + return &FileEditService{client: client} +} + +// SetDocumentID sets the document_id parameter of the FileEditService. +func (s *FileEditService) SetDocumentID(documentID string) *FileEditService { + s.documentID = documentID + return s +} + +// SetKnowledgeType sets the knowledge_type parameter of the FileEditService. +func (s *FileEditService) SetKnowledgeType(knowledgeType int) *FileEditService { + s.knowledgeType = &knowledgeType + return s +} + +// SetSentenceSize sets the sentence_size parameter of the FileEditService. +func (s *FileEditService) SetCustomSeparator(customSeparator ...string) *FileEditService { + s.customSeparator = customSeparator + return s +} + +// SetSentenceSize sets the sentence_size parameter of the FileEditService. +func (s *FileEditService) SetSentenceSize(sentenceSize int) *FileEditService { + s.sentenceSize = &sentenceSize + return s +} + +// Do makes the request. +func (s *FileEditService) Do(ctx context.Context) (err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + body := M{} + + if s.knowledgeType != nil { + body["knowledge_type"] = strconv.Itoa(*s.knowledgeType) + } + if len(s.customSeparator) > 0 { + body["custom_separator"] = s.customSeparator + } + if s.sentenceSize != nil { + body["sentence_size"] = strconv.Itoa(*s.sentenceSize) + } + + if resp, err = s.client.request(ctx). + SetPathParam("document_id", s.documentID). + SetBody(body). + SetError(&apiError). + Put("document/{document_id}"); err != nil { + return + } + + if resp.IsError() { + err = apiError + return + } + + return +} + +// FileListService is a service to list files. +type FileListService struct { + client *Client + + purpose string + + knowledgeID *string + page *int + limit *int + after *string + orderAsc *bool +} + +// FileFailInfo is the failed info of the FileListKnowledgeItem. +type FileFailInfo struct { + EmbeddingCode int `json:"embedding_code"` + EmbeddingMsg string `json:"embedding_msg"` +} + +// FileListKnowledgeItem is the item of the FileListKnowledgeResponse. +type FileListKnowledgeItem struct { + ID string `json:"id"` + Name string `json:"name"` + URL string `json:"url"` + Length int64 `json:"length"` + SentenceSize int64 `json:"sentence_size"` + CustomSeparator []string `json:"custom_separator"` + EmbeddingStat int `json:"embedding_stat"` + FailInfo *FileFailInfo `json:"failInfo"` + WordNum int64 `json:"word_num"` + ParseImage int `json:"parse_image"` +} + +// FileListKnowledgeResponse is the response of the FileListService. +type FileListKnowledgeResponse struct { + Total int `json:"total"` + List []FileListKnowledgeItem `json:"list"` +} + +// FileListFineTuneItem is the item of the FileListFineTuneResponse. +type FileListFineTuneItem struct { + Bytes int64 `json:"bytes"` + CreatedAt int64 `json:"created_at"` + Filename string `json:"filename"` + ID string `json:"id"` + Object string `json:"object"` + Purpose string `json:"purpose"` +} + +// FileListFineTuneResponse is the response of the FileListService. +type FileListFineTuneResponse struct { + Object string `json:"object"` + Data []FileListFineTuneItem `json:"data"` +} + +// FileListResponse is the response of the FileListService. +type FileListResponse struct { + FileListKnowledgeResponse + FileListFineTuneResponse +} + +// NewFileListService creates a new FileListService. +func NewFileListService(client *Client) *FileListService { + return &FileListService{client: client} +} + +// SetPurpose sets the purpose parameter of the FileListService. +func (s *FileListService) SetPurpose(purpose string) *FileListService { + s.purpose = purpose + return s +} + +// SetKnowledgeID sets the knowledge_id parameter of the FileListService. +func (s *FileListService) SetKnowledgeID(knowledgeID string) *FileListService { + s.knowledgeID = &knowledgeID + return s +} + +// SetPage sets the page parameter of the FileListService. +func (s *FileListService) SetPage(page int) *FileListService { + s.page = &page + return s +} + +// SetLimit sets the limit parameter of the FileListService. +func (s *FileListService) SetLimit(limit int) *FileListService { + s.limit = &limit + return s +} + +// SetAfter sets the after parameter of the FileListService. +func (s *FileListService) SetAfter(after string) *FileListService { + s.after = &after + return s +} + +// SetOrder sets the order parameter of the FileListService. +func (s *FileListService) SetOrder(asc bool) *FileListService { + s.orderAsc = &asc + return s +} + +// Do makes the request. +func (s *FileListService) Do(ctx context.Context) (res FileListResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + m := map[string]string{ + "purpose": s.purpose, + } + + if s.knowledgeID != nil { + m["knowledge_id"] = *s.knowledgeID + } + if s.page != nil { + m["page"] = strconv.Itoa(*s.page) + } + if s.limit != nil { + m["limit"] = strconv.Itoa(*s.limit) + } + if s.after != nil { + m["after"] = *s.after + } + if s.orderAsc != nil { + if *s.orderAsc { + m["order"] = "asc" + } else { + m["order"] = "desc" + } + } + + if resp, err = s.client.request(ctx). + SetQueryParams(m). + SetResult(&res). + SetError(&apiError). + Get("files"); err != nil { + return + } + + if resp.IsError() { + err = apiError + return + } + + return +} + +// FileDeleteService is a service to delete a file. +type FileDeleteService struct { + client *Client + documentID string +} + +// NewFileDeleteService creates a new FileDeleteService. +func NewFileDeleteService(client *Client) *FileDeleteService { + return &FileDeleteService{client: client} +} + +// SetDocumentID sets the document_id parameter of the FileDeleteService. +func (s *FileDeleteService) SetDocumentID(documentID string) *FileDeleteService { + s.documentID = documentID + return s +} + +// Do makes the request. +func (s *FileDeleteService) Do(ctx context.Context) (err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + if resp, err = s.client.request(ctx). + SetPathParam("document_id", s.documentID). + SetError(&apiError). + Delete("document/{document_id}"); err != nil { + return + } + + if resp.IsError() { + err = apiError + return + } + + return +} + +// FileGetService is a service to get a file. +type FileGetService struct { + client *Client + documentID string +} + +// FileGetResponse is the response of the FileGetService. +type FileGetResponse = FileListKnowledgeItem + +// NewFileGetService creates a new FileGetService. +func NewFileGetService(client *Client) *FileGetService { + return &FileGetService{client: client} +} + +// SetDocumentID sets the document_id parameter of the FileGetService. +func (s *FileGetService) SetDocumentID(documentID string) *FileGetService { + s.documentID = documentID + return s +} + +// Do makes the request. +func (s *FileGetService) Do(ctx context.Context) (res FileGetResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + if resp, err = s.client.request(ctx). + SetPathParam("document_id", s.documentID). + SetResult(&res). + SetError(&apiError). + Get("document/{document_id}"); err != nil { + return + } + + if resp.IsError() { + err = apiError + return + } + + return +} + +// FileDownloadService is a service to download a file. +type FileDownloadService struct { + client *Client + + fileID string + + writer io.Writer + filename string +} + +// NewFileDownloadService creates a new FileDownloadService. +func NewFileDownloadService(client *Client) *FileDownloadService { + return &FileDownloadService{client: client} +} + +// SetFileID sets the file_id parameter of the FileDownloadService. +func (s *FileDownloadService) SetFileID(fileID string) *FileDownloadService { + s.fileID = fileID + return s +} + +// SetOutput sets the output parameter of the FileDownloadService. +func (s *FileDownloadService) SetOutput(w io.Writer) *FileDownloadService { + s.writer = w + return s +} + +// SetOutputFile sets the output_file parameter of the FileDownloadService. +func (s *FileDownloadService) SetOutputFile(filename string) *FileDownloadService { + s.filename = filename + return s +} + +// Do makes the request. +func (s *FileDownloadService) Do(ctx context.Context) (err error) { + var resp *resty.Response + + writer := s.writer + + if writer == nil && s.filename != "" { + var f *os.File + if f, err = os.Create(s.filename); err != nil { + return + } + defer f.Close() + + writer = f + } + + if writer == nil { + return errors.New("no output specified") + } + + if resp, err = s.client.request(ctx). + SetDoNotParseResponse(true). + SetPathParam("file_id", s.fileID). + Get("files/{file_id}/content"); err != nil { + return + } + defer resp.RawBody().Close() + + _, err = io.Copy(writer, resp.RawBody()) + + return +} diff --git a/llm/zhipu/zhipu/file_test.go b/llm/zhipu/zhipu/file_test.go new file mode 100644 index 0000000..3d035ae --- /dev/null +++ b/llm/zhipu/zhipu/file_test.go @@ -0,0 +1,71 @@ +package zhipu + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFileServiceFineTune(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.FileCreate(FilePurposeFineTune) + s.SetLocalFile(filepath.Join("testdata", "test-file.jsonl")) + + res, err := s.Do(context.Background()) + require.NoError(t, err) + require.NotZero(t, res.Bytes) + require.NotZero(t, res.CreatedAt) + require.NotEmpty(t, res.ID) +} + +func TestFileServiceKnowledge(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.FileCreate(FilePurposeRetrieval) + s.SetKnowledgeID(os.Getenv("TEST_KNOWLEDGE_ID")) + s.SetLocalFile(filepath.Join("testdata", "test-file.txt")) + + res, err := s.Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res.SuccessInfos) + require.NotEmpty(t, res.SuccessInfos[0].DocumentID) + require.NotEmpty(t, res.SuccessInfos[0].Filename) + + documentID := res.SuccessInfos[0].DocumentID + + res2, err := client.FileGet(documentID).Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res2.ID) + + err = client.FileEdit(documentID).SetKnowledgeType(KnowledgeTypeCustom).Do(context.Background()) + require.True(t, err == nil || GetAPIErrorCode(err) == "10019") + + err = client.FileDelete(res.SuccessInfos[0].DocumentID).Do(context.Background()) + require.True(t, err == nil || GetAPIErrorCode(err) == "10019") +} + +func TestFileListServiceKnowledge(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.FileList(FilePurposeRetrieval).SetKnowledgeID(os.Getenv("TEST_KNOWLEDGE_ID")) + res, err := s.Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res.List) +} + +func TestFileListServiceFineTune(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.FileList(FilePurposeFineTune) + res, err := s.Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res.Data) +} diff --git a/llm/zhipu/zhipu/fine_tune.go b/llm/zhipu/zhipu/fine_tune.go new file mode 100644 index 0000000..f2121ea --- /dev/null +++ b/llm/zhipu/zhipu/fine_tune.go @@ -0,0 +1,456 @@ +package zhipu + +import ( + "context" + "strconv" + + "github.com/go-resty/resty/v2" +) + +const ( + HyperParameterAuto = "auto" + + FineTuneStatusCreate = "create" + FineTuneStatusValidatingFiles = "validating_files" + FineTuneStatusQueued = "queued" + FineTuneStatusRunning = "running" + FineTuneStatusSucceeded = "succeeded" + FineTuneStatusFailed = "failed" + FineTuneStatusCancelled = "cancelled" +) + +// FineTuneItem is the item of the FineTune +type FineTuneItem struct { + ID string `json:"id"` + RequestID string `json:"request_id"` + FineTunedModel string `json:"fine_tuned_model"` + Status string `json:"status"` + Object string `json:"object"` + TrainingFile string `json:"training_file"` + ValidationFile string `json:"validation_file"` + Error APIError `json:"error"` +} + +// FineTuneCreateService creates a new fine tune +type FineTuneCreateService struct { + client *Client + + model string + trainingFile string + validationFile *string + + learningRateMultiplier *StringOr[float64] + batchSize *StringOr[int] + nEpochs *StringOr[int] + + suffix *string + requestID *string +} + +// FineTuneCreateResponse is the response of the FineTuneCreateService +type FineTuneCreateResponse = FineTuneItem + +// NewFineTuneCreateService creates a new FineTuneCreateService +func NewFineTuneCreateService(client *Client) *FineTuneCreateService { + return &FineTuneCreateService{ + client: client, + } +} + +// SetModel sets the model parameter +func (s *FineTuneCreateService) SetModel(model string) *FineTuneCreateService { + s.model = model + return s +} + +// SetTrainingFile sets the trainingFile parameter +func (s *FineTuneCreateService) SetTrainingFile(trainingFile string) *FineTuneCreateService { + s.trainingFile = trainingFile + return s +} + +// SetValidationFile sets the validationFile parameter +func (s *FineTuneCreateService) SetValidationFile(validationFile string) *FineTuneCreateService { + s.validationFile = &validationFile + return s +} + +// SetLearningRateMultiplier sets the learningRateMultiplier parameter +func (s *FineTuneCreateService) SetLearningRateMultiplier(learningRateMultiplier float64) *FineTuneCreateService { + s.learningRateMultiplier = &StringOr[float64]{} + s.learningRateMultiplier.SetValue(learningRateMultiplier) + return s +} + +// SetLearningRateMultiplierAuto sets the learningRateMultiplier parameter to auto +func (s *FineTuneCreateService) SetLearningRateMultiplierAuto() *FineTuneCreateService { + s.learningRateMultiplier = &StringOr[float64]{} + s.learningRateMultiplier.SetString(HyperParameterAuto) + return s +} + +// SetBatchSize sets the batchSize parameter +func (s *FineTuneCreateService) SetBatchSize(batchSize int) *FineTuneCreateService { + s.batchSize = &StringOr[int]{} + s.batchSize.SetValue(batchSize) + return s +} + +// SetBatchSizeAuto sets the batchSize parameter to auto +func (s *FineTuneCreateService) SetBatchSizeAuto() *FineTuneCreateService { + s.batchSize = &StringOr[int]{} + s.batchSize.SetString(HyperParameterAuto) + return s +} + +// SetNEpochs sets the nEpochs parameter +func (s *FineTuneCreateService) SetNEpochs(nEpochs int) *FineTuneCreateService { + s.nEpochs = &StringOr[int]{} + s.nEpochs.SetValue(nEpochs) + return s +} + +// SetNEpochsAuto sets the nEpochs parameter to auto +func (s *FineTuneCreateService) SetNEpochsAuto() *FineTuneCreateService { + s.nEpochs = &StringOr[int]{} + s.nEpochs.SetString(HyperParameterAuto) + return s +} + +// SetSuffix sets the suffix parameter +func (s *FineTuneCreateService) SetSuffix(suffix string) *FineTuneCreateService { + s.suffix = &suffix + return s +} + +// SetRequestID sets the requestID parameter +func (s *FineTuneCreateService) SetRequestID(requestID string) *FineTuneCreateService { + s.requestID = &requestID + return s +} + +// Do makes the request +func (s *FineTuneCreateService) Do(ctx context.Context) (res FineTuneCreateResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + body := M{ + "model": s.model, + "training_file": s.trainingFile, + } + + if s.validationFile != nil { + body["validation_file"] = *s.validationFile + } + if s.suffix != nil { + body["suffix"] = *s.suffix + } + if s.requestID != nil { + body["request_id"] = *s.requestID + } + if s.learningRateMultiplier != nil || s.batchSize != nil || s.nEpochs != nil { + hp := M{} + if s.learningRateMultiplier != nil { + hp["learning_rate_multiplier"] = s.learningRateMultiplier + } + if s.batchSize != nil { + hp["batch_size"] = s.batchSize + } + if s.nEpochs != nil { + hp["n_epochs"] = s.nEpochs + } + body["hyperparameters"] = hp + } + + if resp, err = s.client.request(ctx). + SetBody(body). + SetResult(&res). + SetError(&apiError). + Post("fine_tuning/jobs"); err != nil { + return + } + if resp.IsError() { + err = apiError + return + } + return +} + +// FineTuneEventListService creates a new fine tune event list +type FineTuneEventListService struct { + client *Client + + jobID string + + limit *int + after *string +} + +// FineTuneEventData is the data of the FineTuneEventItem +type FineTuneEventData struct { + Acc float64 `json:"acc"` + Loss float64 `json:"loss"` + CurrentSteps int64 `json:"current_steps"` + RemainingTime string `json:"remaining_time"` + ElapsedTime string `json:"elapsed_time"` + TotalSteps int64 `json:"total_steps"` + Epoch int64 `json:"epoch"` + TrainedTokens int64 `json:"trained_tokens"` + LearningRate float64 `json:"learning_rate"` +} + +// FineTuneEventItem is the item of the FineTuneEventListResponse +type FineTuneEventItem struct { + ID string `json:"id"` + Type string `json:"type"` + Level string `json:"level"` + Message string `json:"message"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Data FineTuneEventData `json:"data"` +} + +// FineTuneEventListResponse is the response of the FineTuneEventListService +type FineTuneEventListResponse struct { + Data []FineTuneEventItem `json:"data"` + HasMore bool `json:"has_more"` + Object string `json:"object"` +} + +// NewFineTuneEventListService creates a new FineTuneEventListService +func NewFineTuneEventListService(client *Client) *FineTuneEventListService { + return &FineTuneEventListService{ + client: client, + } +} + +// SetJobID sets the jobID parameter +func (s *FineTuneEventListService) SetJobID(jobID string) *FineTuneEventListService { + s.jobID = jobID + return s +} + +// SetLimit sets the limit parameter +func (s *FineTuneEventListService) SetLimit(limit int) *FineTuneEventListService { + s.limit = &limit + return s +} + +// SetAfter sets the after parameter +func (s *FineTuneEventListService) SetAfter(after string) *FineTuneEventListService { + s.after = &after + return s +} + +// Do makes the request +func (s *FineTuneEventListService) Do(ctx context.Context) (res FineTuneEventListResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + req := s.client.request(ctx) + + if s.limit != nil { + req.SetQueryParam("limit", strconv.Itoa(*s.limit)) + } + if s.after != nil { + req.SetQueryParam("after", *s.after) + } + + if resp, err = req. + SetPathParam("job_id", s.jobID). + SetResult(&res). + SetError(&apiError). + Get("fine_tuning/jobs/{job_id}/events"); err != nil { + return + } + if resp.IsError() { + err = apiError + return + } + return +} + +// FineTuneGetService creates a new fine tune get +type FineTuneGetService struct { + client *Client + jobID string +} + +// NewFineTuneGetService creates a new FineTuneGetService +func NewFineTuneGetService(client *Client) *FineTuneGetService { + return &FineTuneGetService{ + client: client, + } +} + +// SetJobID sets the jobID parameter +func (s *FineTuneGetService) SetJobID(jobID string) *FineTuneGetService { + s.jobID = jobID + return s +} + +// Do makes the request +func (s *FineTuneGetService) Do(ctx context.Context) (res FineTuneItem, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + if resp, err = s.client.request(ctx). + SetPathParam("job_id", s.jobID). + SetResult(&res). + SetError(&apiError). + Get("fine_tuning/jobs/{job_id}"); err != nil { + return + } + if resp.IsError() { + err = apiError + return + } + return +} + +// FineTuneListService creates a new fine tune list +type FineTuneListService struct { + client *Client + + limit *int + after *string +} + +// FineTuneListResponse is the response of the FineTuneListService +type FineTuneListResponse struct { + Data []FineTuneItem `json:"data"` + Object string `json:"object"` +} + +// NewFineTuneListService creates a new FineTuneListService +func NewFineTuneListService(client *Client) *FineTuneListService { + return &FineTuneListService{ + client: client, + } +} + +// SetLimit sets the limit parameter +func (s *FineTuneListService) SetLimit(limit int) *FineTuneListService { + s.limit = &limit + return s +} + +// SetAfter sets the after parameter +func (s *FineTuneListService) SetAfter(after string) *FineTuneListService { + s.after = &after + return s +} + +// Do makes the request +func (s *FineTuneListService) Do(ctx context.Context) (res FineTuneListResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + req := s.client.request(ctx) + if s.limit != nil { + req.SetQueryParam("limit", strconv.Itoa(*s.limit)) + } + if s.after != nil { + req.SetQueryParam("after", *s.after) + } + + if resp, err = req. + SetResult(&res). + SetError(&apiError). + Get("fine_tuning/jobs"); err != nil { + return + } + if resp.IsError() { + err = apiError + return + } + return +} + +// FineTuneDeleteService creates a new fine tune delete +type FineTuneDeleteService struct { + client *Client + jobID string +} + +// NewFineTuneDeleteService creates a new FineTuneDeleteService +func NewFineTuneDeleteService(client *Client) *FineTuneDeleteService { + return &FineTuneDeleteService{ + client: client, + } +} + +// SetJobID sets the jobID parameter +func (s *FineTuneDeleteService) SetJobID(jobID string) *FineTuneDeleteService { + s.jobID = jobID + return s +} + +// Do makes the request +func (s *FineTuneDeleteService) Do(ctx context.Context) (res FineTuneItem, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + if resp, err = s.client.request(ctx). + SetPathParam("job_id", s.jobID). + SetResult(&res). + SetError(&apiError). + Delete("fine_tuning/jobs/{job_id}"); err != nil { + return + } + if resp.IsError() { + err = apiError + return + } + return +} + +// FineTuneCancelService creates a new fine tune cancel +type FineTuneCancelService struct { + client *Client + jobID string +} + +// NewFineTuneCancelService creates a new FineTuneCancelService +func NewFineTuneCancelService(client *Client) *FineTuneCancelService { + return &FineTuneCancelService{ + client: client, + } +} + +// SetJobID sets the jobID parameter +func (s *FineTuneCancelService) SetJobID(jobID string) *FineTuneCancelService { + s.jobID = jobID + return s +} + +// Do makes the request +func (s *FineTuneCancelService) Do(ctx context.Context) (res FineTuneItem, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + if resp, err = s.client.request(ctx). + SetPathParam("job_id", s.jobID). + SetResult(&res). + SetError(&apiError). + Post("fine_tuning/jobs/{job_id}/cancel"); err != nil { + return + } + if resp.IsError() { + err = apiError + return + } + return +} diff --git a/llm/zhipu/zhipu/fine_tune_test.go b/llm/zhipu/zhipu/fine_tune_test.go new file mode 100644 index 0000000..2adb1fa --- /dev/null +++ b/llm/zhipu/zhipu/fine_tune_test.go @@ -0,0 +1,3 @@ +package zhipu + +// tests not available since lack of budget to test it diff --git a/llm/zhipu/zhipu/image_generation.go b/llm/zhipu/zhipu/image_generation.go new file mode 100644 index 0000000..32486a9 --- /dev/null +++ b/llm/zhipu/zhipu/image_generation.go @@ -0,0 +1,110 @@ +package zhipu + +import ( + "context" + + "github.com/go-resty/resty/v2" +) + +// ImageGenerationService creates a new image generation +type ImageGenerationService struct { + client *Client + + model string + prompt string + size string + userID string +} + +var ( + _ BatchSupport = &ImageGenerationService{} +) + +// ImageGenerationResponse is the response of the ImageGenerationService +type ImageGenerationResponse struct { + Created int64 `json:"created"` + Data []URLItem `json:"data"` +} + +// NewImageGenerationService creates a new ImageGenerationService +func NewImageGenerationService(client *Client) *ImageGenerationService { + return &ImageGenerationService{ + client: client, + } +} + +func (s *ImageGenerationService) BatchMethod() string { + return "POST" +} + +func (s *ImageGenerationService) BatchURL() string { + return BatchEndpointV4ImagesGenerations +} + +func (s *ImageGenerationService) BatchBody() any { + return s.buildBody() +} + +// SetModel sets the model parameter +func (s *ImageGenerationService) SetModel(model string) *ImageGenerationService { + s.model = model + return s +} + +// SetPrompt sets the prompt parameter +func (s *ImageGenerationService) SetPrompt(prompt string) *ImageGenerationService { + s.prompt = prompt + return s +} + +func (s *ImageGenerationService) SetSize(size string) *ImageGenerationService { + s.size = size + return s +} + +// SetUserID sets the userID parameter +func (s *ImageGenerationService) SetUserID(userID string) *ImageGenerationService { + s.userID = userID + return s +} + +func (s *ImageGenerationService) buildBody() M { + body := M{ + "model": s.model, + "prompt": s.prompt, + } + + if s.userID != "" { + body["user_id"] = s.userID + } + + if s.size != "" { + body["size"] = s.size + } + + return body +} + +func (s *ImageGenerationService) Do(ctx context.Context) (res ImageGenerationResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + body := s.buildBody() + + if resp, err = s.client.request(ctx). + SetBody(body). + SetResult(&res). + SetError(&apiError). + Post("images/generations"); err != nil { + return + } + + if resp.IsError() { + err = apiError + return + } + + return +} diff --git a/llm/zhipu/zhipu/image_generation_test.go b/llm/zhipu/zhipu/image_generation_test.go new file mode 100644 index 0000000..7ab7807 --- /dev/null +++ b/llm/zhipu/zhipu/image_generation_test.go @@ -0,0 +1,21 @@ +package zhipu + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestImageGenerationService(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.ImageGeneration("cogview-3") + s.SetPrompt("一只可爱的小猫") + + res, err := s.Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res.Data) + t.Log(res.Data[0].URL) +} diff --git a/llm/zhipu/zhipu/knowledge.go b/llm/zhipu/zhipu/knowledge.go new file mode 100644 index 0000000..7f43808 --- /dev/null +++ b/llm/zhipu/zhipu/knowledge.go @@ -0,0 +1,299 @@ +package zhipu + +import ( + "context" + "strconv" + + "github.com/go-resty/resty/v2" +) + +const ( + KnowledgeEmbeddingIDEmbedding2 = 3 +) + +// KnowledgeCreateService creates a new knowledge +type KnowledgeCreateService struct { + client *Client + + embeddingID int + name string + description *string +} + +// KnowledgeCreateResponse is the response of the KnowledgeCreateService +type KnowledgeCreateResponse = IDItem + +// NewKnowledgeCreateService creates a new KnowledgeCreateService +func NewKnowledgeCreateService(client *Client) *KnowledgeCreateService { + return &KnowledgeCreateService{ + client: client, + } +} + +// SetEmbeddingID sets the embedding id of the knowledge +func (s *KnowledgeCreateService) SetEmbeddingID(embeddingID int) *KnowledgeCreateService { + s.embeddingID = embeddingID + return s +} + +// SetName sets the name of the knowledge +func (s *KnowledgeCreateService) SetName(name string) *KnowledgeCreateService { + s.name = name + return s +} + +// SetDescription sets the description of the knowledge +func (s *KnowledgeCreateService) SetDescription(description string) *KnowledgeCreateService { + s.description = &description + return s +} + +// Do creates the knowledge +func (s *KnowledgeCreateService) Do(ctx context.Context) (res KnowledgeCreateResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + body := M{ + "name": s.name, + "embedding_id": s.embeddingID, + } + if s.description != nil { + body["description"] = *s.description + } + if resp, err = s.client.request(ctx). + SetBody(body). + SetResult(&res). + SetError(&apiError). + Post("knowledge"); err != nil { + return + } + if resp.IsError() { + err = apiError + return + } + return +} + +// KnowledgeEditService edits a knowledge +type KnowledgeEditService struct { + client *Client + + knowledgeID string + + embeddingID *int + name *string + description *string +} + +// NewKnowledgeEditService creates a new KnowledgeEditService +func NewKnowledgeEditService(client *Client) *KnowledgeEditService { + return &KnowledgeEditService{ + client: client, + } +} + +// SetKnowledgeID sets the knowledge id +func (s *KnowledgeEditService) SetKnowledgeID(knowledgeID string) *KnowledgeEditService { + s.knowledgeID = knowledgeID + return s +} + +// SetName sets the name of the knowledge +func (s *KnowledgeEditService) SetName(name string) *KnowledgeEditService { + s.name = &name + return s +} + +// SetEmbeddingID sets the embedding id of the knowledge +func (s *KnowledgeEditService) SetEmbeddingID(embeddingID int) *KnowledgeEditService { + s.embeddingID = &embeddingID + return s +} + +// SetDescription sets the description of the knowledge +func (s *KnowledgeEditService) SetDescription(description string) *KnowledgeEditService { + s.description = &description + return s +} + +// Do edits the knowledge +func (s *KnowledgeEditService) Do(ctx context.Context) (err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + body := M{} + if s.name != nil { + body["name"] = *s.name + } + if s.description != nil { + body["description"] = *s.description + } + if s.embeddingID != nil { + body["embedding_id"] = *s.embeddingID + } + if resp, err = s.client.request(ctx). + SetPathParam("knowledge_id", s.knowledgeID). + SetBody(body). + SetError(&apiError). + Put("knowledge/{knowledge_id}"); err != nil { + return + } + if resp.IsError() { + err = apiError + return + } + return +} + +// KnowledgeListService lists the knowledge +type KnowledgeListService struct { + client *Client + + page *int + size *int +} + +// KnowledgeItem is an item in the knowledge list +type KnowledgeItem struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Icon string `json:"icon"` + Background string `json:"background"` + EmbeddingID int `json:"embedding_id"` + CustomIdentifier string `json:"custom_identifier"` + WordNum int64 `json:"word_num"` + Length int64 `json:"length"` + DocumentSize int64 `json:"document_size"` +} + +// KnowledgeListResponse is the response of the KnowledgeListService +type KnowledgeListResponse struct { + List []KnowledgeItem `json:"list"` + Total int `json:"total"` +} + +// NewKnowledgeListService creates a new KnowledgeListService +func NewKnowledgeListService(client *Client) *KnowledgeListService { + return &KnowledgeListService{client: client} +} + +// SetPage sets the page of the knowledge list +func (s *KnowledgeListService) SetPage(page int) *KnowledgeListService { + s.page = &page + return s +} + +// SetSize sets the size of the knowledge list +func (s *KnowledgeListService) SetSize(size int) *KnowledgeListService { + s.size = &size + return s +} + +// Do lists the knowledge +func (s *KnowledgeListService) Do(ctx context.Context) (res KnowledgeListResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + req := s.client.request(ctx) + if s.page != nil { + req.SetQueryParam("page", strconv.Itoa(*s.page)) + } + if s.size != nil { + req.SetQueryParam("size", strconv.Itoa(*s.size)) + } + if resp, err = req. + SetResult(&res). + SetError(&apiError). + Get("knowledge"); err != nil { + return + } + if resp.IsError() { + err = apiError + return + } + return +} + +// KnowledgeDeleteService deletes a knowledge +type KnowledgeDeleteService struct { + client *Client + + knowledgeID string +} + +// NewKnowledgeDeleteService creates a new KnowledgeDeleteService +func NewKnowledgeDeleteService(client *Client) *KnowledgeDeleteService { + return &KnowledgeDeleteService{ + client: client, + } +} + +// SetKnowledgeID sets the knowledge id +func (s *KnowledgeDeleteService) SetKnowledgeID(knowledgeID string) *KnowledgeDeleteService { + s.knowledgeID = knowledgeID + return s +} + +// Do deletes the knowledge +func (s *KnowledgeDeleteService) Do(ctx context.Context) (err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + if resp, err = s.client.request(ctx). + SetPathParam("knowledge_id", s.knowledgeID). + SetError(&apiError). + Delete("knowledge/{knowledge_id}"); err != nil { + return + } + if resp.IsError() { + err = apiError + return + } + return +} + +// KnowledgeCapacityService query the capacity of the knowledge +type KnowledgeCapacityService struct { + client *Client +} + +// KnowledgeCapacityItem is an item in the knowledge capacity +type KnowledgeCapacityItem struct { + WordNum int64 `json:"word_num"` + Length int64 `json:"length"` +} + +// KnowledgeCapacityResponse is the response of the KnowledgeCapacityService +type KnowledgeCapacityResponse struct { + Used KnowledgeCapacityItem `json:"used"` + Total KnowledgeCapacityItem `json:"total"` +} + +// SetKnowledgeID sets the knowledge id +func NewKnowledgeCapacityService(client *Client) *KnowledgeCapacityService { + return &KnowledgeCapacityService{client: client} +} + +// Do query the capacity of the knowledge +func (s *KnowledgeCapacityService) Do(ctx context.Context) (res KnowledgeCapacityResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + if resp, err = s.client.request(ctx). + SetResult(&res). + SetError(&apiError). + Get("knowledge/capacity"); err != nil { + return + } + if resp.IsError() { + err = apiError + return + } + return +} diff --git a/llm/zhipu/zhipu/knowledge_test.go b/llm/zhipu/zhipu/knowledge_test.go new file mode 100644 index 0000000..a330e74 --- /dev/null +++ b/llm/zhipu/zhipu/knowledge_test.go @@ -0,0 +1,50 @@ +package zhipu + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestKnowledgeCapacity(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.KnowledgeCapacity() + res, err := s.Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res.Total.Length) + require.NotEmpty(t, res.Total.WordNum) +} + +func TestKnowledgeServiceAll(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.KnowledgeCreate() + s.SetName("test") + s.SetDescription("test description") + s.SetEmbeddingID(KnowledgeEmbeddingIDEmbedding2) + + res, err := s.Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res.ID) + + s2 := client.KnowledgeList() + res2, err := s2.Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res2.List) + require.Equal(t, res.ID, res2.List[0].ID) + + s3 := client.KnowledgeEdit(res.ID) + s3.SetDescription("test description 2") + s3.SetName("test 2") + s3.SetEmbeddingID(KnowledgeEmbeddingIDEmbedding2) + err = s3.Do(context.Background()) + require.NoError(t, err) + + s4 := client.KnowledgeDelete(res.ID) + err = s4.Do(context.Background()) + require.NoError(t, err) +} diff --git a/llm/zhipu/zhipu/string_or.go b/llm/zhipu/zhipu/string_or.go new file mode 100644 index 0000000..f0a72e6 --- /dev/null +++ b/llm/zhipu/zhipu/string_or.go @@ -0,0 +1,54 @@ +package zhipu + +import ( + "bytes" + "encoding/json" +) + +// StringOr is a struct that can be either a string or a value of type T. +type StringOr[T any] struct { + String *string + Value *T +} + +var ( + _ json.Marshaler = StringOr[float64]{} + _ json.Unmarshaler = &StringOr[float64]{} +) + +// SetString sets the string value of the struct. +func (f *StringOr[T]) SetString(v string) { + f.String = &v + f.Value = nil +} + +// SetValue sets the value of the struct. +func (f *StringOr[T]) SetValue(v T) { + f.String = nil + f.Value = &v +} + +func (f StringOr[T]) MarshalJSON() ([]byte, error) { + if f.Value != nil { + return json.Marshal(f.Value) + } + return json.Marshal(f.String) +} + +func (f *StringOr[T]) UnmarshalJSON(data []byte) error { + if len(data) == 0 { + return nil + } + if bytes.Equal(data, []byte("null")) { + return nil + } + if data[0] == '"' { + f.String = new(string) + f.Value = nil + return json.Unmarshal(data, f.String) + } else { + f.Value = new(T) + f.String = nil + return json.Unmarshal(data, f.Value) + } +} diff --git a/llm/zhipu/zhipu/string_or_test.go b/llm/zhipu/zhipu/string_or_test.go new file mode 100644 index 0000000..cdc27da --- /dev/null +++ b/llm/zhipu/zhipu/string_or_test.go @@ -0,0 +1,37 @@ +package zhipu + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestStringOr(t *testing.T) { + data := struct { + Item *StringOr[int] `json:"item,omitempty"` + }{} + data.Item = &StringOr[int]{} + data.Item.SetString("test") + + b, err := json.Marshal(data) + require.NoError(t, err) + require.Equal(t, `{"item":"test"}`, string(b)) + + data.Item.SetValue(1) + b, err = json.Marshal(data) + require.NoError(t, err) + require.Equal(t, `{"item":1}`, string(b)) + + err = json.Unmarshal([]byte(`{"item":"test2"}`), &data) + require.NoError(t, err) + require.NotNil(t, data.Item.String) + require.Nil(t, data.Item.Value) + require.Equal(t, "test2", *data.Item.String) + + err = json.Unmarshal([]byte(`{"item":2}`), &data) + require.NoError(t, err) + require.Nil(t, data.Item.String) + require.NotNil(t, data.Item.Value) + require.Equal(t, 2, *data.Item.Value) +} diff --git a/llm/zhipu/zhipu/testdata/test-file.jsonl b/llm/zhipu/zhipu/testdata/test-file.jsonl new file mode 100644 index 0000000..4737cfa --- /dev/null +++ b/llm/zhipu/zhipu/testdata/test-file.jsonl @@ -0,0 +1,3 @@ +{"messages": [{"role": "system", "content": "你是一位乐于助人,知识渊博的全能AI助手。"},{"role": "user", "content": "请介绍一下AlphaGo的基本原理。"}, {"role": "assistant", "content": "AlphaGo利用深度神经网络与树搜索算法相结合,通过对弈数据进行训练,逐步提升游戏水平,能够对围棋局面进行判断并选择最佳落子位置。"}]} +{"messages": [{"role": "system", "content": "你是一位乐于助人,知识渊博的全能AI助手。"},{"role": "user", "content": "请列举一些健康饮食的重要原则。"}, {"role": "assistant", "content": "健康饮食的重要原则包括:膳食应营养均衡,合理摄入各种营养素;多吃新鲜蔬果,适量饮水;少吃糖、油、盐,忌吃过量快餐;养成规律生活作息习惯,适量运动。"}]} +{"messages": [{"role": "system", "content": "你是一位乐于助人,知识渊博的全能AI助手。"},{"role": "user", "content": "请给出一些改善睡眠质量的建议。"}, {"role": "assistant", "content": "改善睡眠质量的建议包括:保证睡眠时间,避免睡眠不足;睡前放松身心,远离刺激;养成规律作息时间,不要频繁熬夜;适量运动,但不要运动过度;睡前可以喝一杯热牛奶等温和饮料。"}]} \ No newline at end of file diff --git a/llm/zhipu/zhipu/testdata/test-file.txt b/llm/zhipu/zhipu/testdata/test-file.txt new file mode 100644 index 0000000..ff3bb63 --- /dev/null +++ b/llm/zhipu/zhipu/testdata/test-file.txt @@ -0,0 +1 @@ +The quick brown fox jumps over the lazy dog \ No newline at end of file diff --git a/llm/zhipu/zhipu/util.go b/llm/zhipu/zhipu/util.go new file mode 100644 index 0000000..4912b68 --- /dev/null +++ b/llm/zhipu/zhipu/util.go @@ -0,0 +1,22 @@ +package zhipu + +// URLItem is a struct that contains a URL. +type URLItem struct { + URL string `json:"url,omitempty"` +} + +// IDItem is a struct that contains an ID. +type IDItem struct { + ID string `json:"id,omitempty"` +} + +// Ptr returns a pointer to the value passed in. +// Example: +// +// web_search_enable = zhipu.Ptr(false) +func Ptr[T any](v T) *T { + return &v +} + +// M is a shorthand for map[string]any. +type M = map[string]any diff --git a/llm/zhipu/zhipu/util_test.go b/llm/zhipu/zhipu/util_test.go new file mode 100644 index 0000000..52b5052 --- /dev/null +++ b/llm/zhipu/zhipu/util_test.go @@ -0,0 +1,3 @@ +package zhipu + +// nothing to test diff --git a/llm/zhipu/zhipu/video_generation.go b/llm/zhipu/zhipu/video_generation.go new file mode 100644 index 0000000..3ae3279 --- /dev/null +++ b/llm/zhipu/zhipu/video_generation.go @@ -0,0 +1,125 @@ +package zhipu + +import ( + "context" + + "github.com/go-resty/resty/v2" +) + +const ( + VideoGenerationTaskStatusProcessing = "PROCESSING" + VideoGenerationTaskStatusSuccess = "SUCCESS" + VideoGenerationTaskStatusFail = "FAIL" +) + +// VideoGenerationService creates a new video generation +type VideoGenerationService struct { + client *Client + + model string + prompt string + userID string + imageURL string + requestID string +} + +var ( + _ BatchSupport = &VideoGenerationService{} +) + +// VideoGenerationResponse is the response of the VideoGenerationService +type VideoGenerationResponse struct { + RequestID string `json:"request_id"` + ID string `json:"id"` + Model string `json:"model"` + TaskStatus string `json:"task_status"` +} + +func NewVideoGenerationService(client *Client) *VideoGenerationService { + return &VideoGenerationService{ + client: client, + } +} + +func (s *VideoGenerationService) BatchMethod() string { + return "POST" +} + +func (s *VideoGenerationService) BatchURL() string { + return BatchEndpointV4VideosGenerations +} + +func (s *VideoGenerationService) BatchBody() any { + return s.buildBody() +} + +// SetModel sets the model parameter +func (s *VideoGenerationService) SetModel(model string) *VideoGenerationService { + s.model = model + return s +} + +// SetPrompt sets the prompt parameter +func (s *VideoGenerationService) SetPrompt(prompt string) *VideoGenerationService { + s.prompt = prompt + return s +} + +// SetUserID sets the userID parameter +func (s *VideoGenerationService) SetUserID(userID string) *VideoGenerationService { + s.userID = userID + return s +} + +// SetImageURL sets the imageURL parameter +func (s *VideoGenerationService) SetImageURL(imageURL string) *VideoGenerationService { + s.imageURL = imageURL + return s +} + +// SetRequestID sets the requestID parameter +func (s *VideoGenerationService) SetRequestID(requestID string) *VideoGenerationService { + s.requestID = requestID + return s +} + +func (s *VideoGenerationService) buildBody() M { + body := M{ + "model": s.model, + "prompt": s.prompt, + } + if s.userID != "" { + body["user_id"] = s.userID + } + if s.imageURL != "" { + body["image_url"] = s.imageURL + } + if s.requestID != "" { + body["request_id"] = s.requestID + } + return body +} + +func (s *VideoGenerationService) Do(ctx context.Context) (res VideoGenerationResponse, err error) { + var ( + resp *resty.Response + apiError APIErrorResponse + ) + + body := s.buildBody() + + if resp, err = s.client.request(ctx). + SetBody(body). + SetResult(&res). + SetError(&apiError). + Post("videos/generations"); err != nil { + return + } + + if resp.IsError() { + err = apiError + return + } + + return +} diff --git a/llm/zhipu/zhipu/video_generation_test.go b/llm/zhipu/zhipu/video_generation_test.go new file mode 100644 index 0000000..8dc5ab4 --- /dev/null +++ b/llm/zhipu/zhipu/video_generation_test.go @@ -0,0 +1,38 @@ +package zhipu + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestVideoGeneration(t *testing.T) { + client, err := NewClient() + require.NoError(t, err) + + s := client.VideoGeneration("cogvideox") + s.SetPrompt("一只可爱的小猫") + + res, err := s.Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res.TaskStatus) + require.NotEmpty(t, res.ID) + t.Log(res.ID) + + for { + res, err := client.AsyncResult(res.ID).Do(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, res.TaskStatus) + if res.TaskStatus == VideoGenerationTaskStatusSuccess { + require.NotEmpty(t, res.VideoResult) + t.Log(res.VideoResult[0].URL) + t.Log(res.VideoResult[0].CoverImageURL) + } + if res.TaskStatus != VideoGenerationTaskStatusProcessing { + break + } + time.Sleep(time.Second * 5) + } +} diff --git a/llm/zhipu/zhipu/wechat-donation.png b/llm/zhipu/zhipu/wechat-donation.png new file mode 100644 index 0000000..f6519f6 Binary files /dev/null and b/llm/zhipu/zhipu/wechat-donation.png differ diff --git a/tests/ask.js b/tests/ask.js new file mode 100644 index 0000000..e73fe97 --- /dev/null +++ b/tests/ask.js @@ -0,0 +1,9 @@ +import {glm, gpt} from './lib/ai' +import console from './lib/console' +import chat from './mod/chat' + +function main(...args) { + let ask = args[0] + if (!ask) ask = console.input('请输入问题:') + return chat.fast(ask) +} diff --git a/tests/chat_test.go b/tests/chat_test.go index f295896..44c171b 100644 --- a/tests/chat_test.go +++ b/tests/chat_test.go @@ -2,51 +2,48 @@ package tests import ( "apigo.cc/ai/ai" - _ "apigo.cc/ai/openai" - _ "apigo.cc/ai/zhipu" + "apigo.cc/ai/ai/js" + "apigo.cc/ai/ai/llm" "fmt" - "github.com/ssgo/config" "github.com/ssgo/u" "testing" ) -var agentConf = map[string]ai.AgentConfig{} - -func init() { - _ = config.LoadConfig("agent", &agentConf) -} - func TestAgent(t *testing.T) { - //testChat(t, "openai", "https://api.keya.pw/v1", keys.Openai) + ai.Init() + //testChat(t, "gpt") // 尚未支持 - //testCode(t, "openai", "https://api.keya.pw/v1", keys.Openai) + //testCode(t, "gpt") // keya 消耗过大,计费不合理 - //testAskWithImage(t, "openai", "https://api.keya.pw/v1", keys.Openai, "4032.jpg") + //testAskWithImage(t, "gpt", "4032.jpg") // keya 不支持 - //testMakeImage(t, "openai", "https://api.keya.pw/v1", keys.Openai, "冬天大雪纷飞,一个男人身穿军绿色棉大衣,戴着红色围巾和绿色帽子走在铺面大雪的小镇路上", "") + //testMakeImage(t, "gpt", "冬天大雪纷飞,一个男人身穿军绿色棉大衣,戴着红色围巾和绿色帽子走在铺面大雪的小镇路上", "") - testChat(t, "zhipu") - //testCode(t, "zhipu", "", keys.Zhipu) - //testSearch(t, "zhipu", "", keys.Zhipu) + //testChat(t, "glm") + //testCode(t, "glm", "", keys.Zhipu) + //testSearch(t, "glm", "", keys.Zhipu) // 测试图片识别 - //testAskWithImage(t, "zhipu", "", keys.Zhipu, "4032.jpg") + //testAskWithImage(t, "glm", "", keys.Zhipu, "4032.jpg") // 视频似乎尚不支持 glm-4v-plus - //testAskWithVideo(t, "zhipu", "", keys.Zhipu, "glm-4v", "1080.mp4") + //testAskWithVideo(t, "glm", "", keys.Zhipu, "glm-4v", "1080.mp4") - //testMakeImage(t, "zhipu", "", keys.Zhipu, "冬天大雪纷飞,一个男人身穿军绿色棉大衣,戴着红色围巾和绿色帽子走在铺面大雪的小镇路上", "") - //testMakeVideo(t, "zhipu", "", keys.Zhipu, "大雪纷飞,男人蹦蹦跳跳", "https://aigc-files.bigmodel.cn/api/cogview/20240904133130c4b7121019724aa3_0.png") + //testMakeImage(t, "glm", "", keys.Zhipu, "冬天大雪纷飞,一个男人身穿军绿色棉大衣,戴着红色围巾和绿色帽子走在铺面大雪的小镇路上", "") + //testMakeVideo(t, "glm", "", keys.Zhipu, "大雪纷飞,男人蹦蹦跳跳", "https://aigc-files.bigmodel.cn/api/cogview/20240904133130c4b7121019724aa3_0.png") + + testJS(t) + //testFile(t) } -func testChat(t *testing.T, agent string) { - ag := ai.GetAgent(agent) +func testChat(t *testing.T, llmName string) { + lm := llm.Get(llmName) - if ag == nil { + if lm == nil { t.Fatal("agent is nil") } - r, usage, err := ag.FastAsk(ai.Messages().User().Text("你是什么模型,请给出具体名称、版本号").Make(), func(text string) { + r, usage, err := lm.FastAsk(llm.Messages().User().Text("你是什么模型,请给出具体名称、版本号").Make(), func(text string) { fmt.Print(u.BCyan(text)) fmt.Print(" ") }) @@ -60,14 +57,14 @@ func testChat(t *testing.T, agent string) { fmt.Println("usage:", u.JsonP(usage)) } -func testCode(t *testing.T, agent string) { - ag := ai.GetAgent(agent) +func testCode(t *testing.T, llmName string) { + lm := llm.Get(llmName) - if ag == nil { + if lm == nil { t.Fatal("agent is nil") } - r, usage, err := ag.CodeInterpreterAsk(ai.Messages().User().Text("计算[5,10,20,700,99,310,978,100]的平均值和方差。").Make(), func(text string) { + r, usage, err := lm.CodeInterpreterAsk(llm.Messages().User().Text("计算[5,10,20,700,99,310,978,100]的平均值和方差。").Make(), func(text string) { fmt.Print(u.BCyan(text)) fmt.Print(" ") }) @@ -81,14 +78,14 @@ func testCode(t *testing.T, agent string) { fmt.Println("usage:", u.JsonP(usage)) } -func testSearch(t *testing.T, agent string) { - ag := ai.GetAgent(agent) +func testSearch(t *testing.T, llmName string) { + lm := llm.Get(llmName) - if ag == nil { + if lm == nil { t.Fatal("agent is nil") } - r, usage, err := ag.WebSearchAsk(ai.Messages().User().Text("今天上海的天气怎么样?").Make(), func(text string) { + r, usage, err := lm.WebSearchAsk(llm.Messages().User().Text("今天上海的天气怎么样?").Make(), func(text string) { fmt.Print(u.BCyan(text)) fmt.Print(" ") }) @@ -102,10 +99,10 @@ func testSearch(t *testing.T, agent string) { fmt.Println("usage:", u.JsonP(usage)) } -func testAskWithImage(t *testing.T, agent, imageFile string) { - ag := ai.GetAgent(agent) +func testAskWithImage(t *testing.T, llmName, imageFile string) { + lm := llm.Get(llmName) - if ag == nil { + if lm == nil { t.Fatal("agent is nil") } @@ -115,7 +112,7 @@ func testAskWithImage(t *testing.T, agent, imageFile string) { 3、正在用什么软件播放什么歌?谁演唱的?歌曲的大意是? 4、后面的浏览器中正在浏览什么内容?猜测一下我浏览这个网页是想干嘛? ` - r, usage, err := ag.MultiAsk(ai.Messages().User().Text(ask).Image("data:image/jpeg;base64,"+u.Base64(u.ReadFileBytesN(imageFile))).Make(), func(text string) { + r, usage, err := lm.MultiAsk(llm.Messages().User().Text(ask).Image("data:image/jpeg;base64,"+u.Base64(u.ReadFileBytesN(imageFile))).Make(), func(text string) { fmt.Print(u.BCyan(text)) fmt.Print(" ") }) @@ -129,10 +126,10 @@ func testAskWithImage(t *testing.T, agent, imageFile string) { fmt.Println("usage:", u.JsonP(usage)) } -func testAskWithVideo(t *testing.T, agent, videoFile string) { - ag := ai.GetAgent(agent) +func testAskWithVideo(t *testing.T, llmName, videoFile string) { + lm := llm.Get(llmName) - if ag == nil { + if lm == nil { t.Fatal("agent is nil") } @@ -141,7 +138,7 @@ func testAskWithVideo(t *testing.T, agent, videoFile string) { 4、后面的浏览器中正在浏览什么内容?猜测一下我浏览这个网页是想干嘛? ` - r, usage, err := ag.MultiAsk(ai.Messages().User().Text(ask).Video("data:video/mp4,"+u.Base64(u.ReadFileBytesN(videoFile))).Make(), func(text string) { + r, usage, err := lm.MultiAsk(llm.Messages().User().Text(ask).Video("data:video/mp4,"+u.Base64(u.ReadFileBytesN(videoFile))).Make(), func(text string) { fmt.Print(u.BCyan(text)) fmt.Print(" ") }) @@ -155,14 +152,16 @@ func testAskWithVideo(t *testing.T, agent, videoFile string) { fmt.Println("usage:", u.JsonP(usage)) } -func testMakeImage(t *testing.T, agent, prompt, refImage string) { - ag := ai.GetAgent(agent) +func testMakeImage(t *testing.T, llmName, prompt, refImage string) { + lm := llm.Get(llmName) - if ag == nil { + if lm == nil { t.Fatal("agent is nil") } - r, err := ag.FastMakeImage(prompt, "1024x1024", "") + r, err := lm.FastMakeImage(prompt, llm.GCConfig{ + Size: "1024x1024", + }) if err != nil { t.Fatal("发生错误", err.Error()) @@ -174,14 +173,17 @@ func testMakeImage(t *testing.T, agent, prompt, refImage string) { } } -func testMakeVideo(t *testing.T, agent, prompt, refImage string) { - ag := ai.GetAgent(agent) +func testMakeVideo(t *testing.T, llmName, prompt, refImage string) { + lm := llm.Get(llmName) - if ag == nil { + if lm == nil { t.Fatal("agent is nil") } - r, covers, err := ag.FastMakeVideo(prompt, "1280x720", refImage) + r, covers, err := lm.FastMakeVideo(prompt, llm.GCConfig{ + Size: "1280x720", + Ref: refImage, + }) if err != nil { t.Fatal("发生错误", err.Error()) @@ -192,3 +194,31 @@ func testMakeVideo(t *testing.T, agent, prompt, refImage string) { fmt.Println("result:", i, v, covers[i]) } } + +func testJS(t *testing.T) { + r1, err := js.RunFile("test.js", "1+2=4吗") + if err != nil { + t.Fatal("发生错误", err.Error()) + } + r := js.ChatResult{} + u.Convert(r1, &r) + + fmt.Println() + fmt.Println("result:", r.Result) + fmt.Println("usage:", u.JsonP(r.TokenUsage)) +} + +func testFile(t *testing.T) { + r1, err := js.Run(` +import fs from './lib/file' +import out from './lib/console' +let r = fs.read('test.js') +return r +`, "test.js") + if err != nil { + t.Fatal("发生错误", err.Error()) + } + + fmt.Println() + fmt.Println("result:", r1) +} diff --git a/tests/env.sample.yml b/tests/env.sample.yml index e19de77..dcb3d69 100644 --- a/tests/env.sample.yml +++ b/tests/env.sample.yml @@ -1,6 +1,7 @@ -agent: - openai: +llm: + gpt1: apiKey: ... -# endpoint: ... +# endpoint: https://.... # api base url +# llm: openai # registered llm name, if not specified, it will use the key name as llm name zhipu: apiKey: ... diff --git a/tests/go.mod b/tests/go.mod deleted file mode 100644 index ca6e843..0000000 --- a/tests/go.mod +++ /dev/null @@ -1,24 +0,0 @@ -module tests - -go 1.22 - -require ( - apigo.cc/ai/ai v0.0.0 - apigo.cc/ai/openai v0.0.1 - apigo.cc/ai/zhipu v0.0.1 - github.com/ssgo/config v1.7.7 - github.com/ssgo/u v1.7.7 -) - -require ( - apigo.cc/ai/agent v0.0.1 // indirect - github.com/go-resty/resty/v2 v2.14.0 // indirect - github.com/golang-jwt/jwt/v5 v5.2.1 // indirect - github.com/sashabaranov/go-openai v1.29.1 // indirect - github.com/ssgo/log v1.7.7 // indirect - github.com/ssgo/standard v1.7.7 // indirect - golang.org/x/net v0.29.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect -) - -replace apigo.cc/ai/ai v0.0.0 => ../ diff --git a/tests/mod/chat.js b/tests/mod/chat.js new file mode 100644 index 0000000..8cb9f99 --- /dev/null +++ b/tests/mod/chat.js @@ -0,0 +1,14 @@ +import {glm} from '../lib/ai' +import console from '../lib/console' + +function fast(...args) { + if(!args[0]) throw new Error('no ask') + let r = glm.fastAsk(args[0], r => { + console.print(r) + }) + console.println() + return r +} + +// exports.fast = fast +module.exports = {fast} diff --git a/tests/test.js b/tests/test.js new file mode 100644 index 0000000..b2a177f --- /dev/null +++ b/tests/test.js @@ -0,0 +1,6 @@ +import {glm, gpt} from './lib/ai' +import chat from './mod/chat' + +function main(...args) { + return chat.fast(args[0]) +}