diff options
author | terminaldweller <devi@terminaldweller.com> | 2024-05-15 14:41:39 +0000 |
---|---|---|
committer | terminaldweller <devi@terminaldweller.com> | 2024-05-15 14:41:39 +0000 |
commit | 006cb3a77d7437a60733c28abee0083e6f02ae90 (patch) | |
tree | 2cae44510125ce8eb09d4ee4e142396c7769536f /main.go | |
parent | fixing the go executable build for github actions (diff) | |
download | milla-006cb3a77d7437a60733c28abee0083e6f02ae90.tar.gz milla-006cb3a77d7437a60733c28abee0083e6f02ae90.zip |
fixes #14, fixes #16, fixes #17, fixes #18, fixes #19, fixes #20
Diffstat (limited to '')
-rw-r--r-- | main.go | 216 |
1 files changed, 200 insertions, 16 deletions
@@ -9,10 +9,13 @@ import ( "flag" "fmt" "log" + "net" "net/http" + "net/url" "os" "reflect" "regexp" + "runtime" "strconv" "strings" "time" @@ -20,6 +23,7 @@ import ( "github.com/BurntSushi/toml" "github.com/alecthomas/chroma/v2/quick" "github.com/google/generative-ai-go/genai" + "github.com/jackc/pgx/v5/pgxpool" "github.com/lrstanley/girc" openai "github.com/sashabaranov/go-openai" "golang.org/x/net/proxy" @@ -33,6 +37,7 @@ var ( errCantSet = errors.New("can't set field") errWrongDataForField = errors.New("wrong data type for field") errUnsupportedType = errors.New("unsupported type") + dbConnection *pgxpool.Pool //nolint:gochecknoglobals ) type TomlConfig struct { @@ -50,6 +55,13 @@ type TomlConfig struct { ClientCertPath string `toml:"clientCertPath"` ServerPass string `toml:"serverPass"` Bind string `toml:"bind"` + Name string `toml:"name"` + DatabaseAddress string `toml:"databaseAddress"` + DatabasePassword string `toml:"databasePassword"` + DatabaseUser string `toml:"databaseUser"` + DatabaseName string `toml:"databaseName"` + LLMProxy string `toml:"llmProxy"` + IRCProxy string `toml:"ircProxy"` Temp float64 `toml:"temp"` RequestTimeout int `toml:"requestTimeout"` MillaReconnectDelay int `toml:"millaReconnectDelay"` @@ -69,6 +81,7 @@ type TomlConfig struct { Out bool `toml:"out"` Admins []string `toml:"admins"` IrcChannels []string `toml:"ircChannels"` + ScrapeChannels []string `toml:"scrapeChannels"` } func NewTomlConfig() *TomlConfig { @@ -78,6 +91,9 @@ func NewTomlConfig() *TomlConfig { ChromaStyle: "rose-pine-moon", ChromaFormatter: "noop", Provider: "ollama", + DatabaseAddress: "postgres", + DatabaseUser: "milla", + DatabaseName: "milladb", Temp: 0.5, //nolint:gomnd RequestTimeout: 10, //nolint:gomnd MillaReconnectDelay: 30, //nolint:gomnd @@ -206,6 +222,7 @@ func getHelpString() string { helpString += "set - set a configuration value\n" helpString += "get - get a configuration value\n" helpString += "getall - returns all config options with their value\n" + helpString += "memstats - returns the memory status currently being used\n" return helpString } @@ -251,6 +268,11 @@ func setFieldByName(v reflect.Value, field string, value string) error { return nil } +func byteToMByte(bytes uint64, +) uint64 { + return bytes / 1024 / 1024 +} + func runCommand( client *girc.Client, event girc.Event, @@ -317,6 +339,13 @@ func runCommand( fieldValue := v.Field(i).Interface() client.Cmd.Reply(event, fmt.Sprintf("%s: %v", field.Name, fieldValue)) } + case "memstats": + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + + client.Cmd.Reply(event, fmt.Sprintf("Alloc: %d MiB", byteToMByte(memStats.Alloc))) + client.Cmd.Reply(event, fmt.Sprintf("TotalAlloc: %d MiB", byteToMByte(memStats.TotalAlloc))) + client.Cmd.Reply(event, fmt.Sprintf("Sys: %d MiB", byteToMByte(memStats.Sys))) default: client.Cmd.Reply(event, errUnknCmd.Error()) } @@ -385,12 +414,28 @@ func ollamaHandler( var httpClient http.Client - dialer := proxy.FromEnvironment() + var dialer proxy.Dialer - httpClient = http.Client{ - Transport: &http.Transport{ - Dial: dialer.Dial, - }, + if appConfig.LLMProxy != "" { + proxyURL, err := url.Parse(appConfig.IRCProxy) + if err != nil { + cancel() + + log.Fatal(err.Error()) + } + + dialer, err = proxy.FromURL(proxyURL, &net.Dialer{Timeout: time.Duration(appConfig.RequestTimeout) * time.Second}) + if err != nil { + cancel() + + log.Fatal(err.Error()) + } + + httpClient = http.Client{ + Transport: &http.Transport{ + Dial: dialer.Dial, + }, + } } response, err := httpClient.Do(request) @@ -474,6 +519,10 @@ func geminiHandler( // clientGemini, err := genai.NewClient(ctx, option.WithAPIKey(appConfig.Apikey), option.WithHTTPClient(&httpClient)) + if appConfig.Apikey == "" { + appConfig.Apikey = os.Getenv("MILLA_APIKEY") + } + clientGemini, err := genai.NewClient(ctx, option.WithAPIKey(appConfig.Apikey)) if err != nil { client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) @@ -559,12 +608,30 @@ func chatGPTHandler( var httpClient http.Client - dialer := proxy.FromEnvironment() + if appConfig.LLMProxy != "" { + proxyURL, err := url.Parse(appConfig.IRCProxy) + if err != nil { + cancel() - httpClient = http.Client{ - Transport: &http.Transport{ - Dial: dialer.Dial, - }, + log.Fatal(err.Error()) + } + + dialer, err := proxy.FromURL(proxyURL, &net.Dialer{Timeout: time.Duration(appConfig.RequestTimeout) * time.Second}) + if err != nil { + cancel() + + log.Fatal(err.Error()) + } + + httpClient = http.Client{ + Transport: &http.Transport{ + Dial: dialer.Dial, + }, + } + } + + if appConfig.Apikey == "" { + appConfig.Apikey = os.Getenv("MILLA_APIKEY") } config := openai.DefaultConfig(appConfig.Apikey) @@ -613,7 +680,78 @@ func chatGPTHandler( }) } -func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) { +func connectToDB(appConfig TomlConfig, context *context.Context) { + for { + if appConfig.DatabaseUser == "" { + appConfig.DatabaseUser = os.Getenv("MILLA_DB_USER") + } + + if appConfig.DatabasePassword == "" { + appConfig.DatabasePassword = os.Getenv("MILLA_DB_PASSWORD") + } + + if appConfig.DatabaseAddress == "" { + appConfig.DatabaseAddress = os.Getenv("MILLA_DB_ADDRESS") + } + + if appConfig.DatabaseName == "" { + appConfig.DatabaseName = os.Getenv("MILLA_DB_NAME") + } + + dbURL := fmt.Sprintf( + "postgres://%s:%s@%s/%s", + appConfig.DatabaseUser, + appConfig.DatabasePassword, + appConfig.DatabaseAddress, + appConfig.DatabaseName) + + conn, err := pgxpool.New(*context, dbURL) + if err != nil { + log.Println(err) + time.Sleep(time.Duration(appConfig.MillaReconnectDelay) * time.Second) + } else { + for _, channel := range appConfig.ScrapeChannels { + query := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id SERIAL PRIMARY KEY,channel TEXT NOT NULL,log TEXT NOT NULL,nick TEXT NOT NULL,dateadded TIMESTAMP DEFAULT CURRENT_TIMESTAMP)", + strings.ReplaceAll(channel, "#", "")) + + log.Println(query) + + _, err = conn.Query(*context, query) + if err != nil { + log.Println(err.Error()) + time.Sleep(time.Duration(appConfig.MillaReconnectDelay) * time.Second) + } + } + + dbConnection = conn + } + } +} + +func scrapeChannel(irc *girc.Client) { + irc.Handlers.AddBg(girc.PRIVMSG, func(client *girc.Client, event girc.Event) { + if dbConnection == nil { + log.Println("missed logging message because currently not connected to db") + + return + } + query := fmt.Sprintf("INSERT INTO %s (channel,log,nick) VALUES ('%s','%s','%s')", + strings.ReplaceAll(event.Params[0], "#", ""), + event.Params[0], + event.Last(), + event.Source.Name, + ) + log.Println(query) + + _, err := dbConnection.Query( + context.Background(), query) + if err != nil { + log.Println(err.Error()) + } + }) +} + +func runIRC(appConfig TomlConfig, ircChan chan *girc.Client, dbChan chan *pgxpool.Pool) { var OllamaMemory []MemoryElement var GeminiMemory []*genai.Content @@ -646,16 +784,29 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) { irc.Config.Out = os.Stdout } - if appConfig.ServerPass != "" { - irc.Config.ServerPass = appConfig.ServerPass + if appConfig.ServerPass == "" { + appConfig.ServerPass = os.Getenv("MILLA_SERVER_PASSWORD") } + irc.Config.ServerPass = appConfig.ServerPass + if appConfig.Bind != "" { irc.Config.Bind = appConfig.Bind } + if appConfig.Name != "" { + irc.Config.Name = appConfig.Name + } + saslUser := appConfig.IrcSaslUser - saslPass := appConfig.IrcSaslPass + + var saslPass string + + if appConfig.IrcSaslPass == "" { + saslPass = os.Getenv("MILLA_SASL_PASSWORD") + } else { + saslPass = appConfig.IrcSaslPass + } if appConfig.EnableSasl && saslUser != "" && saslPass != "" { irc.Config.SASL = &girc.SASLPlain{ @@ -690,10 +841,42 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) { chatGPTHandler(irc, &appConfig, &GPTMemory) } + context, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second) + defer cancel() + + go connectToDB(appConfig, &context) + + if len(appConfig.ScrapeChannels) > 0 { + irc.Handlers.AddBg(girc.CONNECTED, func(c *girc.Client, e girc.Event) { + for _, channel := range appConfig.ScrapeChannels { + c.Cmd.Join(channel) + } + }) + + go scrapeChannel(irc) + } ircChan <- irc for { - if err := irc.Connect(); err != nil { + var dialer proxy.Dialer + + if appConfig.IRCProxy != "" { + proxyURL, err := url.Parse(appConfig.IRCProxy) + if err != nil { + cancel() + + log.Fatal(err.Error()) + } + + dialer, err = proxy.FromURL(proxyURL, &net.Dialer{Timeout: time.Duration(appConfig.RequestTimeout) * time.Second}) + if err != nil { + cancel() + + log.Fatal(err.Error()) + } + } + + if err := irc.DialerConnect(dialer); err != nil { log.Println(err) log.Println("reconnecting in " + strconv.Itoa(appConfig.MillaReconnectDelay)) time.Sleep(time.Duration(appConfig.MillaReconnectDelay) * time.Second) @@ -723,6 +906,7 @@ func main() { log.Println(appConfig) ircChan := make(chan *girc.Client, 1) + dbConn := make(chan *pgxpool.Pool, 1) - runIRC(*appConfig, ircChan) + runIRC(*appConfig, ircChan, dbConn) } |