225 lines
6.3 KiB
Go
225 lines
6.3 KiB
Go
package huoshan
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/binary"
|
|
"io"
|
|
"strings"
|
|
"time"
|
|
|
|
"apigo.cc/ai/llm/llm"
|
|
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
|
)
|
|
|
|
func (lm *LLM) FastAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
|
return lm.Ask(messages, llm.ChatConfig{
|
|
Model: ModelDoubaoLite32k,
|
|
}, callback)
|
|
}
|
|
|
|
func (lm *LLM) LongAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
|
return lm.Ask(messages, llm.ChatConfig{
|
|
Model: ModelDoubaoPro256k,
|
|
}, callback)
|
|
}
|
|
|
|
func (lm *LLM) BatterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
|
return lm.Ask(messages, llm.ChatConfig{
|
|
Model: ModelDoubaoPro32k,
|
|
}, callback)
|
|
}
|
|
|
|
func (lm *LLM) BestAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
|
return lm.Ask(messages, llm.ChatConfig{
|
|
Model: ModelDoubaoPro256k,
|
|
}, callback)
|
|
}
|
|
|
|
func (lm *LLM) MultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
|
return lm.Ask(messages, llm.ChatConfig{
|
|
Model: ModelDoubaoLite32k,
|
|
}, callback)
|
|
}
|
|
|
|
func (lm *LLM) BestMultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
|
return lm.Ask(messages, llm.ChatConfig{
|
|
Model: ModelDoubaoPro32k,
|
|
}, callback)
|
|
}
|
|
|
|
func (lm *LLM) CodeInterpreterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
|
return lm.Ask(messages, llm.ChatConfig{
|
|
Model: ModelDoubaoPro32k,
|
|
Tools: map[string]any{llm.ToolCodeInterpreter: nil},
|
|
}, callback)
|
|
}
|
|
|
|
func (lm *LLM) WebSearchAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
|
return lm.Ask(messages, llm.ChatConfig{
|
|
Model: ModelDoubaoPro32k,
|
|
Tools: map[string]any{llm.ToolWebSearch: nil},
|
|
}, callback)
|
|
}
|
|
|
|
func (lm *LLM) Ask(messages []llm.ChatMessage, config llm.ChatConfig, callback func(answer string)) (string, llm.Usage, error) {
|
|
config.SetDefault(&lm.config.ChatConfig)
|
|
|
|
req := model.ChatCompletionRequest{
|
|
Model: config.GetModel(),
|
|
}
|
|
|
|
req.Messages = make([]*model.ChatCompletionMessage, len(messages))
|
|
for i, msg := range messages {
|
|
var contents []*model.ChatCompletionMessageContentPart
|
|
if msg.Contents != nil {
|
|
contents = make([]*model.ChatCompletionMessageContentPart, len(msg.Contents))
|
|
for j, inPart := range msg.Contents {
|
|
part := model.ChatCompletionMessageContentPart{}
|
|
part.Type = model.ChatCompletionMessageContentPartType(NameMap[inPart.Type])
|
|
switch inPart.Type {
|
|
case llm.TypeText:
|
|
part.Text = inPart.Content
|
|
case llm.TypeImage:
|
|
part.ImageURL = &model.ChatMessageImageURL{URL: inPart.Content}
|
|
//case llm.TypeVideo:
|
|
// part.VideoURL = &model.URLItem{URL: inPart.Content}
|
|
}
|
|
contents[j] = &part
|
|
}
|
|
}
|
|
if len(contents) == 1 && contents[0].Type == llm.TypeText {
|
|
req.Messages[i] = &model.ChatCompletionMessage{
|
|
Role: NameMap[msg.Role],
|
|
Content: &model.ChatCompletionMessageContent{
|
|
StringValue: &contents[0].Text,
|
|
},
|
|
}
|
|
} else {
|
|
req.Messages[i] = &model.ChatCompletionMessage{
|
|
Role: NameMap[msg.Role],
|
|
Content: &model.ChatCompletionMessageContent{
|
|
ListValue: contents,
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
// tools := config.GetTools()
|
|
// if len(tools) > 0 {
|
|
// req.Tools = make([]*model.Tool, 0)
|
|
// for name := range tools {
|
|
// switch name {
|
|
// case llm.ToolCodeInterpreter:
|
|
// req.Tools = append(req.Tools, &model.Tool{
|
|
// Type: ,
|
|
// })
|
|
// // cc.AddTool(zhipu.ChatCompletionToolCodeInterpreter{})
|
|
// case llm.ToolWebSearch:
|
|
// // cc.AddTool(zhipu.ChatCompletionToolWebBrowser{})
|
|
// }
|
|
// }
|
|
// }
|
|
if config.GetMaxTokens() != 0 {
|
|
req.MaxTokens = config.GetMaxTokens()
|
|
}
|
|
if config.GetTemperature() != 0 {
|
|
req.Temperature = float32(config.GetTemperature())
|
|
}
|
|
if config.GetTopP() != 0 {
|
|
req.TopP = float32(config.GetTopP())
|
|
}
|
|
|
|
c := lm.getChatClient()
|
|
t1 := time.Now().UnixMilli()
|
|
if callback != nil {
|
|
stream, err := c.CreateChatCompletionStream(context.Background(), req)
|
|
if err != nil {
|
|
return "", llm.Usage{}, err
|
|
}
|
|
out := make([]string, 0)
|
|
var outErr error
|
|
usage := llm.Usage{}
|
|
for {
|
|
recv, err := stream.Recv()
|
|
usage.AskTokens += int64(recv.Usage.PromptTokens)
|
|
usage.AnswerTokens += int64(recv.Usage.CompletionTokens)
|
|
usage.TotalTokens += int64(recv.Usage.TotalTokens)
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
outErr = err
|
|
break
|
|
}
|
|
|
|
if len(recv.Choices) > 0 {
|
|
for _, ch := range recv.Choices {
|
|
text := ch.Delta.Content
|
|
out = append(out, text)
|
|
callback(text)
|
|
}
|
|
}
|
|
}
|
|
stream.Close()
|
|
usage.UsedTime = time.Now().UnixMilli() - t1
|
|
return strings.Join(out, ""), usage, outErr
|
|
} else {
|
|
r, err := c.CreateChatCompletion(context.Background(), req)
|
|
if err != nil {
|
|
return "", llm.Usage{}, err
|
|
}
|
|
t2 := time.Now().UnixMilli() - t1
|
|
results := make([]string, 0)
|
|
if r.Choices != nil {
|
|
for _, ch := range r.Choices {
|
|
results = append(results, *ch.Message.Content.StringValue)
|
|
}
|
|
}
|
|
return strings.Join(results, ""), llm.Usage{
|
|
AskTokens: int64(r.Usage.PromptTokens),
|
|
AnswerTokens: int64(r.Usage.CompletionTokens),
|
|
TotalTokens: int64(r.Usage.TotalTokens),
|
|
UsedTime: t2,
|
|
}, nil
|
|
|
|
}
|
|
}
|
|
|
|
func (lm *LLM) FastEmbedding(text string) ([]byte, llm.Usage, error) {
|
|
return lm.Embedding(text, ModelDoubaoEmbedding)
|
|
}
|
|
|
|
func (lm *LLM) BestEmbedding(text string) ([]byte, llm.Usage, error) {
|
|
return lm.Embedding(text, ModelDoubaoEmbeddingLarge)
|
|
}
|
|
|
|
func (lm *LLM) Embedding(text, modelName string) ([]byte, llm.Usage, error) {
|
|
c := lm.getChatClient()
|
|
// cc := c.Embedding(modelName)
|
|
req := model.EmbeddingRequestStrings{
|
|
Input: []string{text},
|
|
Model: modelName,
|
|
}
|
|
t1 := time.Now().UnixMilli()
|
|
if r, err := c.CreateEmbeddings(context.Background(), req); err == nil {
|
|
t2 := time.Now().UnixMilli() - t1
|
|
buf := new(bytes.Buffer)
|
|
if r.Data != nil {
|
|
for _, ch := range r.Data {
|
|
for _, v := range ch.Embedding {
|
|
_ = binary.Write(buf, binary.LittleEndian, float32(v))
|
|
}
|
|
}
|
|
}
|
|
return buf.Bytes(), llm.Usage{
|
|
AskTokens: int64(r.Usage.PromptTokens),
|
|
AnswerTokens: int64(r.Usage.CompletionTokens),
|
|
TotalTokens: int64(r.Usage.TotalTokens),
|
|
UsedTime: t2,
|
|
}, nil
|
|
} else {
|
|
return nil, llm.Usage{}, err
|
|
}
|
|
}
|