242 lines
6.1 KiB
Go
242 lines
6.1 KiB
Go
package huoshan
|
|
|
|
import (
|
|
"errors"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"apigo.cc/ai"
|
|
"github.com/ssgo/u"
|
|
"github.com/volcengine/volc-sdk-golang/service/visual"
|
|
"github.com/volcengine/volc-sdk-golang/service/visual/model"
|
|
)
|
|
|
|
func getVisualClient(aiConf *ai.AIConfig) *visual.Visual {
|
|
ak, sk := getAKSK(aiConf)
|
|
vis := visual.NewInstance()
|
|
vis.Client.SetAccessKey(ak)
|
|
vis.Client.SetSecretKey(sk)
|
|
if aiConf.Extra["region"] != nil {
|
|
vis.SetRegion(u.String(aiConf.Extra["region"]))
|
|
}
|
|
if aiConf.Extra["host"] != nil {
|
|
vis.SetHost(u.String(aiConf.Extra["host"]))
|
|
}
|
|
return vis
|
|
}
|
|
|
|
func MakeImage(aiConf *ai.AIConfig, conf ai.ImageConfig) (ai.ImageResult, error) {
|
|
modelA := strings.SplitN(conf.Model, ":", 2)
|
|
data := map[string]any{
|
|
"req_key": modelA[0],
|
|
"prompt": conf.SystemPrompt + conf.Prompt,
|
|
"return_url": true,
|
|
}
|
|
if len(modelA) > 1 {
|
|
data["model_version"] = modelA[1]
|
|
}
|
|
if conf.NegativePrompt != "" {
|
|
data["negative_prompt"] = conf.NegativePrompt
|
|
}
|
|
if conf.Width > 0 {
|
|
data["width"] = conf.Width
|
|
}
|
|
if conf.Height > 0 {
|
|
data["height"] = conf.Height
|
|
}
|
|
if conf.Scale > 0 {
|
|
if conf.Scale < 1 {
|
|
// 取值 0-1 放大到 1-30
|
|
data["scale"] = conf.Scale * 30
|
|
} else {
|
|
data["scale"] = conf.Scale
|
|
}
|
|
}
|
|
if conf.Steps > 0 {
|
|
data["ddim_steps"] = conf.Steps
|
|
}
|
|
if len(conf.Ref) > 0 {
|
|
// 如果有参考图,自动切换到图生图模型
|
|
if strings.Contains(modelA[0], "t2i") {
|
|
data["req_key"] = strings.ReplaceAll(modelA[0], "t2i", "i2i")
|
|
}
|
|
// 根据参考图类型设置(url和base64只能2选1)
|
|
image_url := make([]string, 0)
|
|
binary_data_base64 := make([]string, 0)
|
|
for _, ref := range conf.Ref {
|
|
if strings.Contains(ref, "://") {
|
|
image_url = append(image_url, ref)
|
|
} else {
|
|
binary_data_base64 = append(binary_data_base64, ref)
|
|
}
|
|
}
|
|
if len(image_url) > 0 {
|
|
data["image_urls"] = conf.Ref
|
|
} else {
|
|
data["binary_data_base64"] = binary_data_base64
|
|
}
|
|
|
|
// 参考图权重设置
|
|
if conf.Cref > 0 || conf.Sref > 0 {
|
|
if strings.Contains(conf.Model, "ip_keep") {
|
|
// 人脸保持模型
|
|
if conf.Cref > 0 {
|
|
data["ref_id_weight"] = conf.Cref
|
|
}
|
|
if conf.Sref > 0 {
|
|
data["ref_ip_weight"] = conf.Sref
|
|
}
|
|
} else if strings.Contains(conf.Model, ":anime_") {
|
|
// 动漫模型
|
|
if conf.Sref > 0 {
|
|
data["strength"] = conf.Sref
|
|
}
|
|
} else {
|
|
// 图生图模型
|
|
style_reference_args := map[string]any{"binary_data_index": 0}
|
|
if conf.Cref > 0 {
|
|
style_reference_args["id_weight"] = conf.Cref
|
|
}
|
|
if conf.Sref > 0 {
|
|
style_reference_args["style_weight"] = conf.Sref
|
|
}
|
|
data["style_reference_args"] = style_reference_args
|
|
}
|
|
}
|
|
}
|
|
|
|
// 其他参数
|
|
for k, v := range conf.Extra {
|
|
data[k] = v
|
|
}
|
|
|
|
// fmt.Println(u.BMagenta(u.JsonP(data)), 111)
|
|
t1 := time.Now().UnixMilli()
|
|
c := getVisualClient(aiConf)
|
|
respMap, status, err := c.CVProcess(data)
|
|
// fmt.Println(u.BCyan(u.JsonP(respMap)), 222)
|
|
resp := &model.VisualPubResult{}
|
|
u.Convert(respMap, resp)
|
|
t2 := time.Now().UnixMilli() - t1
|
|
|
|
if err != nil {
|
|
return ai.ImageResult{}, err
|
|
}
|
|
if status != 200 {
|
|
return ai.ImageResult{}, errors.New(u.String(resp.Message))
|
|
}
|
|
return ai.ImageResult{
|
|
Results: resp.Data.ImageUrls,
|
|
UsedTime: t2,
|
|
}, nil
|
|
}
|
|
|
|
func Edit(aiConf *ai.AIConfig, from string, conf map[string]any) (ai.StringResult, error) {
|
|
action := u.String(conf["action"])
|
|
c := getVisualClient(aiConf)
|
|
t1 := time.Now().UnixMilli()
|
|
switch action {
|
|
case "EmotionPortrait":
|
|
// 修改表情
|
|
req := map[string]any{"req_key": "emotion_portrait", "return_url": true}
|
|
for k, v := range conf {
|
|
if k != "action" {
|
|
req[k] = v
|
|
}
|
|
}
|
|
if strings.Contains(from, "://") {
|
|
req["image_urls"] = []string{from}
|
|
} else {
|
|
req["binary_data_base64"] = []string{from}
|
|
}
|
|
resp, status, err := c.EmotionPortrait(req)
|
|
if err != nil {
|
|
return ai.StringResult{}, err
|
|
}
|
|
if status != 200 {
|
|
return ai.StringResult{}, errors.New(u.String(resp.Message))
|
|
}
|
|
result := resp.Data.ImageUrls[0]
|
|
return ai.StringResult{
|
|
Result: result,
|
|
UsedTime: time.Now().UnixMilli() - t1,
|
|
}, nil
|
|
case "HumanSegment":
|
|
// 抠出人像
|
|
req := url.Values{
|
|
"refine": {u.String(u.Int(conf["refine"]))},
|
|
"return_foreground_image": {u.String(u.Int(conf["return_foreground_image"]))},
|
|
}
|
|
if strings.Contains(from, "://") {
|
|
req["image_url"] = []string{from}
|
|
} else {
|
|
req["image_base64"] = []string{from}
|
|
}
|
|
resp, status, err := c.HumanSegment(req)
|
|
if err != nil {
|
|
return ai.StringResult{}, err
|
|
}
|
|
if status != 200 {
|
|
return ai.StringResult{}, errors.New(u.String(resp.Message))
|
|
}
|
|
// u.WriteFileBytes("mask.jpg", u.UnBase64(resp.Data.Mask))
|
|
// u.WriteFileBytes("body.jpg", u.UnBase64(resp.Data.ForegroundImage))
|
|
result := resp.Data.ForegroundImage
|
|
if result == "" {
|
|
result = resp.Data.Mask
|
|
}
|
|
return ai.StringResult{
|
|
Result: result,
|
|
UsedTime: time.Now().UnixMilli() - t1,
|
|
}, nil
|
|
case "FacePretty":
|
|
// 美颜
|
|
req := url.Values{
|
|
"do_risk": {u.GetUpperName(u.String(u.Bool(conf["do_risk"])))},
|
|
"multi_face": {u.String(u.Int(conf["multi_face"]))},
|
|
"beauty_level": {u.String(u.Float(conf["beauty_level"]))},
|
|
}
|
|
if strings.Contains(from, "://") {
|
|
req["image_url"] = []string{from}
|
|
} else {
|
|
req["image_base64"] = []string{from}
|
|
}
|
|
resp, status, err := c.FacePretty(req)
|
|
if err != nil {
|
|
return ai.StringResult{}, err
|
|
}
|
|
if status != 200 {
|
|
return ai.StringResult{}, errors.New(u.String(resp.Message))
|
|
}
|
|
return ai.StringResult{
|
|
Result: resp.Data.Image,
|
|
UsedTime: time.Now().UnixMilli() - t1,
|
|
}, nil
|
|
case "HairStyle":
|
|
// 改变发型
|
|
req := url.Values{
|
|
"req_key": {"hair_style"},
|
|
"return_url": {"true"},
|
|
"hair_type": {u.String(conf["hair_type"])},
|
|
}
|
|
if strings.Contains(from, "://") {
|
|
req["image_urls"] = []string{from}
|
|
} else {
|
|
req["binary_data_base64"] = []string{from}
|
|
}
|
|
resp, status, err := c.HairStyle(req)
|
|
if err != nil {
|
|
return ai.StringResult{}, err
|
|
}
|
|
if status != 200 {
|
|
return ai.StringResult{}, errors.New(u.String(resp.Message))
|
|
}
|
|
return ai.StringResult{
|
|
Result: resp.Data.Image,
|
|
UsedTime: time.Now().UnixMilli() - t1,
|
|
}, nil
|
|
}
|
|
return ai.StringResult{}, nil
|
|
}
|