143 lines
3.8 KiB
Go
143 lines
3.8 KiB
Go
package glm
|
||
|
||
import (
|
||
"bufio"
|
||
"github.com/api-go/plugin"
|
||
"github.com/golang-jwt/jwt"
|
||
"github.com/ssgo/httpclient"
|
||
"github.com/ssgo/u"
|
||
"io"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
var apiKey string
|
||
|
||
type Option struct {
|
||
Model string
|
||
TopP float64
|
||
Temperature float64
|
||
}
|
||
|
||
func init() {
|
||
plugin.Register(plugin.Plugin{
|
||
Id: "glm",
|
||
Name: "智谱ChatGLM",
|
||
ConfigSample: `apiKey: <**encrypted_apiKey**> # 从bigmodel.cn获得的APIKey`,
|
||
Init: func(conf map[string]interface{}) {
|
||
apiKey = u.String(conf["apiKey"])
|
||
},
|
||
Objects: map[string]interface{}{
|
||
// send 方法说明(注释中需包含方法名称)
|
||
// * prompt 提示词
|
||
// * parentMessages 历史消息(可选)
|
||
// * option 选项(可选,model[chatglm_lite/chatglm_std/chatglm_pro/chatglm_turbo],topP[0-1],temperature[0-1])
|
||
// send return 回答内容
|
||
"send": func(prompt string, parentMessages *[]map[string]interface{}, option *Option) (map[string]interface{}, error) {
|
||
url, message := makeRequest(prompt, parentMessages, option)
|
||
c := httpclient.GetClient(300 * time.Second)
|
||
r := c.Post(url+"invoke", message, "Authorization", generateToken(apiKey))
|
||
return r.Map(), r.Error
|
||
},
|
||
// sendAsync 方法说明(注释中需包含方法名称)
|
||
// sendAsync callback 回调函数,返回参数为event(add|finish)和data
|
||
// sendAsync meta 回答内容
|
||
"sendAsync": func(callback func(event, data string), prompt string, parentMessages *[]map[string]interface{}, option *Option) (meta map[string]interface{}, err error) {
|
||
url, message := makeRequest(prompt, parentMessages, option)
|
||
message["stream"] = true
|
||
c := httpclient.GetClient(300 * time.Second)
|
||
r := c.ManualDo("POST", url+"sse-invoke", message, "Authorization", generateToken(apiKey), "Accept", "text/event-stream")
|
||
reader := bufio.NewReader(r.Response.Body)
|
||
lastEvent := ""
|
||
//lastId := ""
|
||
meta = map[string]interface{}{}
|
||
for {
|
||
buf, _, err := reader.ReadLine()
|
||
if err == io.EOF {
|
||
break
|
||
}
|
||
if err != nil {
|
||
return meta, err
|
||
}
|
||
line := string(buf)
|
||
pos := strings.IndexByte(line, ':')
|
||
if pos != -1 {
|
||
k := line[0:pos]
|
||
v := line[pos+1:]
|
||
switch k {
|
||
case "id":
|
||
//lastId = v
|
||
case "event":
|
||
lastEvent = v
|
||
case "data":
|
||
if v == "" {
|
||
v = "\n"
|
||
}
|
||
callback(lastEvent, v)
|
||
case "meta":
|
||
meta = u.UnJsonMap(v)
|
||
}
|
||
}
|
||
}
|
||
return meta, nil
|
||
},
|
||
},
|
||
})
|
||
}
|
||
|
||
func makeRequest(prompt string, parentMessages *[]map[string]interface{}, option *Option) (url string, message map[string]interface{}) {
|
||
if option == nil {
|
||
option = &Option{}
|
||
if option.Model == "" {
|
||
option.Model = "chatglm_std"
|
||
}
|
||
if option.TopP == 0 {
|
||
option.TopP = 0.7
|
||
}
|
||
if option.Temperature == 0 {
|
||
option.Temperature = 0.9
|
||
}
|
||
}
|
||
|
||
if parentMessages == nil {
|
||
parentMessages = &[]map[string]interface{}{}
|
||
}
|
||
messages := append(*parentMessages, map[string]interface{}{
|
||
"role": "user",
|
||
"content": prompt,
|
||
})
|
||
|
||
return "https://open.bigmodel.cn/api/paas/v3/model-api/" + option.Model + "/", map[string]interface{}{
|
||
"model": option.Model,
|
||
"topP": option.TopP,
|
||
"temperature": option.Temperature,
|
||
"prompt": messages,
|
||
}
|
||
}
|
||
|
||
func generateToken(apikey string) string {
|
||
parts := strings.Split(apikey, ".")
|
||
if len(parts) != 2 {
|
||
return ""
|
||
}
|
||
id := parts[0]
|
||
secret := parts[1]
|
||
|
||
payload := jwt.MapClaims{
|
||
"api_key": id,
|
||
"exp": time.Now().Add(time.Second * 180).Unix(),
|
||
"timestamp": time.Now().Unix(),
|
||
}
|
||
|
||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
|
||
token.Header["alg"] = "HS256"
|
||
token.Header["sign_type"] = "SIGN"
|
||
|
||
signedToken, err := token.SignedString([]byte(secret))
|
||
if err != nil {
|
||
return ""
|
||
}
|
||
|
||
return signedToken
|
||
}
|