diff --git a/llm_unfinished/chat.go b/llm_unfinished/chat.go new file mode 100644 index 0000000..6413e0f --- /dev/null +++ b/llm_unfinished/chat.go @@ -0,0 +1,224 @@ +package huoshan + +import ( + "bytes" + "context" + "encoding/binary" + "io" + "strings" + "time" + + "apigo.cc/ai/llm/llm" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" +) + +func (lm *LLM) FastAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) { + return lm.Ask(messages, llm.ChatConfig{ + Model: ModelDoubaoLite32k, + }, callback) +} + +func (lm *LLM) LongAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) { + return lm.Ask(messages, llm.ChatConfig{ + Model: ModelDoubaoPro256k, + }, callback) +} + +func (lm *LLM) BatterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) { + return lm.Ask(messages, llm.ChatConfig{ + Model: ModelDoubaoPro32k, + }, callback) +} + +func (lm *LLM) BestAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) { + return lm.Ask(messages, llm.ChatConfig{ + Model: ModelDoubaoPro256k, + }, callback) +} + +func (lm *LLM) MultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) { + return lm.Ask(messages, llm.ChatConfig{ + Model: ModelDoubaoLite32k, + }, callback) +} + +func (lm *LLM) BestMultiAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) { + return lm.Ask(messages, llm.ChatConfig{ + Model: ModelDoubaoPro32k, + }, callback) +} + +func (lm *LLM) CodeInterpreterAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) { + return lm.Ask(messages, llm.ChatConfig{ + Model: ModelDoubaoPro32k, + Tools: map[string]any{llm.ToolCodeInterpreter: nil}, + }, callback) +} + +func (lm *LLM) WebSearchAsk(messages []llm.ChatMessage, callback func(answer string)) (string, llm.Usage, error) { + return lm.Ask(messages, llm.ChatConfig{ + Model: ModelDoubaoPro32k, + Tools: map[string]any{llm.ToolWebSearch: nil}, + }, callback) +} + +func (lm *LLM) Ask(messages []llm.ChatMessage, config llm.ChatConfig, callback func(answer string)) (string, llm.Usage, error) { + config.SetDefault(&lm.config.ChatConfig) + + req := model.ChatCompletionRequest{ + Model: config.GetModel(), + } + + req.Messages = make([]*model.ChatCompletionMessage, len(messages)) + for i, msg := range messages { + var contents []*model.ChatCompletionMessageContentPart + if msg.Contents != nil { + contents = make([]*model.ChatCompletionMessageContentPart, len(msg.Contents)) + for j, inPart := range msg.Contents { + part := model.ChatCompletionMessageContentPart{} + part.Type = model.ChatCompletionMessageContentPartType(NameMap[inPart.Type]) + switch inPart.Type { + case llm.TypeText: + part.Text = inPart.Content + case llm.TypeImage: + part.ImageURL = &model.ChatMessageImageURL{URL: inPart.Content} + //case llm.TypeVideo: + // part.VideoURL = &model.URLItem{URL: inPart.Content} + } + contents[j] = &part + } + } + if len(contents) == 1 && contents[0].Type == llm.TypeText { + req.Messages[i] = &model.ChatCompletionMessage{ + Role: NameMap[msg.Role], + Content: &model.ChatCompletionMessageContent{ + StringValue: &contents[0].Text, + }, + } + } else { + req.Messages[i] = &model.ChatCompletionMessage{ + Role: NameMap[msg.Role], + Content: &model.ChatCompletionMessageContent{ + ListValue: contents, + }, + } + } + } + + // tools := config.GetTools() + // if len(tools) > 0 { + // req.Tools = make([]*model.Tool, 0) + // for name := range tools { + // switch name { + // case llm.ToolCodeInterpreter: + // req.Tools = append(req.Tools, &model.Tool{ + // Type: , + // }) + // // cc.AddTool(zhipu.ChatCompletionToolCodeInterpreter{}) + // case llm.ToolWebSearch: + // // cc.AddTool(zhipu.ChatCompletionToolWebBrowser{}) + // } + // } + // } + if config.GetMaxTokens() != 0 { + req.MaxTokens = config.GetMaxTokens() + } + if config.GetTemperature() != 0 { + req.Temperature = float32(config.GetTemperature()) + } + if config.GetTopP() != 0 { + req.TopP = float32(config.GetTopP()) + } + + c := lm.getChatClient() + t1 := time.Now().UnixMilli() + if callback != nil { + stream, err := c.CreateChatCompletionStream(context.Background(), req) + if err != nil { + return "", llm.Usage{}, err + } + out := make([]string, 0) + var outErr error + usage := llm.Usage{} + for { + recv, err := stream.Recv() + usage.AskTokens += int64(recv.Usage.PromptTokens) + usage.AnswerTokens += int64(recv.Usage.CompletionTokens) + usage.TotalTokens += int64(recv.Usage.TotalTokens) + if err == io.EOF { + break + } + if err != nil { + outErr = err + break + } + + if len(recv.Choices) > 0 { + for _, ch := range recv.Choices { + text := ch.Delta.Content + out = append(out, text) + callback(text) + } + } + } + stream.Close() + usage.UsedTime = time.Now().UnixMilli() - t1 + return strings.Join(out, ""), usage, outErr + } else { + r, err := c.CreateChatCompletion(context.Background(), req) + if err != nil { + return "", llm.Usage{}, err + } + t2 := time.Now().UnixMilli() - t1 + results := make([]string, 0) + if r.Choices != nil { + for _, ch := range r.Choices { + results = append(results, *ch.Message.Content.StringValue) + } + } + return strings.Join(results, ""), llm.Usage{ + AskTokens: int64(r.Usage.PromptTokens), + AnswerTokens: int64(r.Usage.CompletionTokens), + TotalTokens: int64(r.Usage.TotalTokens), + UsedTime: t2, + }, nil + + } +} + +func (lm *LLM) FastEmbedding(text string) ([]byte, llm.Usage, error) { + return lm.Embedding(text, ModelDoubaoEmbedding) +} + +func (lm *LLM) BestEmbedding(text string) ([]byte, llm.Usage, error) { + return lm.Embedding(text, ModelDoubaoEmbeddingLarge) +} + +func (lm *LLM) Embedding(text, modelName string) ([]byte, llm.Usage, error) { + c := lm.getChatClient() + // cc := c.Embedding(modelName) + req := model.EmbeddingRequestStrings{ + Input: []string{text}, + Model: modelName, + } + t1 := time.Now().UnixMilli() + if r, err := c.CreateEmbeddings(context.Background(), req); err == nil { + t2 := time.Now().UnixMilli() - t1 + buf := new(bytes.Buffer) + if r.Data != nil { + for _, ch := range r.Data { + for _, v := range ch.Embedding { + _ = binary.Write(buf, binary.LittleEndian, float32(v)) + } + } + } + return buf.Bytes(), llm.Usage{ + AskTokens: int64(r.Usage.PromptTokens), + AnswerTokens: int64(r.Usage.CompletionTokens), + TotalTokens: int64(r.Usage.TotalTokens), + UsedTime: t2, + }, nil + } else { + return nil, llm.Usage{}, err + } +} diff --git a/llm_unfinished/config.go b/llm_unfinished/config.go new file mode 100644 index 0000000..0b6e583 --- /dev/null +++ b/llm_unfinished/config.go @@ -0,0 +1,96 @@ +package huoshan + +import ( + "strings" + + "apigo.cc/ai/llm/llm" + "github.com/volcengine/volc-sdk-golang/service/visual" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" +) + +type LLM struct { + config llm.Config +} + +var NameMap = map[string]string{ + llm.TypeText: string(model.ChatCompletionMessageContentPartTypeText), + llm.TypeImage: string(model.ChatCompletionMessageContentPartTypeImageURL), + //llm.TypeVideo: string(model.ChatCompletionMessageContentPartTypeVideoURL), + llm.RoleSystem: model.ChatMessageRoleSystem, + llm.RoleUser: model.ChatMessageRoleUser, + llm.RoleAssistant: model.ChatMessageRoleAssistant, + llm.RoleTool: model.ChatMessageRoleTool, +} + +const ( + ModelDoubaoLite4k = "Doubao-lite-4k" + ModelDoubaoLite32k = "Doubao-lite-32k" + ModelDoubaoLite128k = "Doubao-lite-128k" + ModelDoubaoPro4k = "Doubao-pro-4k" + ModelDoubaoPro32k = "Doubao-pro-32k" + ModelDoubaoPro128k = "Doubao-pro-128k" + ModelDoubaoPro256k = "Doubao-pro-256k" + ModelDoubaoEmbedding = "Doubao-embedding" + ModelDoubaoEmbeddingLarge = "Doubao-embedding-large" + ModelT2I2L = "high_aes_general_v20_L:general_v2.0_L" + ModelT2I2S = "high_aes_general_v20:general_v2.0" + ModelT2IXL = "t2i_xl_sft" + ModelI2IXL = "i2i_xl_sft" + ModelT2I14 = "high_aes_general_v14" + ModelI2I14IP = "high_aes_general_v14_ip_keep" + ModelAnime13 = "high_aes:anime_v1.3" + ModelAnime131 = "high_aes:anime_v1.3.1" + ModelPhotoverseAmericanComics = "img2img_photoverse_american_comics" // 美漫风格 + ModelPhotoverseExecutiveIDPhoto = "img2img_photoverse_executive_ID_photo" // 商务证件照 + ModelPhotoverse3dWeird = "img2img_photoverse_3d_weird" // 3d人偶 + ModelPhotoverseCyberpunk = "img2img_photoverse_cyberpunk" // 赛博朋克 + ModelXiezhenGubao = "img2img_xiezhen_gubao" // 古堡 + ModelXiezhenBabiNiuzai = "img2img_xiezhen_babi_niuzai" // 芭比牛仔 + ModelXiezhenBathrobe = "img2img_xiezhen_bathrobe" // 浴袍风格 + ModelXiezhenButterflyMachin = "img2img_xiezhen_butterfly_machin" // 蝴蝶机械 + ModelXiezhenZhichangzhengjianzhao = "img2img_xiezhen_zhichangzhengjianzhao" // 职场证件照 + ModelXiezhenChristmas = "img2img_xiezhen_christmas" // 圣诞 + ModelXiezhenDessert = "img2img_xiezhen_dessert" // 美式甜点师 + ModelXiezhenOldMoney = "img2img_xiezhen_old_money" // old money + ModelXiezhenSchool = "img2img_xiezhen_school" // 最美校园 +) + +func (lm *LLM) Support() llm.Support { + return llm.Support{ + Ask: true, + AskWithImage: true, + AskWithVideo: false, + AskWithCodeInterpreter: false, + AskWithWebSearch: false, + MakeImage: true, + MakeVideo: false, + Models: []string{ModelDoubaoLite4k, ModelDoubaoLite32k, ModelDoubaoLite128k, ModelDoubaoPro4k, ModelDoubaoPro32k, ModelDoubaoPro128k, ModelDoubaoPro256k, ModelDoubaoEmbedding, ModelDoubaoEmbeddingLarge, ModelT2I2L, ModelT2I2S, ModelT2IXL, ModelI2IXL, ModelT2I14, ModelI2I14IP, ModelAnime13, ModelAnime131, ModelPhotoverseAmericanComics, ModelPhotoverseExecutiveIDPhoto, ModelPhotoverse3dWeird, ModelPhotoverseCyberpunk, ModelXiezhenGubao, ModelXiezhenBabiNiuzai, ModelXiezhenBathrobe, ModelXiezhenButterflyMachin, ModelXiezhenZhichangzhengjianzhao, ModelXiezhenChristmas, ModelXiezhenDessert, ModelXiezhenOldMoney, ModelXiezhenSchool}, + } +} + +func (lm *LLM) getChatClient() *arkruntime.Client { + opt := make([]arkruntime.ConfigOption, 0) + if lm.config.Endpoint != "" { + opt = append(opt, arkruntime.WithBaseUrl(lm.config.Endpoint)) + } + return arkruntime.NewClientWithAkSk(strings.SplitN(lm.config.ApiKey, ",", 2)[0], opt...) +} + +func (lm *LLM) getGCClient() *visual.Visual { + keys := strings.SplitN(lm.config.ApiKey, ",", 2) + if len(keys) == 1 { + keys = append(keys, "") + } + vis := visual.NewInstance() + vis.Client.SetAccessKey(keys[0]) + vis.Client.SetSecretKey(keys[1]) + return vis +} + +// 因为火山平台的配置过于繁琐(每个模型都要创建单独的endpoint,所以暂时放弃对豆包大模型的支持) +// func init() { +// llm.Register("huoshan", func(config llm.Config) llm.LLM { +// return &LLM{config: config} +// }) +// } diff --git a/llm_unfinished/gc.go b/llm_unfinished/gc.go new file mode 100644 index 0000000..71cd8b2 --- /dev/null +++ b/llm_unfinished/gc.go @@ -0,0 +1,87 @@ +package huoshan + +import ( + "errors" + "strings" + "time" + + "apigo.cc/ai/llm/llm" + "github.com/ssgo/u" + "github.com/volcengine/volc-sdk-golang/service/visual/model" +) + +func (lm *LLM) FastMakeImage(prompt string, config llm.GCConfig) ([]string, llm.Usage, error) { + config.Model = ModelT2I14 + if config.Ref != "" { + config.Model = ModelI2I14IP + } + return lm.MakeImage(prompt, config) +} + +func (lm *LLM) BestMakeImage(prompt string, config llm.GCConfig) ([]string, llm.Usage, error) { + config.Model = ModelT2IXL + if config.Ref != "" { + config.Model = ModelI2IXL + } + return lm.MakeImage(prompt, config) +} + +func (lm *LLM) MakeImage(prompt string, config llm.GCConfig) ([]string, llm.Usage, error) { + config.SetDefault(&lm.config.GCConfig) + modelA := strings.SplitN(config.GetModel(), ":", 2) + sizeA := strings.SplitN(config.GetSize(), "x", 2) + if len(sizeA) == 1 { + sizeA = append(sizeA, sizeA[0]) + } + ref := config.GetRef() + vis := lm.getGCClient() + data := map[string]any{ + "req_key": modelA[0], + "prompt": prompt, + "width": u.Int(sizeA[0]), + "height": u.Int(sizeA[1]), + "return_url": true, + } + if len(modelA) > 1 { + data["model_version"] = modelA[1] + } + // TODO llm 支持动态额外参数 + + t1 := time.Now().UnixMilli() + var resp *model.VisualPubResult + var status int + var err error + if ref == "" { + resp, status, err = vis.Text2ImgXLSft(data) + } else { + if strings.Contains(ref, "://") { + data["image_url"] = []string{ref} + } else { + data["binary_data_base64"] = []string{ref} + } + resp, status, err = vis.Img2ImgXLSft(data) + } + t2 := time.Now().UnixMilli() - t1 + + if err != nil { + return nil, llm.Usage{}, err + } + if status != 200 { + return nil, llm.Usage{}, errors.New(resp.Message) + } + return resp.Data.ImageUrls, llm.Usage{ + UsedTime: t2, + }, nil +} + +func (lm *LLM) FastMakeVideo(prompt string, config llm.GCConfig) ([]string, []string, llm.Usage, error) { + return lm.MakeVideo(prompt, config) +} + +func (lm *LLM) BestMakeVideo(prompt string, config llm.GCConfig) ([]string, []string, llm.Usage, error) { + return lm.MakeVideo(prompt, config) +} + +func (lm *LLM) MakeVideo(prompt string, config llm.GCConfig) ([]string, []string, llm.Usage, error) { + return nil, nil, llm.Usage{}, errors.New("not support") +}