unfinished llm
This commit is contained in:
parent
b8641271be
commit
ee2775d7fa
224
llm_unfinished/chat.go
Normal file
224
llm_unfinished/chat.go
Normal file
@ -0,0 +1,224 @@
|
||||
package huoshan
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"apigo.cc/ai/llm/llm"
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
)
|
||||
|
||||
func (lm *LLM) FastAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
||||
return lm.Ask(messages, llm.ChatConfig{
|
||||
Model: ModelDoubaoLite32k,
|
||||
}, callback)
|
||||
}
|
||||
|
||||
func (lm *LLM) LongAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
||||
return lm.Ask(messages, llm.ChatConfig{
|
||||
Model: ModelDoubaoPro256k,
|
||||
}, callback)
|
||||
}
|
||||
|
||||
func (lm *LLM) BatterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
||||
return lm.Ask(messages, llm.ChatConfig{
|
||||
Model: ModelDoubaoPro32k,
|
||||
}, callback)
|
||||
}
|
||||
|
||||
func (lm *LLM) BestAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
||||
return lm.Ask(messages, llm.ChatConfig{
|
||||
Model: ModelDoubaoPro256k,
|
||||
}, callback)
|
||||
}
|
||||
|
||||
func (lm *LLM) MultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
||||
return lm.Ask(messages, llm.ChatConfig{
|
||||
Model: ModelDoubaoLite32k,
|
||||
}, callback)
|
||||
}
|
||||
|
||||
func (lm *LLM) BestMultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
||||
return lm.Ask(messages, llm.ChatConfig{
|
||||
Model: ModelDoubaoPro32k,
|
||||
}, callback)
|
||||
}
|
||||
|
||||
func (lm *LLM) CodeInterpreterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
||||
return lm.Ask(messages, llm.ChatConfig{
|
||||
Model: ModelDoubaoPro32k,
|
||||
Tools: map[string]any{llm.ToolCodeInterpreter: nil},
|
||||
}, callback)
|
||||
}
|
||||
|
||||
func (lm *LLM) WebSearchAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) {
|
||||
return lm.Ask(messages, llm.ChatConfig{
|
||||
Model: ModelDoubaoPro32k,
|
||||
Tools: map[string]any{llm.ToolWebSearch: nil},
|
||||
}, callback)
|
||||
}
|
||||
|
||||
func (lm *LLM) Ask(messages []llm.ChatMessage, config llm.ChatConfig, callback func(answer string)) (string, llm.Usage, error) {
|
||||
config.SetDefault(&lm.config.ChatConfig)
|
||||
|
||||
req := model.ChatCompletionRequest{
|
||||
Model: config.GetModel(),
|
||||
}
|
||||
|
||||
req.Messages = make([]*model.ChatCompletionMessage, len(messages))
|
||||
for i, msg := range messages {
|
||||
var contents []*model.ChatCompletionMessageContentPart
|
||||
if msg.Contents != nil {
|
||||
contents = make([]*model.ChatCompletionMessageContentPart, len(msg.Contents))
|
||||
for j, inPart := range msg.Contents {
|
||||
part := model.ChatCompletionMessageContentPart{}
|
||||
part.Type = model.ChatCompletionMessageContentPartType(NameMap[inPart.Type])
|
||||
switch inPart.Type {
|
||||
case llm.TypeText:
|
||||
part.Text = inPart.Content
|
||||
case llm.TypeImage:
|
||||
part.ImageURL = &model.ChatMessageImageURL{URL: inPart.Content}
|
||||
//case llm.TypeVideo:
|
||||
// part.VideoURL = &model.URLItem{URL: inPart.Content}
|
||||
}
|
||||
contents[j] = &part
|
||||
}
|
||||
}
|
||||
if len(contents) == 1 && contents[0].Type == llm.TypeText {
|
||||
req.Messages[i] = &model.ChatCompletionMessage{
|
||||
Role: NameMap[msg.Role],
|
||||
Content: &model.ChatCompletionMessageContent{
|
||||
StringValue: &contents[0].Text,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
req.Messages[i] = &model.ChatCompletionMessage{
|
||||
Role: NameMap[msg.Role],
|
||||
Content: &model.ChatCompletionMessageContent{
|
||||
ListValue: contents,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tools := config.GetTools()
|
||||
// if len(tools) > 0 {
|
||||
// req.Tools = make([]*model.Tool, 0)
|
||||
// for name := range tools {
|
||||
// switch name {
|
||||
// case llm.ToolCodeInterpreter:
|
||||
// req.Tools = append(req.Tools, &model.Tool{
|
||||
// Type: ,
|
||||
// })
|
||||
// // cc.AddTool(zhipu.ChatCompletionToolCodeInterpreter{})
|
||||
// case llm.ToolWebSearch:
|
||||
// // cc.AddTool(zhipu.ChatCompletionToolWebBrowser{})
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
if config.GetMaxTokens() != 0 {
|
||||
req.MaxTokens = config.GetMaxTokens()
|
||||
}
|
||||
if config.GetTemperature() != 0 {
|
||||
req.Temperature = float32(config.GetTemperature())
|
||||
}
|
||||
if config.GetTopP() != 0 {
|
||||
req.TopP = float32(config.GetTopP())
|
||||
}
|
||||
|
||||
c := lm.getChatClient()
|
||||
t1 := time.Now().UnixMilli()
|
||||
if callback != nil {
|
||||
stream, err := c.CreateChatCompletionStream(context.Background(), req)
|
||||
if err != nil {
|
||||
return "", llm.Usage{}, err
|
||||
}
|
||||
out := make([]string, 0)
|
||||
var outErr error
|
||||
usage := llm.Usage{}
|
||||
for {
|
||||
recv, err := stream.Recv()
|
||||
usage.AskTokens += int64(recv.Usage.PromptTokens)
|
||||
usage.AnswerTokens += int64(recv.Usage.CompletionTokens)
|
||||
usage.TotalTokens += int64(recv.Usage.TotalTokens)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
outErr = err
|
||||
break
|
||||
}
|
||||
|
||||
if len(recv.Choices) > 0 {
|
||||
for _, ch := range recv.Choices {
|
||||
text := ch.Delta.Content
|
||||
out = append(out, text)
|
||||
callback(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
stream.Close()
|
||||
usage.UsedTime = time.Now().UnixMilli() - t1
|
||||
return strings.Join(out, ""), usage, outErr
|
||||
} else {
|
||||
r, err := c.CreateChatCompletion(context.Background(), req)
|
||||
if err != nil {
|
||||
return "", llm.Usage{}, err
|
||||
}
|
||||
t2 := time.Now().UnixMilli() - t1
|
||||
results := make([]string, 0)
|
||||
if r.Choices != nil {
|
||||
for _, ch := range r.Choices {
|
||||
results = append(results, *ch.Message.Content.StringValue)
|
||||
}
|
||||
}
|
||||
return strings.Join(results, ""), llm.Usage{
|
||||
AskTokens: int64(r.Usage.PromptTokens),
|
||||
AnswerTokens: int64(r.Usage.CompletionTokens),
|
||||
TotalTokens: int64(r.Usage.TotalTokens),
|
||||
UsedTime: t2,
|
||||
}, nil
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func (lm *LLM) FastEmbedding(text string) ([]byte, llm.Usage, error) {
|
||||
return lm.Embedding(text, ModelDoubaoEmbedding)
|
||||
}
|
||||
|
||||
func (lm *LLM) BestEmbedding(text string) ([]byte, llm.Usage, error) {
|
||||
return lm.Embedding(text, ModelDoubaoEmbeddingLarge)
|
||||
}
|
||||
|
||||
func (lm *LLM) Embedding(text, modelName string) ([]byte, llm.Usage, error) {
|
||||
c := lm.getChatClient()
|
||||
// cc := c.Embedding(modelName)
|
||||
req := model.EmbeddingRequestStrings{
|
||||
Input: []string{text},
|
||||
Model: modelName,
|
||||
}
|
||||
t1 := time.Now().UnixMilli()
|
||||
if r, err := c.CreateEmbeddings(context.Background(), req); err == nil {
|
||||
t2 := time.Now().UnixMilli() - t1
|
||||
buf := new(bytes.Buffer)
|
||||
if r.Data != nil {
|
||||
for _, ch := range r.Data {
|
||||
for _, v := range ch.Embedding {
|
||||
_ = binary.Write(buf, binary.LittleEndian, float32(v))
|
||||
}
|
||||
}
|
||||
}
|
||||
return buf.Bytes(), llm.Usage{
|
||||
AskTokens: int64(r.Usage.PromptTokens),
|
||||
AnswerTokens: int64(r.Usage.CompletionTokens),
|
||||
TotalTokens: int64(r.Usage.TotalTokens),
|
||||
UsedTime: t2,
|
||||
}, nil
|
||||
} else {
|
||||
return nil, llm.Usage{}, err
|
||||
}
|
||||
}
|
96
llm_unfinished/config.go
Normal file
96
llm_unfinished/config.go
Normal file
@ -0,0 +1,96 @@
|
||||
package huoshan
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"apigo.cc/ai/llm/llm"
|
||||
"github.com/volcengine/volc-sdk-golang/service/visual"
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
)
|
||||
|
||||
type LLM struct {
|
||||
config llm.Config
|
||||
}
|
||||
|
||||
var NameMap = map[string]string{
|
||||
llm.TypeText: string(model.ChatCompletionMessageContentPartTypeText),
|
||||
llm.TypeImage: string(model.ChatCompletionMessageContentPartTypeImageURL),
|
||||
//llm.TypeVideo: string(model.ChatCompletionMessageContentPartTypeVideoURL),
|
||||
llm.RoleSystem: model.ChatMessageRoleSystem,
|
||||
llm.RoleUser: model.ChatMessageRoleUser,
|
||||
llm.RoleAssistant: model.ChatMessageRoleAssistant,
|
||||
llm.RoleTool: model.ChatMessageRoleTool,
|
||||
}
|
||||
|
||||
const (
|
||||
ModelDoubaoLite4k = "Doubao-lite-4k"
|
||||
ModelDoubaoLite32k = "Doubao-lite-32k"
|
||||
ModelDoubaoLite128k = "Doubao-lite-128k"
|
||||
ModelDoubaoPro4k = "Doubao-pro-4k"
|
||||
ModelDoubaoPro32k = "Doubao-pro-32k"
|
||||
ModelDoubaoPro128k = "Doubao-pro-128k"
|
||||
ModelDoubaoPro256k = "Doubao-pro-256k"
|
||||
ModelDoubaoEmbedding = "Doubao-embedding"
|
||||
ModelDoubaoEmbeddingLarge = "Doubao-embedding-large"
|
||||
ModelT2I2L = "high_aes_general_v20_L:general_v2.0_L"
|
||||
ModelT2I2S = "high_aes_general_v20:general_v2.0"
|
||||
ModelT2IXL = "t2i_xl_sft"
|
||||
ModelI2IXL = "i2i_xl_sft"
|
||||
ModelT2I14 = "high_aes_general_v14"
|
||||
ModelI2I14IP = "high_aes_general_v14_ip_keep"
|
||||
ModelAnime13 = "high_aes:anime_v1.3"
|
||||
ModelAnime131 = "high_aes:anime_v1.3.1"
|
||||
ModelPhotoverseAmericanComics = "img2img_photoverse_american_comics" // 美漫风格
|
||||
ModelPhotoverseExecutiveIDPhoto = "img2img_photoverse_executive_ID_photo" // 商务证件照
|
||||
ModelPhotoverse3dWeird = "img2img_photoverse_3d_weird" // 3d人偶
|
||||
ModelPhotoverseCyberpunk = "img2img_photoverse_cyberpunk" // 赛博朋克
|
||||
ModelXiezhenGubao = "img2img_xiezhen_gubao" // 古堡
|
||||
ModelXiezhenBabiNiuzai = "img2img_xiezhen_babi_niuzai" // 芭比牛仔
|
||||
ModelXiezhenBathrobe = "img2img_xiezhen_bathrobe" // 浴袍风格
|
||||
ModelXiezhenButterflyMachin = "img2img_xiezhen_butterfly_machin" // 蝴蝶机械
|
||||
ModelXiezhenZhichangzhengjianzhao = "img2img_xiezhen_zhichangzhengjianzhao" // 职场证件照
|
||||
ModelXiezhenChristmas = "img2img_xiezhen_christmas" // 圣诞
|
||||
ModelXiezhenDessert = "img2img_xiezhen_dessert" // 美式甜点师
|
||||
ModelXiezhenOldMoney = "img2img_xiezhen_old_money" // old money
|
||||
ModelXiezhenSchool = "img2img_xiezhen_school" // 最美校园
|
||||
)
|
||||
|
||||
func (lm *LLM) Support() llm.Support {
|
||||
return llm.Support{
|
||||
Ask: true,
|
||||
AskWithImage: true,
|
||||
AskWithVideo: false,
|
||||
AskWithCodeInterpreter: false,
|
||||
AskWithWebSearch: false,
|
||||
MakeImage: true,
|
||||
MakeVideo: false,
|
||||
Models: []string{ModelDoubaoLite4k, ModelDoubaoLite32k, ModelDoubaoLite128k, ModelDoubaoPro4k, ModelDoubaoPro32k, ModelDoubaoPro128k, ModelDoubaoPro256k, ModelDoubaoEmbedding, ModelDoubaoEmbeddingLarge, ModelT2I2L, ModelT2I2S, ModelT2IXL, ModelI2IXL, ModelT2I14, ModelI2I14IP, ModelAnime13, ModelAnime131, ModelPhotoverseAmericanComics, ModelPhotoverseExecutiveIDPhoto, ModelPhotoverse3dWeird, ModelPhotoverseCyberpunk, ModelXiezhenGubao, ModelXiezhenBabiNiuzai, ModelXiezhenBathrobe, ModelXiezhenButterflyMachin, ModelXiezhenZhichangzhengjianzhao, ModelXiezhenChristmas, ModelXiezhenDessert, ModelXiezhenOldMoney, ModelXiezhenSchool},
|
||||
}
|
||||
}
|
||||
|
||||
func (lm *LLM) getChatClient() *arkruntime.Client {
|
||||
opt := make([]arkruntime.ConfigOption, 0)
|
||||
if lm.config.Endpoint != "" {
|
||||
opt = append(opt, arkruntime.WithBaseUrl(lm.config.Endpoint))
|
||||
}
|
||||
return arkruntime.NewClientWithAkSk(strings.SplitN(lm.config.ApiKey, ",", 2)[0], opt...)
|
||||
}
|
||||
|
||||
func (lm *LLM) getGCClient() *visual.Visual {
|
||||
keys := strings.SplitN(lm.config.ApiKey, ",", 2)
|
||||
if len(keys) == 1 {
|
||||
keys = append(keys, "")
|
||||
}
|
||||
vis := visual.NewInstance()
|
||||
vis.Client.SetAccessKey(keys[0])
|
||||
vis.Client.SetSecretKey(keys[1])
|
||||
return vis
|
||||
}
|
||||
|
||||
// 因为火山平台的配置过于繁琐(每个模型都要创建单独的endpoint,所以暂时放弃对豆包大模型的支持)
|
||||
// func init() {
|
||||
// llm.Register("huoshan", func(config llm.Config) llm.LLM {
|
||||
// return &LLM{config: config}
|
||||
// })
|
||||
// }
|
87
llm_unfinished/gc.go
Normal file
87
llm_unfinished/gc.go
Normal file
@ -0,0 +1,87 @@
|
||||
package huoshan
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"apigo.cc/ai/llm/llm"
|
||||
"github.com/ssgo/u"
|
||||
"github.com/volcengine/volc-sdk-golang/service/visual/model"
|
||||
)
|
||||
|
||||
func (lm *LLM) FastMakeImage(prompt string, config llm.GCConfig) ([]string, llm.Usage, error) {
|
||||
config.Model = ModelT2I14
|
||||
if config.Ref != "" {
|
||||
config.Model = ModelI2I14IP
|
||||
}
|
||||
return lm.MakeImage(prompt, config)
|
||||
}
|
||||
|
||||
func (lm *LLM) BestMakeImage(prompt string, config llm.GCConfig) ([]string, llm.Usage, error) {
|
||||
config.Model = ModelT2IXL
|
||||
if config.Ref != "" {
|
||||
config.Model = ModelI2IXL
|
||||
}
|
||||
return lm.MakeImage(prompt, config)
|
||||
}
|
||||
|
||||
func (lm *LLM) MakeImage(prompt string, config llm.GCConfig) ([]string, llm.Usage, error) {
|
||||
config.SetDefault(&lm.config.GCConfig)
|
||||
modelA := strings.SplitN(config.GetModel(), ":", 2)
|
||||
sizeA := strings.SplitN(config.GetSize(), "x", 2)
|
||||
if len(sizeA) == 1 {
|
||||
sizeA = append(sizeA, sizeA[0])
|
||||
}
|
||||
ref := config.GetRef()
|
||||
vis := lm.getGCClient()
|
||||
data := map[string]any{
|
||||
"req_key": modelA[0],
|
||||
"prompt": prompt,
|
||||
"width": u.Int(sizeA[0]),
|
||||
"height": u.Int(sizeA[1]),
|
||||
"return_url": true,
|
||||
}
|
||||
if len(modelA) > 1 {
|
||||
data["model_version"] = modelA[1]
|
||||
}
|
||||
// TODO llm 支持动态额外参数
|
||||
|
||||
t1 := time.Now().UnixMilli()
|
||||
var resp *model.VisualPubResult
|
||||
var status int
|
||||
var err error
|
||||
if ref == "" {
|
||||
resp, status, err = vis.Text2ImgXLSft(data)
|
||||
} else {
|
||||
if strings.Contains(ref, "://") {
|
||||
data["image_url"] = []string{ref}
|
||||
} else {
|
||||
data["binary_data_base64"] = []string{ref}
|
||||
}
|
||||
resp, status, err = vis.Img2ImgXLSft(data)
|
||||
}
|
||||
t2 := time.Now().UnixMilli() - t1
|
||||
|
||||
if err != nil {
|
||||
return nil, llm.Usage{}, err
|
||||
}
|
||||
if status != 200 {
|
||||
return nil, llm.Usage{}, errors.New(resp.Message)
|
||||
}
|
||||
return resp.Data.ImageUrls, llm.Usage{
|
||||
UsedTime: t2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (lm *LLM) FastMakeVideo(prompt string, config llm.GCConfig) ([]string, []string, llm.Usage, error) {
|
||||
return lm.MakeVideo(prompt, config)
|
||||
}
|
||||
|
||||
func (lm *LLM) BestMakeVideo(prompt string, config llm.GCConfig) ([]string, []string, llm.Usage, error) {
|
||||
return lm.MakeVideo(prompt, config)
|
||||
}
|
||||
|
||||
func (lm *LLM) MakeVideo(prompt string, config llm.GCConfig) ([]string, []string, llm.Usage, error) {
|
||||
return nil, nil, llm.Usage{}, errors.New("not support")
|
||||
}
|
Loading…
Reference in New Issue
Block a user