diff options
author | terminaldweller <devi@terminaldweller.com> | 2024-05-21 01:52:20 +0000 |
---|---|---|
committer | terminaldweller <devi@terminaldweller.com> | 2024-05-21 01:52:20 +0000 |
commit | 7f42eece2cc4472ad3470120d00ca47581be3211 (patch) | |
tree | fbd1af02c499e2378e5360c7aaf585f28f76d668 /main.go | |
parent | updated readme (diff) | |
download | milla-7f42eece2cc4472ad3470120d00ca47581be3211.tar.gz milla-7f42eece2cc4472ad3470120d00ca47581be3211.zip |
code clean up
Diffstat (limited to '')
-rw-r--r-- | main.go | 546 |
1 files changed, 323 insertions, 223 deletions
@@ -41,6 +41,16 @@ var ( errUnsupportedType = errors.New("unsupported type") ) +type CustomCommand struct { + SQL string `json:"sql"` + Limit int `json:"limit"` + Prompt string `json:"prompt"` +} + +type CustomCommands struct { + CustomCommands map[string]CustomCommand `json:"customCommands"` +} + type TomlConfig struct { IrcServer string `toml:"ircServer"` IrcNick string `toml:"ircNick"` @@ -64,6 +74,7 @@ type TomlConfig struct { LLMProxy string `toml:"llmProxy"` IRCProxy string `toml:"ircProxy"` IRCDName string `toml:"ircdName"` + CommandsFile string `toml:"commandsFile"` Temp float64 `toml:"temp"` RequestTimeout int `toml:"requestTimeout"` MillaReconnectDelay int `toml:"millaReconnectDelay"` @@ -124,40 +135,44 @@ func addSaneDefaults(config *TomlConfig) { config.DatabaseName = "milladb" } + if config.CommandsFile == "" { + config.CommandsFile = "./commands.json" + } + if config.Temp == 0 { config.Temp = 0.5 //nollint:gomnd } if config.RequestTimeout == 0 { - config.RequestTimeout = 10 //nolint:gomnd + config.RequestTimeout = 10 } if config.MillaReconnectDelay == 0 { - config.MillaReconnectDelay = 30 //nolint:gomnd + config.MillaReconnectDelay = 30 } if config.IrcPort == 0 { - config.IrcPort = 6697 //nolint:gomnd + config.IrcPort = 6697 } if config.KeepAlive == 0 { - config.KeepAlive = 600 //nolint:gomnd + config.KeepAlive = 600 } if config.MemoryLimit == 0 { - config.MemoryLimit = 20 //nolint:gomnd + config.MemoryLimit = 20 } if config.PingDelay == 0 { - config.PingDelay = 20 //nolint:gomnd + config.PingDelay = 20 } if config.PingTimeout == 0 { - config.PingTimeout = 20 //nolint:gomnd + config.PingTimeout = 20 } - if config.TopP == 0. { - config.TopP = 0.9 //nolint:gomnd + if config.TopP == 0.0 { + config.TopP = 0.9 } } @@ -275,6 +290,10 @@ func sendToIRC( chunks := chunker(message, chromaFormatter) for _, chunk := range chunks { + if len(strings.TrimSpace(chunk)) == 0 { + continue + } + client.Cmd.Reply(event, chunk) } } @@ -333,7 +352,7 @@ func setFieldByName(v reflect.Value, field string, value string) error { func byteToMByte(bytes uint64, ) uint64 { - return bytes / 1024 / 1024 + return bytes / 1024 / 1024 //nolint:gomnd } func runCommand( @@ -404,6 +423,7 @@ func runCommand( } case "memstats": var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) client.Cmd.Reply(event, fmt.Sprintf("Alloc: %d MiB", byteToMByte(memStats.Alloc))) @@ -414,149 +434,147 @@ func runCommand( } } -func ollamaHandler( - irc *girc.Client, +func doOllamaRequest( appConfig *TomlConfig, + client *girc.Client, + event girc.Event, ollamaMemory *[]MemoryElement, -) { - irc.Handlers.AddBg(girc.PRIVMSG, func(client *girc.Client, event girc.Event) { - if !strings.HasPrefix(event.Last(), appConfig.IrcNick+": ") { - return - } - if appConfig.AdminOnly { - byAdmin := false - for _, admin := range appConfig.Admins { - if event.Source.Name == admin { - byAdmin = true - } - } - if !byAdmin { - return - } - } - prompt := strings.TrimPrefix(event.Last(), appConfig.IrcNick+": ") - log.Println(prompt) + prompt string, +) (*http.Response, error) { + var jsonPayload []byte - if string(prompt[0]) == "/" { - runCommand(client, event, appConfig) + var err error - return - } + memoryElement := MemoryElement{ + Role: "user", + Content: prompt, + } - var jsonPayload []byte - var err error + if len(*ollamaMemory) > appConfig.MemoryLimit { + *ollamaMemory = []MemoryElement{} + } - memoryElement := MemoryElement{ - Role: "user", - Content: prompt, - } + *ollamaMemory = append(*ollamaMemory, memoryElement) - if len(*ollamaMemory) > appConfig.MemoryLimit { - *ollamaMemory = []MemoryElement{} - } - *ollamaMemory = append(*ollamaMemory, memoryElement) + ollamaRequest := OllamaChatRequest{ + Model: appConfig.Model, + Keep_alive: time.Duration(appConfig.KeepAlive), + Stream: false, + Messages: *ollamaMemory, + Options: OllamaRequestOptions{ + Temperature: appConfig.Temp, + }, + } - ollamaRequest := OllamaChatRequest{ - Model: appConfig.Model, - Keep_alive: time.Duration(appConfig.KeepAlive), - Stream: false, - Messages: *ollamaMemory, - Options: OllamaRequestOptions{ - Temperature: appConfig.Temp, - }, - } - jsonPayload, err = json.Marshal(ollamaRequest) - log.Printf("json payload: %s", string(jsonPayload)) - if err != nil { - client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + jsonPayload, err = json.Marshal(ollamaRequest) + if err != nil { + client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) - return - } + return nil, err + } - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second) - defer cancel() + log.Printf("json payload: %s", string(jsonPayload)) + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second) + defer cancel() + + request, err := http.NewRequest(http.MethodPost, appConfig.OllamaEndpoint, bytes.NewBuffer(jsonPayload)) + if err != nil { + client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) - request, err := http.NewRequest(http.MethodPost, appConfig.OllamaEndpoint, bytes.NewBuffer(jsonPayload)) - request = request.WithContext(ctx) + return nil, err + } + + request = request.WithContext(ctx) + request.Header.Set("Content-Type", "application/json") + + var httpClient http.Client + + var dialer proxy.Dialer + + if appConfig.LLMProxy != "" { + proxyURL, err := url.Parse(appConfig.IRCProxy) if err != nil { - client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + cancel() - return + log.Fatal(err.Error()) } - request.Header.Set("Content-Type", "application/json") + dialer, err = proxy.FromURL(proxyURL, &net.Dialer{Timeout: time.Duration(appConfig.RequestTimeout) * time.Second}) + if err != nil { + cancel() - var httpClient http.Client + log.Fatal(err.Error()) + } - var dialer proxy.Dialer + httpClient = http.Client{ + Transport: &http.Transport{ + Dial: dialer.Dial, + }, + } + } - if appConfig.LLMProxy != "" { - proxyURL, err := url.Parse(appConfig.IRCProxy) - if err != nil { - cancel() + return httpClient.Do(request) +} - log.Fatal(err.Error()) - } +func ollamaRequest( + appConfig *TomlConfig, + client *girc.Client, + event girc.Event, + ollamaMemory *[]MemoryElement, + prompt string, +) { + response, err := doOllamaRequest(appConfig, client, event, ollamaMemory, prompt) - dialer, err = proxy.FromURL(proxyURL, &net.Dialer{Timeout: time.Duration(appConfig.RequestTimeout) * time.Second}) - if err != nil { - cancel() + if response == nil { + return + } - log.Fatal(err.Error()) - } + if err != nil { + client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) - httpClient = http.Client{ - Transport: &http.Transport{ - Dial: dialer.Dial, - }, - } - } + return + } + defer response.Body.Close() - response, err := httpClient.Do(request) - if err != nil { - client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + log.Println("response body:", response.Body) - return - } - defer response.Body.Close() + var writer bytes.Buffer - log.Println("response body:", response.Body) + var ollamaChatResponse OllamaChatMessagesResponse - var writer bytes.Buffer + err = json.NewDecoder(response.Body).Decode(&ollamaChatResponse) + if err != nil { + client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + } - var ollamaChatResponse OllamaChatMessagesResponse - err = json.NewDecoder(response.Body).Decode(&ollamaChatResponse) - if err != nil { - client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) - } + assistantElement := MemoryElement{ + Role: "assistant", + Content: ollamaChatResponse.Messages.Content, + } - assistantElement := MemoryElement{ - Role: "assistant", - Content: ollamaChatResponse.Messages.Content, - } + *ollamaMemory = append(*ollamaMemory, assistantElement) - *ollamaMemory = append(*ollamaMemory, assistantElement) + log.Println(ollamaChatResponse) - log.Println(ollamaChatResponse) - err = quick.Highlight(&writer, - ollamaChatResponse.Messages.Content, - "markdown", - appConfig.ChromaFormatter, - appConfig.ChromaStyle) - if err != nil { - client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + err = quick.Highlight(&writer, + ollamaChatResponse.Messages.Content, + "markdown", + appConfig.ChromaFormatter, + appConfig.ChromaStyle) + if err != nil { + client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) - return - } + return + } - sendToIRC(client, event, writer.String(), appConfig.ChromaFormatter) - }) + sendToIRC(client, event, writer.String(), appConfig.ChromaFormatter) } -func geminiHandler( +func ollamaHandler( irc *girc.Client, appConfig *TomlConfig, - geminiMemory *[]*genai.Content, + ollamaMemory *[]MemoryElement, ) { irc.Handlers.AddBg(girc.PRIVMSG, func(client *girc.Client, event girc.Event) { if !strings.HasPrefix(event.Last(), appConfig.IrcNick+": ") { @@ -582,75 +600,96 @@ func geminiHandler( return } - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second) - defer cancel() - - clientGemini, err := genai.NewClient(ctx, option.WithAPIKey(appConfig.Apikey)) - if err != nil { - client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + ollamaRequest(appConfig, client, event, ollamaMemory, prompt) + }) +} - return - } - defer clientGemini.Close() +func doGeminiRequest( + appConfig *TomlConfig, + client *girc.Client, + event girc.Event, + geminiMemory *[]*genai.Content, + prompt string, +) string { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second) + defer cancel() - model := clientGemini.GenerativeModel(appConfig.Model) - model.SetTemperature(float32(appConfig.Temp)) - model.SetTopK(appConfig.TopK) - model.SetTopP(appConfig.TopP) + clientGemini, err := genai.NewClient(ctx, option.WithAPIKey(appConfig.Apikey)) + if err != nil { + client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) - cs := model.StartChat() + return "" + } + defer clientGemini.Close() - cs.History = *geminiMemory + model := clientGemini.GenerativeModel(appConfig.Model) + model.SetTemperature(float32(appConfig.Temp)) + model.SetTopK(appConfig.TopK) + model.SetTopP(appConfig.TopP) - resp, err := cs.SendMessage(ctx, genai.Text(prompt)) - if err != nil { - client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + cs := model.StartChat() - return - } + cs.History = *geminiMemory - geminiResponse := returnGeminiResponse(resp) - log.Println(geminiResponse) + resp, err := cs.SendMessage(ctx, genai.Text(prompt)) + if err != nil { + client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) - if len(*geminiMemory) > appConfig.MemoryLimit { - *geminiMemory = []*genai.Content{} - } + return "" + } - *geminiMemory = append(*geminiMemory, &genai.Content{ - Parts: []genai.Part{ - genai.Text(prompt), - }, - Role: "user", - }) + return returnGeminiResponse(resp) +} - *geminiMemory = append(*geminiMemory, &genai.Content{ - Parts: []genai.Part{ - genai.Text(geminiResponse), - }, - Role: "model", - }) +func geminiRequest( + appConfig *TomlConfig, + client *girc.Client, + event girc.Event, + geminiMemory *[]*genai.Content, + prompt string, +) { + geminiResponse := doGeminiRequest(appConfig, client, event, geminiMemory, prompt) + log.Println(geminiResponse) - var writer bytes.Buffer - err = quick.Highlight( - &writer, - geminiResponse, - "markdown", - appConfig.ChromaFormatter, - appConfig.ChromaStyle) - if err != nil { - client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + if len(*geminiMemory) > appConfig.MemoryLimit { + *geminiMemory = []*genai.Content{} + } - return - } + *geminiMemory = append(*geminiMemory, &genai.Content{ + Parts: []genai.Part{ + genai.Text(prompt), + }, + Role: "user", + }) - sendToIRC(client, event, writer.String(), appConfig.ChromaFormatter) + *geminiMemory = append(*geminiMemory, &genai.Content{ + Parts: []genai.Part{ + genai.Text(geminiResponse), + }, + Role: "model", }) + + var writer bytes.Buffer + + err := quick.Highlight( + &writer, + geminiResponse, + "markdown", + appConfig.ChromaFormatter, + appConfig.ChromaStyle) + if err != nil { + client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + + return + } + + sendToIRC(client, event, writer.String(), appConfig.ChromaFormatter) } -func chatGPTHandler( +func geminiHandler( irc *girc.Client, appConfig *TomlConfig, - gptMemory *[]openai.ChatCompletionMessage, + geminiMemory *[]*genai.Content, ) { irc.Handlers.AddBg(girc.PRIVMSG, func(client *girc.Client, event girc.Event) { if !strings.HasPrefix(event.Last(), appConfig.IrcNick+": ") { @@ -676,80 +715,134 @@ func chatGPTHandler( return } - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second) - defer cancel() + geminiRequest(appConfig, client, event, geminiMemory, prompt) + }) +} - var httpClient http.Client +func doChatGPTRequest( + appConfig *TomlConfig, + client *girc.Client, + event girc.Event, + gptMemory *[]openai.ChatCompletionMessage, + prompt string, +) (openai.ChatCompletionResponse, error) { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second) + defer cancel() - if appConfig.LLMProxy != "" { - proxyURL, err := url.Parse(appConfig.IRCProxy) - if err != nil { - cancel() + var httpClient http.Client - log.Fatal(err.Error()) - } + if appConfig.LLMProxy != "" { + proxyURL, err := url.Parse(appConfig.IRCProxy) + if err != nil { + cancel() + client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) - dialer, err := proxy.FromURL(proxyURL, &net.Dialer{Timeout: time.Duration(appConfig.RequestTimeout) * time.Second}) - if err != nil { - cancel() + log.Fatal(err.Error()) + } - log.Fatal(err.Error()) - } + dialer, err := proxy.FromURL(proxyURL, &net.Dialer{Timeout: time.Duration(appConfig.RequestTimeout) * time.Second}) + if err != nil { + cancel() + client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) - httpClient = http.Client{ - Transport: &http.Transport{ - Dial: dialer.Dial, - }, - } + log.Fatal(err.Error()) } - if appConfig.Apikey == "" { - appConfig.Apikey = os.Getenv("MILLA_APIKEY") + httpClient = http.Client{ + Transport: &http.Transport{ + Dial: dialer.Dial, + }, } + } - config := openai.DefaultConfig(appConfig.Apikey) - config.HTTPClient = &httpClient + config := openai.DefaultConfig(appConfig.Apikey) + config.HTTPClient = &httpClient - gptClient := openai.NewClientWithConfig(config) + gptClient := openai.NewClientWithConfig(config) - *gptMemory = append(*gptMemory, openai.ChatCompletionMessage{ - Role: openai.ChatMessageRoleUser, - Content: prompt, - }) + *gptMemory = append(*gptMemory, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, + Content: prompt, + }) - resp, err := gptClient.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ - Model: appConfig.Model, - Messages: *gptMemory, - }) - if err != nil { - client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + resp, err := gptClient.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ + Model: appConfig.Model, + Messages: *gptMemory, + }) - return - } + return resp, err +} - *gptMemory = append(*gptMemory, openai.ChatCompletionMessage{ - Role: openai.ChatMessageRoleAssistant, - Content: resp.Choices[0].Message.Content, - }) +func chatGPTRequest( + appConfig *TomlConfig, + client *girc.Client, + event girc.Event, + gptMemory *[]openai.ChatCompletionMessage, + prompt string, +) { + resp, err := doChatGPTRequest(appConfig, client, event, gptMemory, prompt) + if err != nil { + client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) - if len(*gptMemory) > appConfig.MemoryLimit { - *gptMemory = []openai.ChatCompletionMessage{} + return + } + + *gptMemory = append(*gptMemory, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + Content: resp.Choices[0].Message.Content, + }) + + if len(*gptMemory) > appConfig.MemoryLimit { + *gptMemory = []openai.ChatCompletionMessage{} + } + + var writer bytes.Buffer + + err = quick.Highlight( + &writer, + resp.Choices[0].Message.Content, + "markdown", + appConfig.ChromaFormatter, + appConfig.ChromaStyle) + if err != nil { + client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + + return + } + + sendToIRC(client, event, writer.String(), appConfig.ChromaFormatter) +} + +func chatGPTHandler( + irc *girc.Client, + appConfig *TomlConfig, + gptMemory *[]openai.ChatCompletionMessage, +) { + irc.Handlers.AddBg(girc.PRIVMSG, func(client *girc.Client, event girc.Event) { + if !strings.HasPrefix(event.Last(), appConfig.IrcNick+": ") { + return } + if appConfig.AdminOnly { + byAdmin := false + for _, admin := range appConfig.Admins { + if event.Source.Name == admin { + byAdmin = true + } + } + if !byAdmin { + return + } + } + prompt := strings.TrimPrefix(event.Last(), appConfig.IrcNick+": ") + log.Println(prompt) - var writer bytes.Buffer - err = quick.Highlight( - &writer, - resp.Choices[0].Message.Content, - "markdown", - appConfig.ChromaFormatter, - appConfig.ChromaStyle) - if err != nil { - client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + if string(prompt[0]) == "/" { + runCommand(client, event, appConfig) return } - sendToIRC(client, event, writer.String(), appConfig.ChromaFormatter) + chatGPTRequest(appConfig, client, event, gptMemory, prompt) }) } @@ -794,7 +887,14 @@ func connectToDB(appConfig TomlConfig, ctx *context.Context, poolChan chan *pgxp for _, channel := range appConfig.ScrapeChannels { tableName := getTableFromChanName(channel, appConfig.IRCDName) - query := fmt.Sprintf("create table if not exists %s (id serial primary key,channel text not null,log text not null,nick text not null,dateadded timestamp default current_timestamp)", tableName) + query := fmt.Sprintf( + `create table if not exists %s ( + id serial primary key, + channel text not null, + log text not null, + nick text not null, + dateadded timestamp default current_timestamp + )`, tableName) _, err = pool.Exec(*ctx, query) if err != nil { log.Println(err.Error()) @@ -814,7 +914,7 @@ func scrapeChannel(irc *girc.Client, poolChan chan *pgxpool.Pool, appConfig Toml query := fmt.Sprintf( "insert into %s (channel,log,nick) values ('%s','%s','%s')", tableName, - event.Params[0], + sanitizeLog(event.Params[0]), event.Last(), event.Source.Name, ) @@ -848,7 +948,7 @@ func runIRC(appConfig TomlConfig) { DisableSTSFallback: appConfig.DisableSTSFallback, GlobalFormat: true, TLSConfig: &tls.Config{ - InsecureSkipVerify: appConfig.SkipTLSVerify, // #nosec G402 + InsecureSkipVerify: appConfig.SkipTLSVerify, ServerName: appConfig.IrcServer, }, }) |