From f481ec6bd224f41d847ceafb69c79282613419a7 Mon Sep 17 00:00:00 2001 From: terminaldweller Date: Wed, 22 May 2024 23:34:57 -0400 Subject: sql query custom commands, WIP --- main.go | 289 +++++++++++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 195 insertions(+), 94 deletions(-) (limited to 'main.go') diff --git a/main.go b/main.go index 4b896ce..f8f0739 100644 --- a/main.go +++ b/main.go @@ -25,6 +25,7 @@ import ( "github.com/BurntSushi/toml" "github.com/alecthomas/chroma/v2/quick" "github.com/google/generative-ai-go/genai" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/lrstanley/girc" openai "github.com/sashabaranov/go-openai" @@ -41,58 +42,67 @@ var ( errUnsupportedType = errors.New("unsupported type") ) -type CustomCommand struct { - SQL string `json:"sql"` - Limit int `json:"limit"` - Prompt string `json:"prompt"` +type LogModel struct { + // Id int64 `db:"id"` + // Channel string `db:"channel"` + Log string `db:"log"` + // Nick string `db:"nick"` + // DateAdded pgtype.Date `db:"dateadded"` } -type CustomCommands struct { - CustomCommands map[string]CustomCommand `json:"customCommands"` +type CustomCommand struct { + SQL string `toml:"sql"` + Limit int `toml:"limit"` + Prompt string `toml:"prompt"` } type TomlConfig struct { - IrcServer string `toml:"ircServer"` - IrcNick string `toml:"ircNick"` - IrcSaslUser string `toml:"ircSaslUser"` - IrcSaslPass string `toml:"ircSaslPass"` - OllamaEndpoint string `toml:"ollamaEndpoint"` - Model string `toml:"model"` - ChromaStyle string `toml:"chromaStyle"` - ChromaFormatter string `toml:"chromaFormatter"` - Provider string `toml:"provider"` - Apikey string `toml:"apikey"` - OllamaSystem string `toml:"ollamaSystem"` - ClientCertPath string `toml:"clientCertPath"` - ServerPass string `toml:"serverPass"` - Bind string `toml:"bind"` - Name string `toml:"name"` - DatabaseAddress string `toml:"databaseAddress"` - DatabasePassword string `toml:"databasePassword"` - DatabaseUser string `toml:"databaseUser"` - DatabaseName string `toml:"databaseName"` - LLMProxy string `toml:"llmProxy"` - IRCProxy string `toml:"ircProxy"` - IRCDName string `toml:"ircdName"` - CommandsFile string `toml:"commandsFile"` - Temp float64 `toml:"temp"` - RequestTimeout int `toml:"requestTimeout"` - MillaReconnectDelay int `toml:"millaReconnectDelay"` - IrcPort int `toml:"ircPort"` - KeepAlive int `toml:"keepAlive"` - MemoryLimit int `toml:"memoryLimit"` - PingDelay int `toml:"pingDelay"` - PingTimeout int `toml:"pingTimeout"` - TopP float32 `toml:"topP"` - TopK int32 `toml:"topK"` - EnableSasl bool `toml:"enableSasl"` - SkipTLSVerify bool `toml:"skipTLSVerify"` - UseTLS bool `toml:"useTLS"` - DisableSTSFallback bool `toml:"disableSTSFallback"` - AllowFlood bool `toml:"allowFlood"` - Debug bool `toml:"debug"` - Out bool `toml:"out"` - AdminOnly bool `toml:"adminOnly"` + IrcServer string `toml:"ircServer"` + IrcNick string `toml:"ircNick"` + IrcSaslUser string `toml:"ircSaslUser"` + IrcSaslPass string `toml:"ircSaslPass"` + OllamaEndpoint string `toml:"ollamaEndpoint"` + Model string `toml:"model"` + ChromaStyle string `toml:"chromaStyle"` + ChromaFormatter string `toml:"chromaFormatter"` + Provider string `toml:"provider"` + Apikey string `toml:"apikey"` + OllamaSystem string `toml:"ollamaSystem"` + ClientCertPath string `toml:"clientCertPath"` + ServerPass string `toml:"serverPass"` + Bind string `toml:"bind"` + Name string `toml:"name"` + DatabaseAddress string `toml:"databaseAddress"` + DatabasePassword string `toml:"databasePassword"` + DatabaseUser string `toml:"databaseUser"` + DatabaseName string `toml:"databaseName"` + LLMProxy string `toml:"llmProxy"` + IRCProxy string `toml:"ircProxy"` + IRCDName string `toml:"ircdName"` + WebIRCPassword string `toml:"webIRCPassword"` + WebIRCGateway string `toml:"webIRCGateway"` + WebIRCHostname string `toml:"webIRCHostname"` + WebIRCAddress string `toml:"webIRCAddress"` + CustomCommands map[string]CustomCommand `toml:"customCommands"` + Temp float64 `toml:"temp"` + RequestTimeout int `toml:"requestTimeout"` + MillaReconnectDelay int `toml:"millaReconnectDelay"` + IrcPort int `toml:"ircPort"` + KeepAlive int `toml:"keepAlive"` + MemoryLimit int `toml:"memoryLimit"` + PingDelay int `toml:"pingDelay"` + PingTimeout int `toml:"pingTimeout"` + TopP float32 `toml:"topP"` + TopK int32 `toml:"topK"` + EnableSasl bool `toml:"enableSasl"` + SkipTLSVerify bool `toml:"skipTLSVerify"` + UseTLS bool `toml:"useTLS"` + DisableSTSFallback bool `toml:"disableSTSFallback"` + AllowFlood bool `toml:"allowFlood"` + Debug bool `toml:"debug"` + Out bool `toml:"out"` + AdminOnly bool `toml:"adminOnly"` + pool *pgxpool.Pool Admins []string `toml:"admins"` IrcChannels []string `toml:"ircChannels"` ScrapeChannels []string `toml:"scrapeChannels"` @@ -107,10 +117,6 @@ func addSaneDefaults(config *TomlConfig) { config.IrcNick = "milla" } - if config.IrcSaslUser == "" { - config.IrcSaslUser = "milla" - } - if config.ChromaStyle == "" { config.ChromaStyle = "rose-pine-moon" } @@ -119,10 +125,6 @@ func addSaneDefaults(config *TomlConfig) { config.ChromaFormatter = "noop" } - if config.Provider == "" { - config.Provider = "ollam" - } - if config.DatabaseAddress == "" { config.DatabaseAddress = "postgres" } @@ -135,12 +137,8 @@ func addSaneDefaults(config *TomlConfig) { config.DatabaseName = "milladb" } - if config.CommandsFile == "" { - config.CommandsFile = "./commands.json" - } - if config.Temp == 0 { - config.Temp = 0.5 //nollint:gomnd + config.Temp = 0.5 } if config.RequestTimeout == 0 { @@ -190,11 +188,11 @@ type OllamaChatMessagesResponse struct { } type OllamaChatRequest struct { - Model string `json:"model"` - Stream bool `json:"stream"` - Keep_alive time.Duration `json:"keep_alive"` - Options OllamaRequestOptions `json:"options"` - Messages []MemoryElement `json:"messages"` + Model string `json:"model"` + Stream bool `json:"stream"` + KeepAlive time.Duration `json:"keep_alive"` + Options OllamaRequestOptions `json:"options"` + Messages []MemoryElement `json:"messages"` } type MemoryElement struct { @@ -303,6 +301,9 @@ func getHelpString() string { helpString += "help - show this help message\n" helpString += "set - set a configuration value\n" helpString += "get - get a configuration value\n" + helpString += "join - joins a given channel\n" + helpString += "leave - leaves a given channel\n" + helpString += "cmd - run a custom command defined in the customcommands file\n" helpString += "getall - returns all config options with their value\n" helpString += "memstats - returns the memory status currently being used\n" @@ -352,7 +353,7 @@ func setFieldByName(v reflect.Value, field string, value string) error { func byteToMByte(bytes uint64, ) uint64 { - return bytes / 1024 / 1024 //nolint:gomnd + return bytes / 1024 / 1024 } func runCommand( @@ -383,7 +384,7 @@ func runCommand( case "help": sendToIRC(client, event, getHelpString(), "noop") case "set": - if len(args) < 3 { //nolint:gomnd + if len(args) < 3 { client.Cmd.Reply(event, errNotEnoughArgs.Error()) break @@ -394,7 +395,7 @@ func runCommand( client.Cmd.Reply(event, err.Error()) } case "get": - if len(args) < 2 { //nolint:gomnd + if len(args) < 2 { client.Cmd.Reply(event, errNotEnoughArgs.Error()) break @@ -429,6 +430,81 @@ func runCommand( client.Cmd.Reply(event, fmt.Sprintf("Alloc: %d MiB", byteToMByte(memStats.Alloc))) client.Cmd.Reply(event, fmt.Sprintf("TotalAlloc: %d MiB", byteToMByte(memStats.TotalAlloc))) client.Cmd.Reply(event, fmt.Sprintf("Sys: %d MiB", byteToMByte(memStats.Sys))) + case "join": + if len(args) < 2 { + client.Cmd.Reply(event, errNotEnoughArgs.Error()) + + break + } + + client.Cmd.Join(args[1]) + case "leave": + if len(args) < 2 { + client.Cmd.Reply(event, errNotEnoughArgs.Error()) + + break + } + + client.Cmd.Part(args[1]) + case "cmd": + if len(args) < 2 { + client.Cmd.Reply(event, errNotEnoughArgs.Error()) + + break + } + + customCommand := appConfig.CustomCommands[args[1]] + + if customCommand.SQL == "" { + client.Cmd.Reply(event, "empty sql commands in the custom command") + + break + } + + if appConfig.pool == nil { + client.Cmd.Reply(event, "no database connection") + + break + } + + log.Println(customCommand.SQL) + + rows, err := appConfig.pool.Query(context.Background(), customCommand.SQL) + defer rows.Close() + + if err != nil { + client.Cmd.Reply(event, "error: "+err.Error()) + + break + } + + var gptMemory []openai.ChatCompletionMessage + + logs, err := pgx.CollectRows(rows, pgx.RowToStructByName[LogModel]) + if err != nil { + log.Println(err.Error()) + + break + } + + log.Println(logs) + logs = logs[:customCommand.Limit] + + if err != nil { + log.Println(err.Error()) + + break + } + + for _, log := range logs { + gptMemory = append(gptMemory, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, + Content: log.Log, + }) + } + + chatGPTRequest(appConfig, client, event, &gptMemory, customCommand.Prompt) + default: client.Cmd.Reply(event, errUnknCmd.Error()) } @@ -457,10 +533,10 @@ func doOllamaRequest( *ollamaMemory = append(*ollamaMemory, memoryElement) ollamaRequest := OllamaChatRequest{ - Model: appConfig.Model, - Keep_alive: time.Duration(appConfig.KeepAlive), - Stream: false, - Messages: *ollamaMemory, + Model: appConfig.Model, + KeepAlive: time.Duration(appConfig.KeepAlive), + Stream: false, + Messages: *ollamaMemory, Options: OllamaRequestOptions{ Temperature: appConfig.Temp, }, @@ -468,9 +544,9 @@ func doOllamaRequest( jsonPayload, err = json.Marshal(ollamaRequest) if err != nil { - client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + client.Cmd.ReplyTo(event, "error: "+err.Error()) - return nil, err + return nil, fmt.Errorf("could not marshal json payload: %v", err) } log.Printf("json payload: %s", string(jsonPayload)) @@ -480,9 +556,9 @@ func doOllamaRequest( request, err := http.NewRequest(http.MethodPost, appConfig.OllamaEndpoint, bytes.NewBuffer(jsonPayload)) if err != nil { - client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + client.Cmd.ReplyTo(event, "error: "+err.Error()) - return nil, err + return nil, fmt.Errorf("could not make a new http request: %v", err) } request = request.WithContext(ctx) @@ -531,10 +607,11 @@ func ollamaRequest( } if err != nil { - client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + client.Cmd.ReplyTo(event, "error: "+err.Error()) return } + defer response.Body.Close() log.Println("response body:", response.Body) @@ -545,7 +622,7 @@ func ollamaRequest( err = json.NewDecoder(response.Body).Decode(&ollamaChatResponse) if err != nil { - client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + client.Cmd.ReplyTo(event, "error: "+err.Error()) } assistantElement := MemoryElement{ @@ -563,7 +640,7 @@ func ollamaRequest( appConfig.ChromaFormatter, appConfig.ChromaStyle) if err != nil { - client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + client.Cmd.ReplyTo(event, "error: "+err.Error()) return } @@ -580,17 +657,21 @@ func ollamaHandler( if !strings.HasPrefix(event.Last(), appConfig.IrcNick+": ") { return } + if appConfig.AdminOnly { byAdmin := false + for _, admin := range appConfig.Admins { if event.Source.Name == admin { byAdmin = true } } + if !byAdmin { return } } + prompt := strings.TrimPrefix(event.Last(), appConfig.IrcNick+": ") log.Println(prompt) @@ -616,7 +697,7 @@ func doGeminiRequest( clientGemini, err := genai.NewClient(ctx, option.WithAPIKey(appConfig.Apikey)) if err != nil { - client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + client.Cmd.ReplyTo(event, "error: "+err.Error()) return "" } @@ -633,7 +714,7 @@ func doGeminiRequest( resp, err := cs.SendMessage(ctx, genai.Text(prompt)) if err != nil { - client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + client.Cmd.ReplyTo(event, "error: "+err.Error()) return "" } @@ -678,7 +759,7 @@ func geminiRequest( appConfig.ChromaFormatter, appConfig.ChromaStyle) if err != nil { - client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + client.Cmd.ReplyTo(event, "error: "+err.Error()) return } @@ -695,17 +776,21 @@ func geminiHandler( if !strings.HasPrefix(event.Last(), appConfig.IrcNick+": ") { return } + if appConfig.AdminOnly { byAdmin := false + for _, admin := range appConfig.Admins { if event.Source.Name == admin { byAdmin = true } } + if !byAdmin { return } } + prompt := strings.TrimPrefix(event.Last(), appConfig.IrcNick+": ") log.Println(prompt) @@ -735,7 +820,7 @@ func doChatGPTRequest( proxyURL, err := url.Parse(appConfig.IRCProxy) if err != nil { cancel() - client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + client.Cmd.ReplyTo(event, "error: "+err.Error()) log.Fatal(err.Error()) } @@ -743,7 +828,7 @@ func doChatGPTRequest( dialer, err := proxy.FromURL(proxyURL, &net.Dialer{Timeout: time.Duration(appConfig.RequestTimeout) * time.Second}) if err != nil { cancel() - client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + client.Cmd.ReplyTo(event, "error: "+err.Error()) log.Fatal(err.Error()) } @@ -782,7 +867,7 @@ func chatGPTRequest( ) { resp, err := doChatGPTRequest(appConfig, client, event, gptMemory, prompt) if err != nil { - client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + client.Cmd.ReplyTo(event, "error: "+err.Error()) return } @@ -805,7 +890,7 @@ func chatGPTRequest( appConfig.ChromaFormatter, appConfig.ChromaStyle) if err != nil { - client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + client.Cmd.ReplyTo(event, "error: "+err.Error()) return } @@ -822,17 +907,21 @@ func chatGPTHandler( if !strings.HasPrefix(event.Last(), appConfig.IrcNick+": ") { return } + if appConfig.AdminOnly { byAdmin := false + for _, admin := range appConfig.Admins { if event.Source.Name == admin { byAdmin = true } } + if !byAdmin { return } } + prompt := strings.TrimPrefix(event.Last(), appConfig.IrcNick+": ") log.Println(prompt) @@ -846,7 +935,7 @@ func chatGPTHandler( }) } -func connectToDB(appConfig TomlConfig, ctx *context.Context, poolChan chan *pgxpool.Pool) { +func connectToDB(appConfig *TomlConfig, ctx *context.Context, poolChan chan *pgxpool.Pool) { for { if appConfig.DatabaseUser == "" { appConfig.DatabaseUser = os.Getenv("MILLA_DB_USER") @@ -895,6 +984,7 @@ func connectToDB(appConfig TomlConfig, ctx *context.Context, poolChan chan *pgxp nick text not null, dateadded timestamp default current_timestamp )`, tableName) + _, err = pool.Exec(*ctx, query) if err != nil { log.Println(err.Error()) @@ -902,13 +992,14 @@ func connectToDB(appConfig TomlConfig, ctx *context.Context, poolChan chan *pgxp } } + appConfig.pool = pool poolChan <- pool } } } func scrapeChannel(irc *girc.Client, poolChan chan *pgxpool.Pool, appConfig TomlConfig) { - irc.Handlers.AddBg(girc.PRIVMSG, func(client *girc.Client, event girc.Event) { + irc.Handlers.AddBg(girc.PRIVMSG, func(_ *girc.Client, event girc.Event) { pool := <-poolChan tableName := getTableFromChanName(event.Params[0], appConfig.IRCDName) query := fmt.Sprintf( @@ -953,6 +1044,13 @@ func runIRC(appConfig TomlConfig) { }, }) + if appConfig.WebIRCGateway != "" { + irc.Config.WebIRC.Address = appConfig.WebIRCAddress + irc.Config.WebIRC.Gateway = appConfig.WebIRCGateway + irc.Config.WebIRC.Hostname = appConfig.WebIRCHostname + irc.Config.WebIRC.Password = appConfig.WebIRCPassword + } + if appConfig.Debug { irc.Config.Debug = os.Stdout } @@ -1003,7 +1101,7 @@ func runIRC(appConfig TomlConfig) { irc.Config.TLSConfig.Certificates = []tls.Certificate{cert} } - irc.Handlers.AddBg(girc.CONNECTED, func(c *girc.Client, e girc.Event) { + irc.Handlers.AddBg(girc.CONNECTED, func(c *girc.Client, _ girc.Event) { for _, channel := range appConfig.IrcChannels { c.Cmd.Join(channel) } @@ -1022,11 +1120,11 @@ func runIRC(appConfig TomlConfig) { context, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second) defer cancel() - go connectToDB(appConfig, &context, poolChan) + go connectToDB(&appConfig, &context, poolChan) } if len(appConfig.ScrapeChannels) > 0 { - irc.Handlers.AddBg(girc.CONNECTED, func(c *girc.Client, e girc.Event) { + irc.Handlers.AddBg(girc.CONNECTED, func(c *girc.Client, _ girc.Event) { for _, channel := range appConfig.ScrapeChannels { c.Cmd.Join(channel) } @@ -1080,10 +1178,13 @@ func main() { log.Fatal(err) } + for key, value := range config.Ircd { + addSaneDefaults(&value) + value.IRCDName = key + config.Ircd[key] = value + } + for k, v := range config.Ircd { - addSaneDefaults(&v) - v.IRCDName = k - config.Ircd[k] = v log.Println(k, v) } -- cgit v1.2.3