huoshan/chat.go
2024-10-31 15:31:37 +08:00

199 lines
5.1 KiB
Go

package huoshan
import (
"bytes"
"context"
"encoding/binary"
"io"
"strings"
"time"
"apigo.cc/ai"
"github.com/ssgo/u"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
)
func getAKSK(aiConf *ai.AIConfig) (string, string) {
keys := strings.SplitN(aiConf.ApiKey, ",", 2)
if len(keys) == 1 {
keys = append(keys, "")
}
return keys[0], keys[1]
}
func getChatClient(aiConf *ai.AIConfig) *arkruntime.Client {
opt := make([]arkruntime.ConfigOption, 0)
if aiConf.Endpoint != "" {
opt = append(opt, arkruntime.WithBaseUrl(aiConf.Endpoint))
}
if aiConf.Extra["region"] != nil {
opt = append(opt, arkruntime.WithRegion(u.String(aiConf.Extra["region"])))
}
ak, sk := getAKSK(aiConf)
return arkruntime.NewClientWithAkSk(ak, sk, opt...)
}
func Chat(aiConf *ai.AIConfig, messages []ai.ChatMessage, callback func(string), conf ai.ChatConfig) (ai.ChatResult, error) {
req := model.ChatCompletionRequest{
Model: conf.Model,
}
req.Messages = make([]*model.ChatCompletionMessage, len(messages))
for i, msg := range messages {
var contents []*model.ChatCompletionMessageContentPart
if msg.Contents != nil {
contents = make([]*model.ChatCompletionMessageContentPart, len(msg.Contents))
for j, inPart := range msg.Contents {
part := model.ChatCompletionMessageContentPart{}
part.Type = model.ChatCompletionMessageContentPartType(inPart.Type)
switch inPart.Type {
case ai.TypeText:
part.Text = inPart.Content
case ai.TypeImage:
part.ImageURL = &model.ChatMessageImageURL{URL: inPart.Content}
//case ai.TypeVideo:
// part.VideoURL = &model.URLItem{URL: inPart.Content}
}
contents[j] = &part
}
}
if len(contents) == 1 && contents[0].Type == ai.TypeText {
req.Messages[i] = &model.ChatCompletionMessage{
Role: msg.Role,
Content: &model.ChatCompletionMessageContent{
StringValue: &contents[0].Text,
},
}
} else {
req.Messages[i] = &model.ChatCompletionMessage{
Role: msg.Role,
Content: &model.ChatCompletionMessageContent{
ListValue: contents,
},
}
}
}
if conf.SystemPrompt != "" {
req.Messages = append([]*model.ChatCompletionMessage{{
Role: ai.RoleSystem,
Content: &model.ChatCompletionMessageContent{
StringValue: &conf.SystemPrompt,
},
}}, req.Messages...)
}
tools := conf.Tools
if len(tools) > 0 {
req.Tools = make([]*model.Tool, 0)
for name, toolConf := range tools {
switch name {
case ai.ToolFunction:
conf := model.FunctionDefinition{}
u.Convert(toolConf, &conf)
req.Tools = append(req.Tools, &model.Tool{
Type: model.ToolTypeFunction,
Function: &conf,
})
}
}
}
if conf.MaxTokens != 0 {
req.MaxTokens = conf.MaxTokens
}
if conf.Temperature != 0 {
req.Temperature = float32(conf.Temperature)
}
if conf.TopP != 0 {
req.TopP = float32(conf.TopP)
}
c := getChatClient(aiConf)
t1 := time.Now().UnixMilli()
if callback != nil {
stream, err := c.CreateChatCompletionStream(context.Background(), req)
if err != nil {
return ai.ChatResult{}, err
}
results := make([]string, 0)
var outErr error
out := ai.ChatResult{}
for {
recv, err := stream.Recv()
if recv.Usage != nil {
out.AskTokens += int64(recv.Usage.PromptTokens)
out.AnswerTokens += int64(recv.Usage.CompletionTokens)
out.TotalTokens += int64(recv.Usage.TotalTokens)
}
if err == io.EOF {
break
}
if err != nil {
outErr = err
break
}
if len(recv.Choices) > 0 {
for _, ch := range recv.Choices {
text := ch.Delta.Content
results = append(results, text)
callback(text)
}
}
}
stream.Close()
out.UsedTime = time.Now().UnixMilli() - t1
out.Result = strings.Join(results, "")
return out, outErr
} else {
r, err := c.CreateChatCompletion(context.Background(), req)
if err != nil {
return ai.ChatResult{}, err
}
t2 := time.Now().UnixMilli() - t1
results := make([]string, 0)
if r.Choices != nil {
for _, ch := range r.Choices {
results = append(results, *ch.Message.Content.StringValue)
}
}
return ai.ChatResult{
Result: strings.Join(results, ""),
AskTokens: int64(r.Usage.PromptTokens),
AnswerTokens: int64(r.Usage.CompletionTokens),
TotalTokens: int64(r.Usage.TotalTokens),
UsedTime: t2,
}, nil
}
}
func Embedding(aiConf *ai.AIConfig, text string, embeddingConf ai.EmbeddingConfig) (ai.EmbeddingResult, error) {
c := getChatClient(aiConf)
req := model.EmbeddingRequestStrings{
Input: []string{text},
Model: embeddingConf.Model,
}
t1 := time.Now().UnixMilli()
if r, err := c.CreateEmbeddings(context.Background(), req); err == nil {
t2 := time.Now().UnixMilli() - t1
buf := new(bytes.Buffer)
if r.Data != nil {
for _, ch := range r.Data {
for _, v := range ch.Embedding {
_ = binary.Write(buf, binary.LittleEndian, float32(v))
}
}
}
return ai.EmbeddingResult{
Result: buf.Bytes(),
AskTokens: int64(r.Usage.PromptTokens),
AnswerTokens: int64(r.Usage.CompletionTokens),
TotalTokens: int64(r.Usage.TotalTokens),
UsedTime: t2,
}, nil
} else {
return ai.EmbeddingResult{}, err
}
}