ai/agent.go

430 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package ai
import (
"reflect"
"strings"
"sync"
"apigo.cc/gojs"
"apigo.cc/gojs/goja"
"github.com/ssgo/u"
)
type Agent struct {
ChatConfigs map[string]*ChatConfig
EmbeddingConfigs map[string]*EmbeddingConfig
ImageConfigs map[string]*ImageConfig
VideoConfigs map[string]*VideoConfig
EditConfigs map[string]*map[string]any
ScanConfigs map[string]*map[string]any
AsrConfigs map[string]*AsrConfig
TtsConfigs map[string]*TtsConfig
Chat func(aiConf *AIConfig, messages []ChatMessage, callback func(string), conf ChatConfig) (ChatResult, error)
Embedding func(aiConf *AIConfig, text string, conf EmbeddingConfig) (EmbeddingResult, error)
MakeImage func(aiConf *AIConfig, conf ImageConfig) (ImageResult, error)
MakeVideo func(aiConf *AIConfig, conf VideoConfig) (string, error)
GetVideoResult func(aiConf *AIConfig, taskId string, waitSeconds int) (VideoResult, error)
Edit func(aiConf *AIConfig, from string, conf map[string]any) (StringResult, error)
Scan func(aiConf *AIConfig, image []byte, conf map[string]any) (ScanResult, error)
Asr func(aiConf *AIConfig, url string, conf AsrConfig) (ScanResult, error)
Tts func(aiConf *AIConfig, text string, conf TtsConfig) (StringResult, error)
}
type agentObj struct {
config *AIConfig
chatConfig *ChatConfig
embeddingConfig *EmbeddingConfig
imageConfig *ImageConfig
videoConfig *VideoConfig
editConfig *map[string]any
scanConfig *map[string]any
asrConfig *AsrConfig
ttsConfig *TtsConfig
agent *Agent
}
var agents = map[string]*Agent{}
var agentsLock = sync.RWMutex{}
func Register(aiName string, agent *Agent) {
if conf := aiList[aiName]; conf == nil {
agentsLock.Lock()
defer agentsLock.Unlock()
agents[aiName] = agent
}
}
func (ag *agentObj) Chat(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
args := gojs.MakeArgs(&argsIn, vm).Check(1)
msgs, conf, cb := ag.getAskArgs(args.This, vm, args.Arguments)
r, err := ag.agent.Chat(ag.config, msgs, cb, conf)
if err != nil {
panic(vm.NewGoError(err))
}
return vm.ToValue(gojs.MakeMap(r))
}
func (ag *agentObj) Embedding(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
args := gojs.MakeArgs(&argsIn, vm).Check(1)
conf := EmbeddingConfig{}
if confObj := args.Any(1); confObj != nil {
u.Convert(confObj, &conf)
}
if conf.Model == "" {
conf.Model = ag.embeddingConfig.Model
}
r, err := ag.agent.Embedding(ag.config, args.Str(0), conf)
if err != nil {
panic(vm.NewGoError(err))
}
return vm.ToValue(gojs.MakeMap(r))
}
func (ag *agentObj) MakeImage(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
args := gojs.MakeArgs(&argsIn, vm).Check(1)
conf := ag.getImageArgs(args.Arguments)
r, err := ag.agent.MakeImage(ag.config, conf)
if err != nil {
panic(vm.NewGoError(err))
}
return vm.ToValue(gojs.MakeMap(r))
}
func (ag *agentObj) MakeVideo(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
args := gojs.MakeArgs(&argsIn, vm).Check(1)
conf := ag.getVideoArgs(args.Arguments)
r, err := ag.agent.MakeVideo(ag.config, conf)
if err != nil {
panic(vm.NewGoError(err))
}
return vm.ToValue(r)
}
func (ag *agentObj) GetVideoResult(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
args := gojs.MakeArgs(&argsIn, vm).Check(1)
r, err := ag.agent.GetVideoResult(ag.config, args.Str(0), args.Int(1))
if err != nil {
panic(vm.NewGoError(err))
}
return vm.ToValue(gojs.MakeMap(r))
}
func (ag *agentObj) Edit(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
args := gojs.MakeArgs(&argsIn, vm).Check(1)
from := args.Str(0)
conf := ag.getMapArgs(args.Map(1), *ag.editConfig)
r, err := ag.agent.Edit(ag.config, from, conf)
if err != nil {
panic(vm.NewGoError(err))
}
return vm.ToValue(gojs.MakeMap(r))
}
func (ag *agentObj) Scan(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
args := gojs.MakeArgs(&argsIn, vm).Check(1)
image := args.Bytes(0)
conf := ag.getMapArgs(args.Map(1), *ag.scanConfig)
r, err := ag.agent.Scan(ag.config, image, conf)
if err != nil {
panic(vm.NewGoError(err))
}
return vm.ToValue(gojs.MakeMap(r))
}
func (ag *agentObj) Asr(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
args := gojs.MakeArgs(&argsIn, vm).Check(1)
url := args.Str(0)
conf := AsrConfig{}
u.Convert(args.Map(1), &conf)
// if conf.Model == "" {
// conf.Model = ag.asrConfig.Model
// }
// Uid string
// Format string
// Codec string
// Rate int
// Bits int
// Channel int
// Language string
// Itn bool
// Punc bool
// Ddc bool
// Extra map[string]any
conf.Extra = ag.getMapArgs(conf.Extra, ag.asrConfig.Extra)
r, err := ag.agent.Asr(ag.config, url, conf)
if err != nil {
panic(vm.NewGoError(err))
}
return vm.ToValue(gojs.MakeMap(r))
}
func (ag *agentObj) Tts(argsIn goja.FunctionCall, vm *goja.Runtime) goja.Value {
args := gojs.MakeArgs(&argsIn, vm).Check(1)
text := args.Str(0)
conf := TtsConfig{}
u.Convert(args.Map(1), &conf)
conf.Extra = ag.getMapArgs(conf.Extra, ag.asrConfig.Extra)
r, err := ag.agent.Tts(ag.config, text, conf)
if err != nil {
panic(vm.NewGoError(err))
}
return vm.ToValue(gojs.MakeMap(r))
}
func (ag *agentObj) getAskArgs(thisArg goja.Value, vm *goja.Runtime, args []goja.Value) ([]ChatMessage, ChatConfig, func(string)) {
conf := ChatConfig{}
var callback func(answer string)
if len(args) > 0 {
for i := 1; i < len(args); i++ {
if cb, ok := goja.AssertFunction(args[i]); ok {
callback = func(answer string) {
_, _ = cb(thisArg, vm.ToValue(answer))
}
} else if args[i].ExportType() != nil {
switch args[i].ExportType().Kind() {
case reflect.Map, reflect.Struct:
u.Convert(args[i].Export(), &conf)
default:
conf.Model = u.String(args[i].Export())
}
}
}
}
if conf.Model == "" {
conf.Model = ag.chatConfig.Model
}
if conf.SystemPrompt == "" {
conf.SystemPrompt = ag.chatConfig.SystemPrompt
}
if conf.MaxTokens == 0 {
conf.MaxTokens = ag.chatConfig.MaxTokens
}
if conf.Temperature == 0 {
conf.Temperature = ag.chatConfig.Temperature
}
if conf.TopP == 0 {
conf.TopP = ag.chatConfig.TopP
}
if conf.Tools == nil {
conf.Tools = ag.chatConfig.Tools
}
return makeChatMessages(args), conf, callback
}
func (ag *agentObj) getImageArgs(args []goja.Value) ImageConfig {
conf := ImageConfig{}
u.Convert(args[0].Export(), &conf)
if conf.Model == "" {
conf.Model = ag.imageConfig.Model
}
if conf.SystemPrompt == "" {
conf.SystemPrompt = ag.imageConfig.SystemPrompt
}
if conf.NegativePrompt == "" {
conf.NegativePrompt = ag.imageConfig.NegativePrompt
}
if conf.Cref == 0 {
conf.Cref = ag.imageConfig.Cref
}
if conf.Sref == 0 {
conf.Sref = ag.imageConfig.Sref
}
if conf.Scale == 0 {
conf.Scale = ag.imageConfig.Scale
}
if conf.Steps == 0 {
conf.Steps = ag.imageConfig.Steps
}
if conf.Width == 0 {
conf.Width = ag.imageConfig.Width
}
if conf.Height == 0 {
conf.Height = ag.imageConfig.Height
}
if ag.imageConfig.Extra != nil {
extra := make(map[string]any)
for k, v := range ag.imageConfig.Extra {
extra[k] = v
}
if conf.Extra != nil {
for k, v := range conf.Extra {
extra[k] = v
}
}
conf.Extra = extra
}
return conf
}
func (ag *agentObj) getVideoArgs(args []goja.Value) VideoConfig {
conf := VideoConfig{}
u.Convert(args[0].Export(), &conf)
if conf.Model == "" {
conf.Model = ag.videoConfig.Model
}
if conf.SystemPrompt == "" {
conf.SystemPrompt = ag.videoConfig.SystemPrompt
}
if conf.NegativePrompt == "" {
conf.NegativePrompt = ag.videoConfig.NegativePrompt
}
if conf.Width == 0 {
conf.Width = ag.videoConfig.Width
}
if conf.Height == 0 {
conf.Height = ag.videoConfig.Height
}
if ag.videoConfig.Extra != nil {
extra := make(map[string]any)
for k, v := range ag.videoConfig.Extra {
extra[k] = v
}
if conf.Extra != nil {
for k, v := range conf.Extra {
extra[k] = v
}
}
conf.Extra = extra
}
return conf
}
func (ag *agentObj) getMapArgs(setConf map[string]any, defaultConf map[string]any) map[string]any {
conf := map[string]any{}
if defaultConf != nil {
for k, v := range defaultConf {
conf[k] = v
}
}
if setConf != nil {
for k, v := range setConf {
conf[k] = v
}
}
return conf
}
func makeChatMessages(args []goja.Value) []ChatMessage {
out := make([]ChatMessage, 0)
if len(args) > 0 {
v := args[0].Export()
vv := reflect.ValueOf(v)
t := args[0].ExportType()
if t != nil {
lastRoleIsUser := false
switch t.Kind() {
// 数组,根据成员类型处理
// 字符串:
// 含有媒体:单条多模态消息
// 无媒体:多条文本消息
// 数组:多条消息(第一个成员不是 role 则自动生成)
// 对象:多条消息(无 role 则自动生成)(支持 content 或 contents
// 结构:转换为 ChatMessage
// 对象:单条消息(支持 content 或 contents
// 结构:转换为 ChatMessage
// 字符串:单条文本消息
case reflect.Slice:
hasSub := false
hasMulti := false
for i := 0; i < vv.Len(); i++ {
vv2 := u.FinalValue(vv.Index(i))
if vv2.Kind() == reflect.Slice || vv2.Kind() == reflect.Map || vv2.Kind() == reflect.Struct {
hasSub = true
break
}
if vv2.Kind() == reflect.String {
str := vv2.String()
if strings.HasPrefix(str, "data:") || strings.HasPrefix(str, "https://") || strings.HasPrefix(str, "http://") {
hasMulti = true
}
}
}
if hasSub || !hasMulti {
// 有子对象或纯文本数组
var defaultRole string
for i := 0; i < vv.Len(); i++ {
lastRoleIsUser = !lastRoleIsUser
if lastRoleIsUser {
defaultRole = RoleUser
} else {
defaultRole = RoleAssistant
}
vv2 := u.FinalValue(vv.Index(i))
switch vv2.Kind() {
case reflect.Slice:
out = append(out, makeChatMessageFromSlice(vv2, defaultRole))
case reflect.Map:
out = append(out, makeChatMessageFromMap(vv2, defaultRole))
case reflect.Struct:
item := ChatMessage{}
u.Convert(vv2.Interface(), &item)
out = append(out, item)
default:
out = append(out, ChatMessage{Role: RoleUser, Contents: []ChatMessageContent{makeChatMessageContent(u.String(vv2.Interface()))}})
}
lastRoleIsUser = out[len(out)-1].Role != RoleUser
}
} else {
// 单条多模态消息
out = append(out, makeChatMessageFromSlice(vv, RoleUser))
}
case reflect.Map:
out = append(out, makeChatMessageFromMap(vv, RoleUser))
case reflect.Struct:
item := ChatMessage{}
u.Convert(v, &item)
out = append(out, item)
default:
out = append(out, ChatMessage{Role: RoleUser, Contents: []ChatMessageContent{makeChatMessageContent(u.String(v))}})
}
}
}
return out
}
func makeChatMessageFromSlice(vv reflect.Value, defaultRole string) ChatMessage {
role := u.String(vv.Index(0).Interface())
j := 0
if role == RoleUser || role == RoleAssistant || role == RoleSystem || role == RoleTool {
j = 1
} else {
role = defaultRole
}
contents := make([]ChatMessageContent, 0)
for ; j < vv.Len(); j++ {
contents = append(contents, makeChatMessageContent(u.String(vv.Index(j).Interface())))
}
return ChatMessage{Role: role, Contents: contents}
}
func makeChatMessageFromMap(vv reflect.Value, defaultRole string) ChatMessage {
role := defaultRole
roleValue := vv.MapIndex(reflect.ValueOf("role"))
if roleValue.IsValid() {
role = u.String(roleValue.Interface())
}
contents := make([]ChatMessageContent, 0)
content := u.String(vv.MapIndex(reflect.ValueOf("content")).Interface())
if content != "" {
contents = append(contents, makeChatMessageContent(content))
} else {
contentsV := vv.MapIndex(reflect.ValueOf("contents"))
if contentsV.IsValid() && contentsV.Kind() == reflect.Slice {
for i := 0; i < contentsV.Len(); i++ {
contents = append(contents, makeChatMessageContent(u.String(contentsV.Index(i).Interface())))
}
}
}
return ChatMessage{Role: role, Contents: contents}
}
func makeChatMessageContent(contnet string) ChatMessageContent {
if strings.HasPrefix(contnet, "data:image/") || ((strings.HasPrefix(contnet, "https://") || strings.HasPrefix(contnet, "http://")) && (strings.HasSuffix(contnet, ".png") || strings.HasSuffix(contnet, ".jpg") || strings.HasSuffix(contnet, ".jpeg") || strings.HasSuffix(contnet, ".gif") || strings.HasSuffix(contnet, ".svg"))) {
return ChatMessageContent{Type: TypeImage, Content: contnet}
} else if strings.HasPrefix(contnet, "data:video/") || ((strings.HasPrefix(contnet, "https://") || strings.HasPrefix(contnet, "http://")) && (strings.HasSuffix(contnet, ".mp4") || strings.HasSuffix(contnet, ".mov") || strings.HasSuffix(contnet, ".m4v") || strings.HasSuffix(contnet, ".avi") || strings.HasSuffix(contnet, ".wmv"))) {
return ChatMessageContent{Type: TypeVideo, Content: contnet}
}
return ChatMessageContent{Type: TypeText, Content: contnet}
}