430 lines
13 KiB
Go
430 lines
13 KiB
Go
package ai
|
||
|
||
import (
|
||
"reflect"
|
||
"strings"
|
||
"sync"
|
||
|
||
"apigo.cc/gojs"
|
||
"apigo.cc/gojs/goja"
|
||
"github.com/ssgo/u"
|
||
)
|
||
|
||
type Agent struct {
|
||
ChatConfigs map[string]*ChatConfig
|
||
EmbeddingConfigs map[string]*EmbeddingConfig
|
||
ImageConfigs map[string]*ImageConfig
|
||
VideoConfigs map[string]*VideoConfig
|
||
EditConfigs map[string]*map[string]any
|
||
ScanConfigs map[string]*map[string]any
|
||
AsrConfigs map[string]*AsrConfig
|
||
TtsConfigs map[string]*TtsConfig
|
||
Chat func(aiConf *AIConfig, messages []ChatMessage, callback func(string), conf ChatConfig) (ChatResult, error)
|
||
Embedding func(aiConf *AIConfig, text string, conf EmbeddingConfig) (EmbeddingResult, error)
|
||
MakeImage func(aiConf *AIConfig, conf ImageConfig) (ImageResult, error)
|
||
MakeVideo func(aiConf *AIConfig, conf VideoConfig) (string, error)
|
||
GetVideoResult func(aiConf *AIConfig, taskId string, waitSeconds int) (VideoResult, error)
|
||
Edit func(aiConf *AIConfig, from string, conf map[string]any) (StringResult, error)
|
||
Scan func(aiConf *AIConfig, image []byte, conf map[string]any) (ScanResult, error)
|
||
Asr func(aiConf *AIConfig, url string, conf AsrConfig) (ScanResult, error)
|
||
Tts func(aiConf *AIConfig, text string, conf TtsConfig) (StringResult, error)
|
||
}
|
||
|
||
type agentObj struct {
|
||
config *AIConfig
|
||
chatConfig *ChatConfig
|
||
embeddingConfig *EmbeddingConfig
|
||
imageConfig *ImageConfig
|
||
videoConfig *VideoConfig
|
||
editConfig *map[string]any
|
||
scanConfig *map[string]any
|
||
asrConfig *AsrConfig
|
||
ttsConfig *TtsConfig
|
||
agent *Agent
|
||
}
|
||
|
||
var agents = map[string]*Agent{}
|
||
var agentsLock = sync.RWMutex{}
|
||
|
||
func Register(aiName string, agent *Agent) {
|
||
if conf := aiList[aiName]; conf == nil {
|
||
agentsLock.Lock()
|
||
defer agentsLock.Unlock()
|
||
agents[aiName] = agent
|
||
}
|
||
}
|
||
|
||
func (ag *agentObj) Chat(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||
args := gojs.MakeArgs(&argsIn, vm).Check(1)
|
||
msgs, conf, cb := ag.getAskArgs(args.This, vm, args.Arguments)
|
||
r, err := ag.agent.Chat(ag.config, msgs, cb, conf)
|
||
if err != nil {
|
||
panic(vm.NewGoError(err))
|
||
}
|
||
return vm.ToValue(gojs.MakeMap(r))
|
||
}
|
||
|
||
func (ag *agentObj) Embedding(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||
args := gojs.MakeArgs(&argsIn, vm).Check(1)
|
||
conf := EmbeddingConfig{}
|
||
if confObj := args.Any(1); confObj != nil {
|
||
u.Convert(confObj, &conf)
|
||
}
|
||
if conf.Model == "" {
|
||
conf.Model = ag.embeddingConfig.Model
|
||
}
|
||
r, err := ag.agent.Embedding(ag.config, args.Str(0), conf)
|
||
if err != nil {
|
||
panic(vm.NewGoError(err))
|
||
}
|
||
return vm.ToValue(gojs.MakeMap(r))
|
||
}
|
||
|
||
func (ag *agentObj) MakeImage(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||
args := gojs.MakeArgs(&argsIn, vm).Check(1)
|
||
conf := ag.getImageArgs(args.Arguments)
|
||
r, err := ag.agent.MakeImage(ag.config, conf)
|
||
if err != nil {
|
||
panic(vm.NewGoError(err))
|
||
}
|
||
return vm.ToValue(gojs.MakeMap(r))
|
||
}
|
||
|
||
func (ag *agentObj) MakeVideo(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||
args := gojs.MakeArgs(&argsIn, vm).Check(1)
|
||
conf := ag.getVideoArgs(args.Arguments)
|
||
r, err := ag.agent.MakeVideo(ag.config, conf)
|
||
if err != nil {
|
||
panic(vm.NewGoError(err))
|
||
}
|
||
return vm.ToValue(r)
|
||
}
|
||
|
||
func (ag *agentObj) GetVideoResult(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||
args := gojs.MakeArgs(&argsIn, vm).Check(1)
|
||
r, err := ag.agent.GetVideoResult(ag.config, args.Str(0), args.Int(1))
|
||
if err != nil {
|
||
panic(vm.NewGoError(err))
|
||
}
|
||
return vm.ToValue(gojs.MakeMap(r))
|
||
}
|
||
|
||
func (ag *agentObj) Edit(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||
args := gojs.MakeArgs(&argsIn, vm).Check(1)
|
||
from := args.Str(0)
|
||
conf := ag.getMapArgs(args.Map(1), *ag.editConfig)
|
||
r, err := ag.agent.Edit(ag.config, from, conf)
|
||
if err != nil {
|
||
panic(vm.NewGoError(err))
|
||
}
|
||
return vm.ToValue(gojs.MakeMap(r))
|
||
}
|
||
|
||
func (ag *agentObj) Scan(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||
args := gojs.MakeArgs(&argsIn, vm).Check(1)
|
||
image := args.Bytes(0)
|
||
conf := ag.getMapArgs(args.Map(1), *ag.scanConfig)
|
||
r, err := ag.agent.Scan(ag.config, image, conf)
|
||
if err != nil {
|
||
panic(vm.NewGoError(err))
|
||
}
|
||
return vm.ToValue(gojs.MakeMap(r))
|
||
}
|
||
|
||
func (ag *agentObj) Asr(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||
args := gojs.MakeArgs(&argsIn, vm).Check(1)
|
||
url := args.Str(0)
|
||
conf := AsrConfig{}
|
||
u.Convert(args.Map(1), &conf)
|
||
// if conf.Model == "" {
|
||
// conf.Model = ag.asrConfig.Model
|
||
// }
|
||
// Uid string
|
||
// Format string
|
||
// Codec string
|
||
// Rate int
|
||
// Bits int
|
||
// Channel int
|
||
// Language string
|
||
// Itn bool
|
||
// Punc bool
|
||
// Ddc bool
|
||
// Extra map[string]any
|
||
conf.Extra = ag.getMapArgs(conf.Extra, ag.asrConfig.Extra)
|
||
r, err := ag.agent.Asr(ag.config, url, conf)
|
||
if err != nil {
|
||
panic(vm.NewGoError(err))
|
||
}
|
||
return vm.ToValue(gojs.MakeMap(r))
|
||
}
|
||
|
||
func (ag *agentObj) Tts(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
|
||
args := gojs.MakeArgs(&argsIn, vm).Check(1)
|
||
text := args.Str(0)
|
||
conf := TtsConfig{}
|
||
u.Convert(args.Map(1), &conf)
|
||
conf.Extra = ag.getMapArgs(conf.Extra, ag.asrConfig.Extra)
|
||
r, err := ag.agent.Tts(ag.config, text, conf)
|
||
if err != nil {
|
||
panic(vm.NewGoError(err))
|
||
}
|
||
return vm.ToValue(gojs.MakeMap(r))
|
||
}
|
||
|
||
func (ag *agentObj) getAskArgs(thisArg goja.Value, vm *goja.Runtime, args []goja.Value) ([]ChatMessage, ChatConfig, func(string)) {
|
||
conf := ChatConfig{}
|
||
var callback func(answer string)
|
||
if len(args) > 0 {
|
||
for i := 1; i < len(args); i++ {
|
||
if cb, ok := goja.AssertFunction(args[i]); ok {
|
||
callback = func(answer string) {
|
||
_, _ = cb(thisArg, vm.ToValue(answer))
|
||
}
|
||
} else if args[i].ExportType() != nil {
|
||
switch args[i].ExportType().Kind() {
|
||
case reflect.Map, reflect.Struct:
|
||
u.Convert(args[i].Export(), &conf)
|
||
default:
|
||
conf.Model = u.String(args[i].Export())
|
||
}
|
||
}
|
||
}
|
||
}
|
||
if conf.Model == "" {
|
||
conf.Model = ag.chatConfig.Model
|
||
}
|
||
if conf.SystemPrompt == "" {
|
||
conf.SystemPrompt = ag.chatConfig.SystemPrompt
|
||
}
|
||
if conf.MaxTokens == 0 {
|
||
conf.MaxTokens = ag.chatConfig.MaxTokens
|
||
}
|
||
if conf.Temperature == 0 {
|
||
conf.Temperature = ag.chatConfig.Temperature
|
||
}
|
||
if conf.TopP == 0 {
|
||
conf.TopP = ag.chatConfig.TopP
|
||
}
|
||
if conf.Tools == nil {
|
||
conf.Tools = ag.chatConfig.Tools
|
||
}
|
||
return makeChatMessages(args), conf, callback
|
||
}
|
||
|
||
func (ag *agentObj) getImageArgs(args []goja.Value) ImageConfig {
|
||
conf := ImageConfig{}
|
||
u.Convert(args[0].Export(), &conf)
|
||
if conf.Model == "" {
|
||
conf.Model = ag.imageConfig.Model
|
||
}
|
||
if conf.SystemPrompt == "" {
|
||
conf.SystemPrompt = ag.imageConfig.SystemPrompt
|
||
}
|
||
if conf.NegativePrompt == "" {
|
||
conf.NegativePrompt = ag.imageConfig.NegativePrompt
|
||
}
|
||
if conf.Cref == 0 {
|
||
conf.Cref = ag.imageConfig.Cref
|
||
}
|
||
if conf.Sref == 0 {
|
||
conf.Sref = ag.imageConfig.Sref
|
||
}
|
||
if conf.Scale == 0 {
|
||
conf.Scale = ag.imageConfig.Scale
|
||
}
|
||
if conf.Steps == 0 {
|
||
conf.Steps = ag.imageConfig.Steps
|
||
}
|
||
if conf.Width == 0 {
|
||
conf.Width = ag.imageConfig.Width
|
||
}
|
||
if conf.Height == 0 {
|
||
conf.Height = ag.imageConfig.Height
|
||
}
|
||
if ag.imageConfig.Extra != nil {
|
||
extra := make(map[string]any)
|
||
for k, v := range ag.imageConfig.Extra {
|
||
extra[k] = v
|
||
}
|
||
if conf.Extra != nil {
|
||
for k, v := range conf.Extra {
|
||
extra[k] = v
|
||
}
|
||
}
|
||
conf.Extra = extra
|
||
}
|
||
return conf
|
||
}
|
||
|
||
func (ag *agentObj) getVideoArgs(args []goja.Value) VideoConfig {
|
||
conf := VideoConfig{}
|
||
u.Convert(args[0].Export(), &conf)
|
||
if conf.Model == "" {
|
||
conf.Model = ag.videoConfig.Model
|
||
}
|
||
if conf.SystemPrompt == "" {
|
||
conf.SystemPrompt = ag.videoConfig.SystemPrompt
|
||
}
|
||
if conf.NegativePrompt == "" {
|
||
conf.NegativePrompt = ag.videoConfig.NegativePrompt
|
||
}
|
||
if conf.Width == 0 {
|
||
conf.Width = ag.videoConfig.Width
|
||
}
|
||
if conf.Height == 0 {
|
||
conf.Height = ag.videoConfig.Height
|
||
}
|
||
if ag.videoConfig.Extra != nil {
|
||
extra := make(map[string]any)
|
||
for k, v := range ag.videoConfig.Extra {
|
||
extra[k] = v
|
||
}
|
||
if conf.Extra != nil {
|
||
for k, v := range conf.Extra {
|
||
extra[k] = v
|
||
}
|
||
}
|
||
conf.Extra = extra
|
||
}
|
||
return conf
|
||
}
|
||
|
||
func (ag *agentObj) getMapArgs(setConf map[string]any, defaultConf map[string]any) map[string]any {
|
||
conf := map[string]any{}
|
||
if defaultConf != nil {
|
||
for k, v := range defaultConf {
|
||
conf[k] = v
|
||
}
|
||
}
|
||
if setConf != nil {
|
||
for k, v := range setConf {
|
||
conf[k] = v
|
||
}
|
||
}
|
||
return conf
|
||
}
|
||
|
||
func makeChatMessages(args []goja.Value) []ChatMessage {
|
||
out := make([]ChatMessage, 0)
|
||
if len(args) > 0 {
|
||
v := args[0].Export()
|
||
vv := reflect.ValueOf(v)
|
||
t := args[0].ExportType()
|
||
if t != nil {
|
||
lastRoleIsUser := false
|
||
switch t.Kind() {
|
||
// 数组,根据成员类型处理
|
||
// 字符串:
|
||
// 含有媒体:单条多模态消息
|
||
// 无媒体:多条文本消息
|
||
// 数组:多条消息(第一个成员不是 role 则自动生成)
|
||
// 对象:多条消息(无 role 则自动生成)(支持 content 或 contents)
|
||
// 结构:转换为 ChatMessage
|
||
// 对象:单条消息(支持 content 或 contents)
|
||
// 结构:转换为 ChatMessage
|
||
// 字符串:单条文本消息
|
||
case reflect.Slice:
|
||
hasSub := false
|
||
hasMulti := false
|
||
for i := 0; i < vv.Len(); i++ {
|
||
vv2 := u.FinalValue(vv.Index(i))
|
||
if vv2.Kind() == reflect.Slice || vv2.Kind() == reflect.Map || vv2.Kind() == reflect.Struct {
|
||
hasSub = true
|
||
break
|
||
}
|
||
if vv2.Kind() == reflect.String {
|
||
str := vv2.String()
|
||
if strings.HasPrefix(str, "data:") || strings.HasPrefix(str, "https://") || strings.HasPrefix(str, "http://") {
|
||
hasMulti = true
|
||
}
|
||
}
|
||
}
|
||
|
||
if hasSub || !hasMulti {
|
||
// 有子对象或纯文本数组
|
||
var defaultRole string
|
||
for i := 0; i < vv.Len(); i++ {
|
||
lastRoleIsUser = !lastRoleIsUser
|
||
if lastRoleIsUser {
|
||
defaultRole = RoleUser
|
||
} else {
|
||
defaultRole = RoleAssistant
|
||
}
|
||
vv2 := u.FinalValue(vv.Index(i))
|
||
switch vv2.Kind() {
|
||
case reflect.Slice:
|
||
out = append(out, makeChatMessageFromSlice(vv2, defaultRole))
|
||
case reflect.Map:
|
||
out = append(out, makeChatMessageFromMap(vv2, defaultRole))
|
||
case reflect.Struct:
|
||
item := ChatMessage{}
|
||
u.Convert(vv2.Interface(), &item)
|
||
out = append(out, item)
|
||
default:
|
||
out = append(out, ChatMessage{Role: RoleUser, Contents: []ChatMessageContent{makeChatMessageContent(u.String(vv2.Interface()))}})
|
||
}
|
||
lastRoleIsUser = out[len(out)-1].Role != RoleUser
|
||
}
|
||
} else {
|
||
// 单条多模态消息
|
||
out = append(out, makeChatMessageFromSlice(vv, RoleUser))
|
||
}
|
||
case reflect.Map:
|
||
out = append(out, makeChatMessageFromMap(vv, RoleUser))
|
||
case reflect.Struct:
|
||
item := ChatMessage{}
|
||
u.Convert(v, &item)
|
||
out = append(out, item)
|
||
default:
|
||
out = append(out, ChatMessage{Role: RoleUser, Contents: []ChatMessageContent{makeChatMessageContent(u.String(v))}})
|
||
}
|
||
}
|
||
}
|
||
return out
|
||
}
|
||
|
||
func makeChatMessageFromSlice(vv reflect.Value, defaultRole string) ChatMessage {
|
||
role := u.String(vv.Index(0).Interface())
|
||
j := 0
|
||
if role == RoleUser || role == RoleAssistant || role == RoleSystem || role == RoleTool {
|
||
j = 1
|
||
} else {
|
||
role = defaultRole
|
||
}
|
||
contents := make([]ChatMessageContent, 0)
|
||
for ; j < vv.Len(); j++ {
|
||
contents = append(contents, makeChatMessageContent(u.String(vv.Index(j).Interface())))
|
||
}
|
||
return ChatMessage{Role: role, Contents: contents}
|
||
}
|
||
|
||
func makeChatMessageFromMap(vv reflect.Value, defaultRole string) ChatMessage {
|
||
role := defaultRole
|
||
roleValue := vv.MapIndex(reflect.ValueOf("role"))
|
||
if roleValue.IsValid() {
|
||
role = u.String(roleValue.Interface())
|
||
}
|
||
contents := make([]ChatMessageContent, 0)
|
||
content := u.String(vv.MapIndex(reflect.ValueOf("content")).Interface())
|
||
if content != "" {
|
||
contents = append(contents, makeChatMessageContent(content))
|
||
} else {
|
||
contentsV := vv.MapIndex(reflect.ValueOf("contents"))
|
||
if contentsV.IsValid() && contentsV.Kind() == reflect.Slice {
|
||
for i := 0; i < contentsV.Len(); i++ {
|
||
contents = append(contents, makeChatMessageContent(u.String(contentsV.Index(i).Interface())))
|
||
}
|
||
}
|
||
}
|
||
return ChatMessage{Role: role, Contents: contents}
|
||
}
|
||
|
||
func makeChatMessageContent(contnet string) ChatMessageContent {
|
||
if strings.HasPrefix(contnet, "data:image/") || ((strings.HasPrefix(contnet, "https://") || strings.HasPrefix(contnet, "http://")) && (strings.HasSuffix(contnet, ".png") || strings.HasSuffix(contnet, ".jpg") || strings.HasSuffix(contnet, ".jpeg") || strings.HasSuffix(contnet, ".gif") || strings.HasSuffix(contnet, ".svg"))) {
|
||
return ChatMessageContent{Type: TypeImage, Content: contnet}
|
||
} else if strings.HasPrefix(contnet, "data:video/") || ((strings.HasPrefix(contnet, "https://") || strings.HasPrefix(contnet, "http://")) && (strings.HasSuffix(contnet, ".mp4") || strings.HasSuffix(contnet, ".mov") || strings.HasSuffix(contnet, ".m4v") || strings.HasSuffix(contnet, ".avi") || strings.HasSuffix(contnet, ".wmv"))) {
|
||
return ChatMessageContent{Type: TypeVideo, Content: contnet}
|
||
}
|
||
return ChatMessageContent{Type: TypeText, Content: contnet}
|
||
}
|