huoshan/gc.go
2024-10-31 15:31:37 +08:00

128 lines
3.1 KiB
Go

package huoshan
import (
"errors"
"strings"
"time"
"apigo.cc/ai"
"github.com/ssgo/u"
"github.com/volcengine/volc-sdk-golang/service/visual"
"github.com/volcengine/volc-sdk-golang/service/visual/model"
)
func getVisualClient(aiConf *ai.AIConfig) *visual.Visual {
ak, sk := getAKSK(aiConf)
vis := visual.NewInstance()
vis.Client.SetAccessKey(ak)
vis.Client.SetSecretKey(sk)
if aiConf.Extra["region"] != nil {
vis.SetRegion(u.String(aiConf.Extra["region"]))
}
if aiConf.Extra["host"] != nil {
vis.SetHost(u.String(aiConf.Extra["host"]))
}
return vis
}
func MakeImage(aiConf *ai.AIConfig, conf ai.ImageConfig) (ai.ImageResult, error) {
modelA := strings.SplitN(conf.Model, ":", 2)
data := map[string]any{
"req_key": modelA[0],
"prompt": conf.SystemPrompt + conf.Prompt,
"return_url": true,
}
if len(modelA) > 1 {
data["model_version"] = modelA[1]
}
if conf.NegativePrompt != "" {
data["negative_prompt"] = conf.NegativePrompt
}
if conf.Width > 0 {
data["width"] = conf.Width
}
if conf.Height > 0 {
data["height"] = conf.Height
}
if conf.Scale > 0 {
if conf.Scale < 1 {
// 取值 0-1 放大到 1-30
data["scale"] = conf.Scale * 30
} else {
data["scale"] = conf.Scale
}
}
if conf.Steps > 0 {
data["ddim_steps"] = conf.Steps
}
if len(conf.Ref) > 0 {
// 如果有参考图,自动切换到图生图模型
if strings.Contains(modelA[0], "t2i") {
data["req_key"] = strings.ReplaceAll(modelA[0], "t2i", "i2i")
}
// 根据参考图类型设置(url和base64只能2选1)
image_url := make([]string, 0)
binary_data_base64 := make([]string, 0)
for _, ref := range conf.Ref {
if strings.Contains(ref, "://") {
image_url = append(image_url, ref)
} else {
binary_data_base64 = append(binary_data_base64, ref)
}
}
if len(image_url) > 0 {
data["image_urls"] = conf.Ref
} else {
data["binary_data_base64"] = binary_data_base64
}
// 参考图权重设置
if conf.Cref > 0 || conf.Sref > 0 {
if strings.Contains(conf.Model, "ip_keep") {
// 人脸保持模型
if conf.Cref > 0 {
data["ref_id_weight"] = conf.Cref
}
if conf.Sref > 0 {
data["ref_ip_weight"] = conf.Sref
}
} else {
// 图生图模型
style_reference_args := map[string]any{"binary_data_index": 0}
if conf.Cref > 0 {
style_reference_args["id_weight"] = conf.Cref
}
if conf.Sref > 0 {
style_reference_args["style_weight"] = conf.Sref
}
data["style_reference_args"] = style_reference_args
}
}
}
// 其他参数
for k, v := range conf.Extra {
data[k] = v
}
// fmt.Println(u.BMagenta(u.JsonP(data)), 111)
t1 := time.Now().UnixMilli()
c := getVisualClient(aiConf)
respMap, status, err := c.CVProcess(data)
// fmt.Println(u.BCyan(u.JsonP(respMap)), 222)
resp := &model.VisualPubResult{}
u.Convert(respMap, resp)
t2 := time.Now().UnixMilli() - t1
if err != nil {
return ai.ImageResult{}, err
}
if status != 200 {
return ai.ImageResult{}, errors.New(u.String(resp.Message))
}
return ai.ImageResult{
Results: resp.Data.ImageUrls,
UsedTime: t2,
}, nil
}