diff --git a/glm/glm.go b/glm/glm.go new file mode 100644 index 0000000..608a485 --- /dev/null +++ b/glm/glm.go @@ -0,0 +1,142 @@ +package glm + +import ( + "bufio" + "github.com/api-go/plugin" + "github.com/golang-jwt/jwt" + "github.com/ssgo/httpclient" + "github.com/ssgo/u" + "io" + "strings" + "time" +) + +var apiKey string + +type Option struct { + Model string + TopP float64 + Temperature float64 +} + +func init() { + plugin.Register(plugin.Plugin{ + Id: "glm", + Name: "智谱ChatGLM", + ConfigSample: `apiKey: <**encrypted_apiKey**> # 从bigmodel.cn获得的APIKey`, + Init: func(conf map[string]interface{}) { + apiKey = u.String(conf["apiKey"]) + }, + Objects: map[string]interface{}{ + // send 方法说明(注释中需包含方法名称) + // * prompt 提示词 + // * parentMessages 历史消息(可选) + // * option 选项(可选,model[chatglm_lite/chatglm_std/chatglm_pro/chatglm_turbo],topP[0-1],temperature[0-1]) + // send return 回答内容 + "send": func(prompt string, parentMessages *[]map[string]interface{}, option *Option) (map[string]interface{}, error) { + url, message := makeRequest(prompt, parentMessages, option) + c := httpclient.GetClient(300 * time.Second) + r := c.Post(url+"invoke", message, "Authorization", generateToken(apiKey)) + return r.Map(), r.Error + }, + // sendAsync 方法说明(注释中需包含方法名称) + // sendAsync callback 回调函数,返回参数为event(add|finish)和data + // sendAsync meta 回答内容 + "sendAsync": func(callback func(event, data string), prompt string, parentMessages *[]map[string]interface{}, option *Option) (meta map[string]interface{}, err error) { + url, message := makeRequest(prompt, parentMessages, option) + message["stream"] = true + c := httpclient.GetClient(300 * time.Second) + r := c.ManualDo("POST", url+"sse-invoke", message, "Authorization", generateToken(apiKey), "Accept", "text/event-stream") + reader := bufio.NewReader(r.Response.Body) + lastEvent := "" + //lastId := "" + meta = map[string]interface{}{} + for { + buf, _, err := reader.ReadLine() + if err == io.EOF { + break + } + if err != nil { + return meta, err + } + line := string(buf) + pos := strings.IndexByte(line, ':') + if pos != -1 { + k := line[0:pos] + v := line[pos+1:] + switch k { + case "id": + //lastId = v + case "event": + lastEvent = v + case "data": + if v == "" { + v = "\n" + } + callback(lastEvent, v) + case "meta": + meta = u.UnJsonMap(v) + } + } + } + return meta, nil + }, + }, + }) +} + +func makeRequest(prompt string, parentMessages *[]map[string]interface{}, option *Option) (url string, message map[string]interface{}) { + if option == nil { + option = &Option{} + if option.Model == "" { + option.Model = "chatglm_std" + } + if option.TopP == 0 { + option.TopP = 0.7 + } + if option.Temperature == 0 { + option.Temperature = 0.9 + } + } + + if parentMessages == nil { + parentMessages = &[]map[string]interface{}{} + } + messages := append(*parentMessages, map[string]interface{}{ + "role": "user", + "content": prompt, + }) + + return "https://open.bigmodel.cn/api/paas/v3/model-api/" + option.Model + "/", map[string]interface{}{ + "model": option.Model, + "topP": option.TopP, + "temperature": option.Temperature, + "prompt": messages, + } +} + +func generateToken(apikey string) string { + parts := strings.Split(apikey, ".") + if len(parts) != 2 { + return "" + } + id := parts[0] + secret := parts[1] + + payload := jwt.MapClaims{ + "api_key": id, + "exp": time.Now().Add(time.Second * 180).Unix(), + "timestamp": time.Now().Unix(), + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload) + token.Header["alg"] = "HS256" + token.Header["sign_type"] = "SIGN" + + signedToken, err := token.SignedString([]byte(secret)) + if err != nil { + return "" + } + + return signedToken +}