plugins/glm/glm.go
2024-03-09 14:19:19 +08:00

143 lines
3.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package glm
import (
"bufio"
"github.com/api-go/plugin"
"github.com/golang-jwt/jwt"
"github.com/ssgo/httpclient"
"github.com/ssgo/u"
"io"
"strings"
"time"
)
var apiKey string
type Option struct {
Model string
TopP float64
Temperature float64
}
func init() {
plugin.Register(plugin.Plugin{
Id: "glm",
Name: "智谱ChatGLM",
ConfigSample: `apiKey: <**encrypted_apiKey**> # 从bigmodel.cn获得的APIKey`,
Init: func(conf map[string]interface{}) {
apiKey = u.String(conf["apiKey"])
},
Objects: map[string]interface{}{
// send 方法说明(注释中需包含方法名称)
// * prompt 提示词
// * parentMessages 历史消息(可选)
// * option 选项可选model[chatglm_lite/chatglm_std/chatglm_pro/chatglm_turbo]topP[0-1]temperature[0-1]
// send return 回答内容
"send": func(prompt string, parentMessages *[]map[string]interface{}, option *Option) (map[string]interface{}, error) {
url, message := makeRequest(prompt, parentMessages, option)
c := httpclient.GetClient(300 * time.Second)
r := c.Post(url+"invoke", message, "Authorization", generateToken(apiKey))
return r.Map(), r.Error
},
// sendAsync 方法说明(注释中需包含方法名称)
// sendAsync callback 回调函数返回参数为event(add|finish)和data
// sendAsync meta 回答内容
"sendAsync": func(callback func(event, data string), prompt string, parentMessages *[]map[string]interface{}, option *Option) (meta map[string]interface{}, err error) {
url, message := makeRequest(prompt, parentMessages, option)
message["stream"] = true
c := httpclient.GetClient(300 * time.Second)
r := c.ManualDo("POST", url+"sse-invoke", message, "Authorization", generateToken(apiKey), "Accept", "text/event-stream")
reader := bufio.NewReader(r.Response.Body)
lastEvent := ""
//lastId := ""
meta = map[string]interface{}{}
for {
buf, _, err := reader.ReadLine()
if err == io.EOF {
break
}
if err != nil {
return meta, err
}
line := string(buf)
pos := strings.IndexByte(line, ':')
if pos != -1 {
k := line[0:pos]
v := line[pos+1:]
switch k {
case "id":
//lastId = v
case "event":
lastEvent = v
case "data":
if v == "" {
v = "\n"
}
callback(lastEvent, v)
case "meta":
meta = u.UnJsonMap(v)
}
}
}
return meta, nil
},
},
})
}
func makeRequest(prompt string, parentMessages *[]map[string]interface{}, option *Option) (url string, message map[string]interface{}) {
if option == nil {
option = &Option{}
if option.Model == "" {
option.Model = "chatglm_std"
}
if option.TopP == 0 {
option.TopP = 0.7
}
if option.Temperature == 0 {
option.Temperature = 0.9
}
}
if parentMessages == nil {
parentMessages = &[]map[string]interface{}{}
}
messages := append(*parentMessages, map[string]interface{}{
"role": "user",
"content": prompt,
})
return "https://open.bigmodel.cn/api/paas/v3/model-api/" + option.Model + "/", map[string]interface{}{
"model": option.Model,
"topP": option.TopP,
"temperature": option.Temperature,
"prompt": messages,
}
}
func generateToken(apikey string) string {
parts := strings.Split(apikey, ".")
if len(parts) != 2 {
return ""
}
id := parts[0]
secret := parts[1]
payload := jwt.MapClaims{
"api_key": id,
"exp": time.Now().Add(time.Second * 180).Unix(),
"timestamp": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
token.Header["alg"] = "HS256"
token.Header["sign_type"] = "SIGN"
signedToken, err := token.SignedString([]byte(secret))
if err != nil {
return ""
}
return signedToken
}