199 lines
5.1 KiB
Go
199 lines
5.1 KiB
Go
|
package huoshan
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"encoding/binary"
|
||
|
"io"
|
||
|
"strings"
|
||
|
"time"
|
||
|
|
||
|
"apigo.cc/ai"
|
||
|
"github.com/ssgo/u"
|
||
|
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
|
||
|
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||
|
)
|
||
|
|
||
|
func getAKSK(aiConf *ai.AIConfig) (string, string) {
|
||
|
keys := strings.SplitN(aiConf.ApiKey, ",", 2)
|
||
|
if len(keys) == 1 {
|
||
|
keys = append(keys, "")
|
||
|
}
|
||
|
return keys[0], keys[1]
|
||
|
}
|
||
|
|
||
|
func getChatClient(aiConf *ai.AIConfig) *arkruntime.Client {
|
||
|
opt := make([]arkruntime.ConfigOption, 0)
|
||
|
if aiConf.Endpoint != "" {
|
||
|
opt = append(opt, arkruntime.WithBaseUrl(aiConf.Endpoint))
|
||
|
}
|
||
|
if aiConf.Extra["region"] != nil {
|
||
|
opt = append(opt, arkruntime.WithRegion(u.String(aiConf.Extra["region"])))
|
||
|
}
|
||
|
ak, sk := getAKSK(aiConf)
|
||
|
return arkruntime.NewClientWithAkSk(ak, sk, opt...)
|
||
|
}
|
||
|
|
||
|
func Chat(aiConf *ai.AIConfig, messages []ai.ChatMessage, callback func(string), conf ai.ChatConfig) (ai.ChatResult, error) {
|
||
|
req := model.ChatCompletionRequest{
|
||
|
Model: conf.Model,
|
||
|
}
|
||
|
|
||
|
req.Messages = make([]*model.ChatCompletionMessage, len(messages))
|
||
|
for i, msg := range messages {
|
||
|
var contents []*model.ChatCompletionMessageContentPart
|
||
|
if msg.Contents != nil {
|
||
|
contents = make([]*model.ChatCompletionMessageContentPart, len(msg.Contents))
|
||
|
for j, inPart := range msg.Contents {
|
||
|
part := model.ChatCompletionMessageContentPart{}
|
||
|
part.Type = model.ChatCompletionMessageContentPartType(inPart.Type)
|
||
|
switch inPart.Type {
|
||
|
case ai.TypeText:
|
||
|
part.Text = inPart.Content
|
||
|
case ai.TypeImage:
|
||
|
part.ImageURL = &model.ChatMessageImageURL{URL: inPart.Content}
|
||
|
//case ai.TypeVideo:
|
||
|
// part.VideoURL = &model.URLItem{URL: inPart.Content}
|
||
|
}
|
||
|
contents[j] = &part
|
||
|
}
|
||
|
}
|
||
|
if len(contents) == 1 && contents[0].Type == ai.TypeText {
|
||
|
req.Messages[i] = &model.ChatCompletionMessage{
|
||
|
Role: msg.Role,
|
||
|
Content: &model.ChatCompletionMessageContent{
|
||
|
StringValue: &contents[0].Text,
|
||
|
},
|
||
|
}
|
||
|
} else {
|
||
|
req.Messages[i] = &model.ChatCompletionMessage{
|
||
|
Role: msg.Role,
|
||
|
Content: &model.ChatCompletionMessageContent{
|
||
|
ListValue: contents,
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
if conf.SystemPrompt != "" {
|
||
|
req.Messages = append([]*model.ChatCompletionMessage{{
|
||
|
Role: ai.RoleSystem,
|
||
|
Content: &model.ChatCompletionMessageContent{
|
||
|
StringValue: &conf.SystemPrompt,
|
||
|
},
|
||
|
}}, req.Messages...)
|
||
|
}
|
||
|
|
||
|
tools := conf.Tools
|
||
|
if len(tools) > 0 {
|
||
|
req.Tools = make([]*model.Tool, 0)
|
||
|
for name, toolConf := range tools {
|
||
|
switch name {
|
||
|
case ai.ToolFunction:
|
||
|
conf := model.FunctionDefinition{}
|
||
|
u.Convert(toolConf, &conf)
|
||
|
req.Tools = append(req.Tools, &model.Tool{
|
||
|
Type: model.ToolTypeFunction,
|
||
|
Function: &conf,
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
if conf.MaxTokens != 0 {
|
||
|
req.MaxTokens = conf.MaxTokens
|
||
|
}
|
||
|
if conf.Temperature != 0 {
|
||
|
req.Temperature = float32(conf.Temperature)
|
||
|
}
|
||
|
if conf.TopP != 0 {
|
||
|
req.TopP = float32(conf.TopP)
|
||
|
}
|
||
|
|
||
|
c := getChatClient(aiConf)
|
||
|
t1 := time.Now().UnixMilli()
|
||
|
if callback != nil {
|
||
|
stream, err := c.CreateChatCompletionStream(context.Background(), req)
|
||
|
if err != nil {
|
||
|
return ai.ChatResult{}, err
|
||
|
}
|
||
|
results := make([]string, 0)
|
||
|
var outErr error
|
||
|
out := ai.ChatResult{}
|
||
|
for {
|
||
|
recv, err := stream.Recv()
|
||
|
if recv.Usage != nil {
|
||
|
out.AskTokens += int64(recv.Usage.PromptTokens)
|
||
|
out.AnswerTokens += int64(recv.Usage.CompletionTokens)
|
||
|
out.TotalTokens += int64(recv.Usage.TotalTokens)
|
||
|
}
|
||
|
if err == io.EOF {
|
||
|
break
|
||
|
}
|
||
|
if err != nil {
|
||
|
outErr = err
|
||
|
break
|
||
|
}
|
||
|
|
||
|
if len(recv.Choices) > 0 {
|
||
|
for _, ch := range recv.Choices {
|
||
|
text := ch.Delta.Content
|
||
|
results = append(results, text)
|
||
|
callback(text)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
stream.Close()
|
||
|
out.UsedTime = time.Now().UnixMilli() - t1
|
||
|
out.Result = strings.Join(results, "")
|
||
|
return out, outErr
|
||
|
} else {
|
||
|
r, err := c.CreateChatCompletion(context.Background(), req)
|
||
|
if err != nil {
|
||
|
return ai.ChatResult{}, err
|
||
|
}
|
||
|
t2 := time.Now().UnixMilli() - t1
|
||
|
results := make([]string, 0)
|
||
|
if r.Choices != nil {
|
||
|
for _, ch := range r.Choices {
|
||
|
results = append(results, *ch.Message.Content.StringValue)
|
||
|
}
|
||
|
}
|
||
|
return ai.ChatResult{
|
||
|
Result: strings.Join(results, ""),
|
||
|
AskTokens: int64(r.Usage.PromptTokens),
|
||
|
AnswerTokens: int64(r.Usage.CompletionTokens),
|
||
|
TotalTokens: int64(r.Usage.TotalTokens),
|
||
|
UsedTime: t2,
|
||
|
}, nil
|
||
|
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func Embedding(aiConf *ai.AIConfig, text string, embeddingConf ai.EmbeddingConfig) (ai.EmbeddingResult, error) {
|
||
|
c := getChatClient(aiConf)
|
||
|
req := model.EmbeddingRequestStrings{
|
||
|
Input: []string{text},
|
||
|
Model: embeddingConf.Model,
|
||
|
}
|
||
|
t1 := time.Now().UnixMilli()
|
||
|
if r, err := c.CreateEmbeddings(context.Background(), req); err == nil {
|
||
|
t2 := time.Now().UnixMilli() - t1
|
||
|
buf := new(bytes.Buffer)
|
||
|
if r.Data != nil {
|
||
|
for _, ch := range r.Data {
|
||
|
for _, v := range ch.Embedding {
|
||
|
_ = binary.Write(buf, binary.LittleEndian, float32(v))
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return ai.EmbeddingResult{
|
||
|
Result: buf.Bytes(),
|
||
|
AskTokens: int64(r.Usage.PromptTokens),
|
||
|
AnswerTokens: int64(r.Usage.CompletionTokens),
|
||
|
TotalTokens: int64(r.Usage.TotalTokens),
|
||
|
UsedTime: t2,
|
||
|
}, nil
|
||
|
} else {
|
||
|
return ai.EmbeddingResult{}, err
|
||
|
}
|
||
|
}
|