package main
import (
"context"
"log"
"net/http"
"reflect"
"github.com/ailncode/gluaxmlpath"
"github.com/cjoudrey/gluahttp"
"github.com/google/generative-ai-go/genai"
"github.com/kohkimakimoto/gluayaml"
"github.com/lrstanley/girc"
openai "github.com/sashabaranov/go-openai"
lua "github.com/yuin/gopher-lua"
"gitlab.com/megalithic-llc/gluasocket"
)
func registerStructAsLuaMetaTable[T any](
luaState *lua.LState,
luaLTable *lua.LTable,
checkStruct func(luaState *lua.LState) *T,
structType T,
metaTableName string,
) {
metaTable := luaState.NewTypeMetatable(metaTableName)
luaState.SetField(luaLTable, metaTableName, metaTable)
luaState.SetField(
metaTable,
"new",
luaState.NewFunction(
newStructFunctionFactory(structType, metaTableName),
),
)
var dummyType T
tableMethods := luaTableGenFactory(reflect.TypeOf(dummyType), checkStruct)
luaState.SetField(
metaTable,
"__index",
luaState.SetFuncs(
luaState.NewTable(),
tableMethods,
),
)
}
func newStructFunctionFactory[T any](structType T, metaTableName string) func(*lua.LState) int {
return func(luaState *lua.LState) int {
structInstance := &structType
ud := luaState.NewUserData()
ud.Value = structInstance
luaState.SetMetatable(ud, luaState.GetTypeMetatable(metaTableName))
luaState.Push(ud)
return 1
}
}
func checkStruct[T any](luaState *lua.LState) *T {
userData := luaState.CheckUserData(1)
if v, ok := userData.Value.(*T); ok {
return v
}
luaState.ArgError(1, "got wrong struct")
return nil
}
func getterSetterFactory[T any](
fieldName string,
fieldType reflect.Type,
checkStruct func(luaState *lua.LState) *T,
) func(*lua.LState) int {
return func(luaState *lua.LState) int {
genericStruct := checkStruct(luaState)
structValue := reflect.ValueOf(genericStruct).Elem()
fieldValue := structValue.FieldByName(fieldName)
if luaState.GetTop() == 2 { //nolint: mnd,gomnd
switch fieldType.Kind() {
case reflect.String:
fieldValue.SetString(luaState.CheckString(2)) //nolint: mnd,gomnd
case reflect.Float64:
fieldValue.SetFloat(float64(luaState.CheckNumber(2))) //nolint: mnd,gomnd
case reflect.Float32:
fieldValue.SetFloat(float64(luaState.CheckNumber(2))) //nolint: mnd,gomnd
case reflect.Int8:
fieldValue.SetInt(int64(luaState.CheckInt(2))) //nolint: mnd,gomnd
case reflect.Int16:
fieldValue.SetInt(int64(luaState.CheckInt(2))) //nolint: mnd,gomnd
case reflect.Int:
fieldValue.SetInt(int64(luaState.CheckInt(2))) //nolint: mnd,gomnd
case reflect.Int32:
fieldValue.SetInt(int64(luaState.CheckInt(2))) //nolint: mnd,gomnd
case reflect.Int64:
fieldValue.SetInt(int64(luaState.CheckInt(2))) //nolint: mnd,gomnd
case reflect.Bool:
fieldValue.SetBool(luaState.CheckBool(2)) //nolint: mnd,gomnd
case reflect.Uint:
fieldValue.SetUint(uint64(luaState.CheckInt(2))) //nolint: mnd,gomnd
case reflect.Uint8:
fieldValue.SetUint(uint64(luaState.CheckInt(2))) //nolint: mnd,gomnd
case reflect.Uint16:
fieldValue.SetUint(uint64(luaState.CheckInt(2))) //nolint: mnd,gomnd
case reflect.Uint32:
fieldValue.SetUint(uint64(luaState.CheckInt(2))) //nolint: mnd,gomnd
case reflect.Uint64:
fieldValue.SetUint(uint64(luaState.CheckInt(2))) //nolint: mnd,gomnd
case reflect.Func:
case reflect.Ptr:
case reflect.Struct:
case reflect.Slice:
case reflect.Array:
case reflect.Map:
default:
log.Print("unsupported type")
}
return 0
}
switch fieldType.Kind() {
case reflect.String:
luaState.Push(lua.LString(fieldValue.Interface().(string)))
case reflect.Float64:
luaState.Push(lua.LNumber(fieldValue.Interface().(float64)))
case reflect.Float32:
luaState.Push(lua.LNumber(fieldValue.Float()))
case reflect.Int8:
luaState.Push(lua.LNumber(fieldValue.Int()))
case reflect.Int16:
luaState.Push(lua.LNumber(fieldValue.Int()))
case reflect.Int:
luaState.Push(lua.LNumber(fieldValue.Int()))
case reflect.Int32:
luaState.Push(lua.LNumber(fieldValue.Int()))
case reflect.Int64:
luaState.Push(lua.LNumber(fieldValue.Int()))
case reflect.Bool:
luaState.Push(lua.LBool(fieldValue.Bool()))
case reflect.Uint:
luaState.Push(lua.LNumber(fieldValue.Uint()))
case reflect.Uint8:
luaState.Push(lua.LNumber(fieldValue.Uint()))
case reflect.Uint16:
luaState.Push(lua.LNumber(fieldValue.Uint()))
case reflect.Uint32:
luaState.Push(lua.LNumber(fieldValue.Uint()))
case reflect.Uint64:
luaState.Push(lua.LNumber(fieldValue.Uint()))
case reflect.Func:
case reflect.Ptr:
case reflect.Struct:
case reflect.Slice:
case reflect.Array:
case reflect.Map:
default:
log.Print("unsupported type")
}
return 1
}
}
func luaTableGenFactory[T any](
structType reflect.Type,
checkStructType func(luaState *lua.LState) *T) map[string]lua.LGFunction {
tableMethods := make(map[string]lua.LGFunction)
for _, field := range reflect.VisibleFields(structType) {
tableMethods[field.Name] = getterSetterFactory(field.Name, field.Type, checkStructType)
}
return tableMethods
}
func sendMessageClosure(luaState *lua.LState, client *girc.Client) func(*lua.LState) int {
return func(luaState *lua.LState) int {
message := luaState.CheckString(1)
target := luaState.CheckString(2) //nolint: mnd,gomnd
client.Cmd.Message(target, message)
return 0
}
}
func ircJoinChannelClosure(luaState *lua.LState, client *girc.Client) func(*lua.LState) int {
return func(luaState *lua.LState) int {
channel := luaState.CheckString(1)
client.Cmd.Join(channel)
return 0
}
}
func ircPartChannelClosure(luaState *lua.LState, client *girc.Client) func(*lua.LState) int {
return func(luaState *lua.LState) int {
channel := luaState.CheckString(1)
client.Cmd.Part(channel)
return 0
}
}
func ollamaRequestClosure(luaState *lua.LState, appConfig *TomlConfig) func(*lua.LState) int {
return func(luaState *lua.LState) int {
prompt := luaState.CheckString(1)
result, err := DoOllamaRequest(appConfig, &[]MemoryElement{}, prompt)
if err != nil {
log.Print(err)
}
luaState.Push(lua.LString(result))
return 1
}
}
func geminiRequestClosure(luaState *lua.LState, appConfig *TomlConfig) func(*lua.LState) int {
return func(luaState *lua.LState) int {
prompt := luaState.CheckString(1)
result, err := DoGeminiRequest(appConfig, &[]*genai.Content{}, prompt)
if err != nil {
log.Print(err)
}
luaState.Push(lua.LString(result))
return 1
}
}
func chatGPTRequestClosure(luaState *lua.LState, appConfig *TomlConfig) func(*lua.LState) int {
return func(luaState *lua.LState) int {
prompt := luaState.CheckString(1)
result, err := DoChatGPTRequest(appConfig, &[]openai.ChatCompletionMessage{}, prompt)
if err != nil {
log.Print(err)
}
luaState.Push(lua.LString(result))
return 1
}
}
func millaModuleLoaderClosure(luaState *lua.LState, client *girc.Client, appConfig *TomlConfig) func(*lua.LState) int {
return func(luaState *lua.LState) int {
exports := map[string]lua.LGFunction{
"send_message": lua.LGFunction(sendMessageClosure(luaState, client)),
"join_channel": lua.LGFunction(ircJoinChannelClosure(luaState, client)),
"part_channel": lua.LGFunction(ircPartChannelClosure(luaState, client)),
"send_ollama_request": lua.LGFunction(ollamaRequestClosure(luaState, appConfig)),
"send_gemini_request": lua.LGFunction(geminiRequestClosure(luaState, appConfig)),
"send_chatgpt_request": lua.LGFunction(chatGPTRequestClosure(luaState, appConfig)),
}
millaModule := luaState.SetFuncs(luaState.NewTable(), exports)
registerStructAsLuaMetaTable[TomlConfig](luaState, millaModule, checkStruct, TomlConfig{}, "toml_config")
registerStructAsLuaMetaTable[CustomCommand](luaState, millaModule, checkStruct, CustomCommand{}, "custom_command")
registerStructAsLuaMetaTable[LogModel](luaState, millaModule, checkStruct, LogModel{}, "log_model")
luaState.SetGlobal("milla", millaModule)
luaState.Push(millaModule)
return 1
}
}
func RunScript(scriptPath string, client *girc.Client, appConfig *TomlConfig) {
luaState := lua.NewState()
defer luaState.Close()
ctx, cancel := context.WithCancel(context.Background())
luaState.SetContext(ctx)
appConfig.insertLState(scriptPath, luaState, cancel)
luaState.PreloadModule("milla", millaModuleLoaderClosure(luaState, client, appConfig))
gluasocket.Preload(luaState)
gluaxmlpath.Preload(luaState)
luaState.PreloadModule("http", gluahttp.NewHttpModule(&http.Client{}).Loader)
luaState.PreloadModule("yaml", gluayaml.Loader)
log.Print("Running script: ", scriptPath)
err := luaState.DoFile(scriptPath)
if err != nil {
log.Print(err)
}
}
func LoadAllPlugins(appConfig *TomlConfig, client *girc.Client) {
for _, scriptPath := range appConfig.Plugins {
log.Print("Loading plugin: ", scriptPath)
go RunScript(scriptPath, client, appConfig)
}
}