diff options
Diffstat (limited to '')
-rw-r--r-- | main.go | 298 |
1 files changed, 206 insertions, 92 deletions
@@ -6,6 +6,7 @@ import ( "crypto/tls" "encoding/json" "errors" + "expvar" "flag" "fmt" "index/suffixarray" @@ -13,6 +14,7 @@ import ( "math/rand" "net" "net/http" + _ "net/http/pprof" "net/url" "os" "os/signal" @@ -26,6 +28,7 @@ import ( "github.com/BurntSushi/toml" "github.com/alecthomas/chroma/v2/quick" + "github.com/cenkalti/backoff/v5" "github.com/google/generative-ai-go/genai" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" @@ -69,8 +72,8 @@ func addSaneDefaults(config *TomlConfig) { config.DatabaseName = "milladb" } - if config.Temp == 0 { - config.Temp = 0.5 + if config.Temperature == 0 { + config.Temperature = 0.5 } if config.RequestTimeout == 0 { @@ -101,9 +104,49 @@ func addSaneDefaults(config *TomlConfig) { config.PingTimeout = 20 } + if config.OllamaMirostatEta == 0 { + config.OllamaMirostatEta = 0.1 + } + + if config.OllamaMirostatTau == 0 { + config.OllamaMirostatTau = 5.0 + } + + if config.OllamaNumCtx == 0 { + config.OllamaNumCtx = 4096 + } + + if config.OllamaRepeatLastN == 0 { + config.OllamaRepeatLastN = 64 + } + + if config.OllamaRepeatPenalty == 0 { + config.OllamaRepeatPenalty = 1.1 + } + + if config.OllamaSeed == 0 { + config.OllamaSeed = 42 + } + + if config.OllamaNumPredict == 0 { + config.OllamaNumPredict = -1 + } + + if config.TopK == 0 { + config.TopK = 40 + } + if config.TopP == 0.0 { config.TopP = 0.9 } + + if config.OllamaMinP == 0 { + config.OllamaMinP = 0.05 + } + + if config.Temperature == 0 { + config.Temperature = 0.7 + } } func getTableFromChanName(channel, ircdName string) string { @@ -316,7 +359,7 @@ func handleCustomCommand( bigPrompt += log.Log + "\n" } - result := ChatGPTRequestProcessor(appConfig, client, event, &gptMemory, customCommand.Prompt) + result := ChatGPTRequestProcessor(appConfig, client, event, &gptMemory, customCommand.Prompt, customCommand.SystemPrompt) if result != "" { SendToIRC(client, event, result, appConfig.ChromaFormatter) } @@ -341,7 +384,7 @@ func handleCustomCommand( }) } - result := GeminiRequestProcessor(appConfig, client, event, &geminiMemory, customCommand.Prompt) + result := GeminiRequestProcessor(appConfig, client, event, &geminiMemory, customCommand.Prompt, customCommand.SystemPrompt) if result != "" { SendToIRC(client, event, result, appConfig.ChromaFormatter) } @@ -362,7 +405,7 @@ func handleCustomCommand( }) } - result := OllamaRequestProcessor(appConfig, client, event, &ollamaMemory, customCommand.Prompt) + result := OllamaRequestProcessor(appConfig, client, event, &ollamaMemory, customCommand.Prompt, customCommand.SystemPrompt) if result != "" { SendToIRC(client, event, result, appConfig.ChromaFormatter) } @@ -541,7 +584,7 @@ func runCommand( appConfig.deleteLstate(args[1]) case "remind": - if len(args) < 2 { + if len(args) < 2 { //nolint: mnd,gomnd client.Cmd.Reply(event, errNotEnoughArgs.Error()) break @@ -559,10 +602,9 @@ func runCommand( client.Cmd.ReplyTo(event, " Ping!") case "forget": - client.Cmd.Reply(event, "I no longer even know whether you're supposed to wear or drink a camel.'") case "whois": - if len(args) < 2 { + if len(args) < 2 { //nolint: mnd,gomnd client.Cmd.Reply(event, errNotEnoughArgs.Error()) break @@ -575,9 +617,8 @@ func runCommand( upperLimit := 6 if len(args) == 1 { - } else if len(args) == 2 { + } else if len(args) == 2 { //nolint: mnd,gomnd argOne, err := strconv.Atoi(args[1]) - if err != nil { client.Cmd.Reply(event, errNotEnoughArgs.Error()) @@ -585,9 +626,8 @@ func runCommand( } upperLimit = argOne - } else if len(args) == 3 { + } else if len(args) == 3 { //nolint: mnd,gomnd argOne, err := strconv.Atoi(args[1]) - if err != nil { client.Cmd.Reply(event, errNotEnoughArgs.Error()) @@ -597,7 +637,6 @@ func runCommand( lowerLimit = argOne argTwo, err := strconv.Atoi(args[2]) - if err != nil { client.Cmd.Reply(event, errNotEnoughArgs.Error()) @@ -636,7 +675,7 @@ func runCommand( func DoOllamaRequest( appConfig *TomlConfig, ollamaMemory *[]MemoryElement, - prompt string, + prompt, systemPrompt string, ) (string, error) { var jsonPayload []byte @@ -665,14 +704,25 @@ func DoOllamaRequest( KeepAlive: time.Duration(appConfig.KeepAlive), Stream: false, Messages: *ollamaMemory, + System: systemPrompt, Options: OllamaRequestOptions{ - Temperature: appConfig.Temp, + Mirostat: appConfig.OllamaMirostat, + MirostatEta: appConfig.OllamaMirostatEta, + MirostatTau: appConfig.OllamaMirostatTau, + NumCtx: appConfig.OllamaNumCtx, + RepeatLastN: appConfig.OllamaRepeatLastN, + RepeatPenalty: appConfig.OllamaRepeatPenalty, + Temperature: appConfig.Temperature, + Seed: appConfig.OllamaSeed, + NumPredict: appConfig.OllamaNumPredict, + TopK: appConfig.TopK, + TopP: appConfig.TopP, + MinP: appConfig.OllamaMinP, }, } jsonPayload, err = json.Marshal(ollamaRequest) if err != nil { - return "", err } @@ -683,7 +733,6 @@ func DoOllamaRequest( request, err := http.NewRequest(http.MethodPost, appConfig.Endpoint, bytes.NewBuffer(jsonPayload)) if err != nil { - return "", err } @@ -715,16 +764,14 @@ func DoOllamaRequest( }, } } - response, err := httpClient.Do(request) + response, err := httpClient.Do(request) if err != nil { return "", err } defer response.Body.Close() - log.Println("response body:", response.Body) - var ollamaChatResponse OllamaChatMessagesResponse err = json.NewDecoder(response.Body).Decode(&ollamaChatResponse) @@ -732,6 +779,8 @@ func DoOllamaRequest( return "", err } + log.Println("ollama chat response: ", ollamaChatResponse) + return ollamaChatResponse.Messages.Content, nil } @@ -740,9 +789,9 @@ func OllamaRequestProcessor( client *girc.Client, event girc.Event, ollamaMemory *[]MemoryElement, - prompt string, + prompt, systemPrompt string, ) string { - response, err := DoOllamaRequest(appConfig, ollamaMemory, prompt) + response, err := DoOllamaRequest(appConfig, ollamaMemory, prompt, systemPrompt) if err != nil { client.Cmd.ReplyTo(event, "error: "+err.Error()) @@ -807,7 +856,7 @@ func OllamaHandler( return } - result := OllamaRequestProcessor(appConfig, client, event, ollamaMemory, prompt) + result := OllamaRequestProcessor(appConfig, client, event, ollamaMemory, prompt, appConfig.SystemPrompt) if result != "" { SendToIRC(client, event, result, appConfig.ChromaFormatter) } @@ -822,6 +871,7 @@ func (t *ProxyRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) if err != nil { return nil, err } + transport.Proxy = http.ProxyURL(proxyURL) } @@ -841,7 +891,7 @@ func (t *ProxyRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) func DoGeminiRequest( appConfig *TomlConfig, geminiMemory *[]*genai.Content, - prompt string, + prompt, systemPrompt string, ) (string, error) { httpProxyClient := &http.Client{Transport: &ProxyRoundTripper{ APIKey: appConfig.Apikey, @@ -853,15 +903,37 @@ func DoGeminiRequest( clientGemini, err := genai.NewClient(ctx, option.WithHTTPClient(httpProxyClient)) if err != nil { - - return "", err + return "", fmt.Errorf("Could not create a genai client.", err) } defer clientGemini.Close() model := clientGemini.GenerativeModel(appConfig.Model) - model.SetTemperature(float32(appConfig.Temp)) + model.SetTemperature(float32(appConfig.Temperature)) model.SetTopK(appConfig.TopK) model.SetTopP(appConfig.TopP) + model.SystemInstruction = &genai.Content{ + Parts: []genai.Part{ + genai.Text(systemPrompt), + }, + } + model.SafetySettings = []*genai.SafetySetting{ + { + Category: genai.HarmCategoryDangerousContent, + Threshold: genai.HarmBlockNone, + }, + { + Category: genai.HarmCategoryHarassment, + Threshold: genai.HarmBlockNone, + }, + { + Category: genai.HarmCategoryHateSpeech, + Threshold: genai.HarmBlockNone, + }, + { + Category: genai.HarmCategorySexuallyExplicit, + Threshold: genai.HarmBlockNone, + }, + } cs := model.StartChat() @@ -869,8 +941,7 @@ func DoGeminiRequest( resp, err := cs.SendMessage(ctx, genai.Text(prompt)) if err != nil { - - return "", err + return "", fmt.Errorf("Gemini: Could not send message", err) } return returnGeminiResponse(resp), nil @@ -881,9 +952,9 @@ func GeminiRequestProcessor( client *girc.Client, event girc.Event, geminiMemory *[]*genai.Content, - prompt string, + prompt, systemPrompt string, ) string { - geminiResponse, err := DoGeminiRequest(appConfig, geminiMemory, prompt) + geminiResponse, err := DoGeminiRequest(appConfig, geminiMemory, prompt, systemPrompt) if err != nil { client.Cmd.ReplyTo(event, "error: "+err.Error()) @@ -969,7 +1040,7 @@ func GeminiHandler( return } - result := GeminiRequestProcessor(appConfig, client, event, geminiMemory, prompt) + result := GeminiRequestProcessor(appConfig, client, event, geminiMemory, prompt, appConfig.SystemPrompt) if result != "" { SendToIRC(client, event, result, appConfig.ChromaFormatter) @@ -980,7 +1051,7 @@ func GeminiHandler( func DoChatGPTRequest( appConfig *TomlConfig, gptMemory *[]openai.ChatCompletionMessage, - prompt string, + prompt, systemPrompt string, ) (string, error) { ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second) defer cancel() @@ -1011,6 +1082,7 @@ func DoChatGPTRequest( config := openai.DefaultConfig(appConfig.Apikey) config.HTTPClient = &httpClient + if appConfig.Endpoint != "" { config.BaseURL = appConfig.Endpoint log.Print(config.BaseURL) @@ -1019,6 +1091,11 @@ func DoChatGPTRequest( gptClient := openai.NewClientWithConfig(config) *gptMemory = append(*gptMemory, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleSystem, + Content: systemPrompt, + }) + + *gptMemory = append(*gptMemory, openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleUser, Content: prompt, }) @@ -1028,7 +1105,6 @@ func DoChatGPTRequest( Messages: *gptMemory, }) if err != nil { - return "", err } @@ -1040,9 +1116,9 @@ func ChatGPTRequestProcessor( client *girc.Client, event girc.Event, gptMemory *[]openai.ChatCompletionMessage, - prompt string, + prompt, systemPrompt string, ) string { - resp, err := DoChatGPTRequest(appConfig, gptMemory, prompt) + resp, err := DoChatGPTRequest(appConfig, gptMemory, prompt, systemPrompt) if err != nil { client.Cmd.ReplyTo(event, "error: "+err.Error()) @@ -1115,14 +1191,16 @@ func ChatGPTHandler( return } - result := ChatGPTRequestProcessor(appConfig, client, event, gptMemory, prompt) + result := ChatGPTRequestProcessor(appConfig, client, event, gptMemory, prompt, appConfig.SystemPrompt) if result != "" { SendToIRC(client, event, result, appConfig.ChromaFormatter) } }) } -func connectToDB(appConfig *TomlConfig, ctx *context.Context, poolChan chan *pgxpool.Pool) { +func connectToDB(appConfig *TomlConfig, ctx *context.Context, irc *girc.Client) { + var pool *pgxpool.Pool + dbURL := fmt.Sprintf( "postgres://%s:%s@%s/%s", appConfig.DatabaseUser, @@ -1134,40 +1212,54 @@ func connectToDB(appConfig *TomlConfig, ctx *context.Context, poolChan chan *pgx poolConfig, err := pgxpool.ParseConfig(dbURL) if err != nil { - LogErrorFatal(err) + LogError(err) + + return } - pool, err := pgxpool.NewWithConfig(*ctx, poolConfig) + dbConnect := func() (*pgxpool.Pool, error) { + return pgxpool.NewWithConfig(*ctx, poolConfig) + } + + pool, err = backoff.Retry(*ctx, dbConnect, backoff.WithBackOff(backoff.NewExponentialBackOff())) if err != nil { - LogErrorFatal(err) - } else { - log.Printf("%s connected to database", appConfig.IRCDName) - - for _, channel := range appConfig.ScrapeChannels { - tableName := getTableFromChanName(channel[0], appConfig.IRCDName) - query := fmt.Sprintf( - `create table if not exists %s ( - id serial primary key, - channel text not null, - log text not null, - nick text not null, - dateadded timestamp default current_timestamp - )`, tableName) - - _, err = pool.Exec(*ctx, query) - if err != nil { - LogErrorFatal(err) - } - } + LogError(err) + } + + log.Printf("%s connected to database", appConfig.IRCDName) - appConfig.pool = pool - poolChan <- pool + for _, channel := range appConfig.ScrapeChannels { + tableName := getTableFromChanName(channel[0], appConfig.IRCDName) + query := fmt.Sprintf( + `create table if not exists %s ( + id serial primary key, + channel text not null, + log text not null, + nick text not null, + dateadded timestamp default current_timestamp + )`, tableName) + + _, err := pool.Exec(*ctx, query) + if err != nil { + LogError(err) + + continue + } } + + appConfig.pool = pool } -func scrapeChannel(irc *girc.Client, poolChan chan *pgxpool.Pool, appConfig TomlConfig) { +func scrapeChannel(irc *girc.Client, appConfig *TomlConfig) { + log.Print("spawning scraper") + irc.Handlers.AddBg(girc.PRIVMSG, func(_ *girc.Client, event girc.Event) { - pool := <-poolChan + if appConfig.pool == nil { + log.Println("no db connection. cant write scrapes to db.") + + return + } + tableName := getTableFromChanName(event.Params[0], appConfig.IRCDName) query := fmt.Sprintf( "insert into %s (channel,log,nick) values ('%s','%s','%s')", @@ -1177,7 +1269,12 @@ func scrapeChannel(irc *girc.Client, poolChan chan *pgxpool.Pool, appConfig Toml event.Source.Name, ) - _, err := pool.Exec(context.Background(), query) + log.Println(query) + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second) + defer cancel() + + _, err := appConfig.pool.Exec(ctx, query) if err != nil { LogError(err) } @@ -1202,7 +1299,6 @@ func populateWatchListWords(appConfig *TomlConfig) { appConfig.WatchLists[watchlistName] = watchlist } } - } func WatchListHandler(irc *girc.Client, appConfig TomlConfig) { @@ -1269,8 +1365,6 @@ func runIRC(appConfig TomlConfig) { var ORMemory []MemoryElement - poolChan := make(chan *pgxpool.Pool, 1) - irc := girc.New(girc.Config{ Server: appConfig.IrcServer, Port: appConfig.IrcPort, @@ -1332,7 +1426,7 @@ func runIRC(appConfig TomlConfig) { irc.Config.TLSConfig.Certificates = []tls.Certificate{cert} } - irc.Handlers.AddBg(girc.CONNECTED, func(c *girc.Client, _ girc.Event) { + irc.Handlers.AddBg(girc.CONNECTED, func(_ *girc.Client, _ girc.Event) { for _, channel := range appConfig.IrcChannels { IrcJoin(irc, channel) } @@ -1385,21 +1479,21 @@ func runIRC(appConfig TomlConfig) { context, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second) defer cancel() - go connectToDB(&appConfig, &context, poolChan) + go connectToDB(&appConfig, &context, irc) } if len(appConfig.ScrapeChannels) > 0 { - irc.Handlers.AddBg(girc.CONNECTED, func(c *girc.Client, _ girc.Event) { + irc.Handlers.AddBg(girc.CONNECTED, func(_ *girc.Client, _ girc.Event) { for _, channel := range appConfig.ScrapeChannels { IrcJoin(irc, channel) } }) - go scrapeChannel(irc, poolChan, appConfig) + go scrapeChannel(irc, &appConfig) } if len(appConfig.WatchLists) > 0 { - irc.Handlers.AddBg(girc.CONNECTED, func(client *girc.Client, _ girc.Event) { + irc.Handlers.AddBg(girc.CONNECTED, func(_ *girc.Client, _ girc.Event) { for _, watchlist := range appConfig.WatchLists { log.Print("joining ", watchlist.AlertChannel) IrcJoin(irc, watchlist.AlertChannel) @@ -1417,36 +1511,51 @@ func runIRC(appConfig TomlConfig) { if len(appConfig.Rss) > 0 { irc.Handlers.AddBg(girc.CONNECTED, func(client *girc.Client, _ girc.Event) { - go runRSS(&appConfig, irc) + for _, rss := range appConfig.Rss { + log.Print("RSS: joining ", rss.Channel) + IrcJoin(irc, rss.Channel) + } }) - } - for { - var dialer proxy.Dialer + go runRSS(&appConfig, irc) + } - if appConfig.IRCProxy != "" { - proxyURL, err := url.Parse(appConfig.IRCProxy) - if err != nil { - LogErrorFatal(err) - } + var dialer proxy.Dialer - dialer, err = proxy.FromURL(proxyURL, &net.Dialer{Timeout: time.Duration(appConfig.RequestTimeout) * time.Second}) - if err != nil { - LogErrorFatal(err) - } + if appConfig.IRCProxy != "" { + proxyURL, err := url.Parse(appConfig.IRCProxy) + if err != nil { + LogErrorFatal(err) } - if err := irc.DialerConnect(dialer); err != nil { - LogError(err) - log.Println("reconnecting in " + strconv.Itoa(appConfig.MillaReconnectDelay)) - time.Sleep(time.Duration(appConfig.MillaReconnectDelay) * time.Second) - } else { - return + dialer, err = proxy.FromURL(proxyURL, &net.Dialer{Timeout: time.Duration(appConfig.RequestTimeout) * time.Second}) + if err != nil { + LogErrorFatal(err) } } + + connectToIRC := func() (string, error) { + return "", irc.DialerConnect(dialer) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.MillaReconnectDelay)*time.Second) + defer cancel() + + _, err := backoff.Retry(ctx, connectToIRC, backoff.WithBackOff(backoff.NewExponentialBackOff())) + if err != nil { + LogError(err) + } else { + return + } +} + +func goroutines() interface{} { + return runtime.NumGoroutine() } func main() { + expvar.Publish("Goroutines", expvar.Func(goroutines)) + quitChannel := make(chan os.Signal, 1) signal.Notify(quitChannel, syscall.SIGINT, syscall.SIGTERM) @@ -1480,5 +1589,10 @@ func main() { go runIRC(v) } + go func() { + err := http.ListenAndServe(":6060", nil) + log.Println(err) + }() + <-quitChannel } |