aboutsummaryrefslogtreecommitdiffstats
path: root/main.go
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--main.go298
1 files changed, 206 insertions, 92 deletions
diff --git a/main.go b/main.go
index 082586f..265c8c9 100644
--- a/main.go
+++ b/main.go
@@ -6,6 +6,7 @@ import (
"crypto/tls"
"encoding/json"
"errors"
+ "expvar"
"flag"
"fmt"
"index/suffixarray"
@@ -13,6 +14,7 @@ import (
"math/rand"
"net"
"net/http"
+ _ "net/http/pprof"
"net/url"
"os"
"os/signal"
@@ -26,6 +28,7 @@ import (
"github.com/BurntSushi/toml"
"github.com/alecthomas/chroma/v2/quick"
+ "github.com/cenkalti/backoff/v5"
"github.com/google/generative-ai-go/genai"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
@@ -69,8 +72,8 @@ func addSaneDefaults(config *TomlConfig) {
config.DatabaseName = "milladb"
}
- if config.Temp == 0 {
- config.Temp = 0.5
+ if config.Temperature == 0 {
+ config.Temperature = 0.5
}
if config.RequestTimeout == 0 {
@@ -101,9 +104,49 @@ func addSaneDefaults(config *TomlConfig) {
config.PingTimeout = 20
}
+ if config.OllamaMirostatEta == 0 {
+ config.OllamaMirostatEta = 0.1
+ }
+
+ if config.OllamaMirostatTau == 0 {
+ config.OllamaMirostatTau = 5.0
+ }
+
+ if config.OllamaNumCtx == 0 {
+ config.OllamaNumCtx = 4096
+ }
+
+ if config.OllamaRepeatLastN == 0 {
+ config.OllamaRepeatLastN = 64
+ }
+
+ if config.OllamaRepeatPenalty == 0 {
+ config.OllamaRepeatPenalty = 1.1
+ }
+
+ if config.OllamaSeed == 0 {
+ config.OllamaSeed = 42
+ }
+
+ if config.OllamaNumPredict == 0 {
+ config.OllamaNumPredict = -1
+ }
+
+ if config.TopK == 0 {
+ config.TopK = 40
+ }
+
if config.TopP == 0.0 {
config.TopP = 0.9
}
+
+ if config.OllamaMinP == 0 {
+ config.OllamaMinP = 0.05
+ }
+
+ if config.Temperature == 0 {
+ config.Temperature = 0.7
+ }
}
func getTableFromChanName(channel, ircdName string) string {
@@ -316,7 +359,7 @@ func handleCustomCommand(
bigPrompt += log.Log + "\n"
}
- result := ChatGPTRequestProcessor(appConfig, client, event, &gptMemory, customCommand.Prompt)
+ result := ChatGPTRequestProcessor(appConfig, client, event, &gptMemory, customCommand.Prompt, customCommand.SystemPrompt)
if result != "" {
SendToIRC(client, event, result, appConfig.ChromaFormatter)
}
@@ -341,7 +384,7 @@ func handleCustomCommand(
})
}
- result := GeminiRequestProcessor(appConfig, client, event, &geminiMemory, customCommand.Prompt)
+ result := GeminiRequestProcessor(appConfig, client, event, &geminiMemory, customCommand.Prompt, customCommand.SystemPrompt)
if result != "" {
SendToIRC(client, event, result, appConfig.ChromaFormatter)
}
@@ -362,7 +405,7 @@ func handleCustomCommand(
})
}
- result := OllamaRequestProcessor(appConfig, client, event, &ollamaMemory, customCommand.Prompt)
+ result := OllamaRequestProcessor(appConfig, client, event, &ollamaMemory, customCommand.Prompt, customCommand.SystemPrompt)
if result != "" {
SendToIRC(client, event, result, appConfig.ChromaFormatter)
}
@@ -541,7 +584,7 @@ func runCommand(
appConfig.deleteLstate(args[1])
case "remind":
- if len(args) < 2 {
+ if len(args) < 2 { //nolint: mnd,gomnd
client.Cmd.Reply(event, errNotEnoughArgs.Error())
break
@@ -559,10 +602,9 @@ func runCommand(
client.Cmd.ReplyTo(event, " Ping!")
case "forget":
-
client.Cmd.Reply(event, "I no longer even know whether you're supposed to wear or drink a camel.'")
case "whois":
- if len(args) < 2 {
+ if len(args) < 2 { //nolint: mnd,gomnd
client.Cmd.Reply(event, errNotEnoughArgs.Error())
break
@@ -575,9 +617,8 @@ func runCommand(
upperLimit := 6
if len(args) == 1 {
- } else if len(args) == 2 {
+ } else if len(args) == 2 { //nolint: mnd,gomnd
argOne, err := strconv.Atoi(args[1])
-
if err != nil {
client.Cmd.Reply(event, errNotEnoughArgs.Error())
@@ -585,9 +626,8 @@ func runCommand(
}
upperLimit = argOne
- } else if len(args) == 3 {
+ } else if len(args) == 3 { //nolint: mnd,gomnd
argOne, err := strconv.Atoi(args[1])
-
if err != nil {
client.Cmd.Reply(event, errNotEnoughArgs.Error())
@@ -597,7 +637,6 @@ func runCommand(
lowerLimit = argOne
argTwo, err := strconv.Atoi(args[2])
-
if err != nil {
client.Cmd.Reply(event, errNotEnoughArgs.Error())
@@ -636,7 +675,7 @@ func runCommand(
func DoOllamaRequest(
appConfig *TomlConfig,
ollamaMemory *[]MemoryElement,
- prompt string,
+ prompt, systemPrompt string,
) (string, error) {
var jsonPayload []byte
@@ -665,14 +704,25 @@ func DoOllamaRequest(
KeepAlive: time.Duration(appConfig.KeepAlive),
Stream: false,
Messages: *ollamaMemory,
+ System: systemPrompt,
Options: OllamaRequestOptions{
- Temperature: appConfig.Temp,
+ Mirostat: appConfig.OllamaMirostat,
+ MirostatEta: appConfig.OllamaMirostatEta,
+ MirostatTau: appConfig.OllamaMirostatTau,
+ NumCtx: appConfig.OllamaNumCtx,
+ RepeatLastN: appConfig.OllamaRepeatLastN,
+ RepeatPenalty: appConfig.OllamaRepeatPenalty,
+ Temperature: appConfig.Temperature,
+ Seed: appConfig.OllamaSeed,
+ NumPredict: appConfig.OllamaNumPredict,
+ TopK: appConfig.TopK,
+ TopP: appConfig.TopP,
+ MinP: appConfig.OllamaMinP,
},
}
jsonPayload, err = json.Marshal(ollamaRequest)
if err != nil {
-
return "", err
}
@@ -683,7 +733,6 @@ func DoOllamaRequest(
request, err := http.NewRequest(http.MethodPost, appConfig.Endpoint, bytes.NewBuffer(jsonPayload))
if err != nil {
-
return "", err
}
@@ -715,16 +764,14 @@ func DoOllamaRequest(
},
}
}
- response, err := httpClient.Do(request)
+ response, err := httpClient.Do(request)
if err != nil {
return "", err
}
defer response.Body.Close()
- log.Println("response body:", response.Body)
-
var ollamaChatResponse OllamaChatMessagesResponse
err = json.NewDecoder(response.Body).Decode(&ollamaChatResponse)
@@ -732,6 +779,8 @@ func DoOllamaRequest(
return "", err
}
+ log.Println("ollama chat response: ", ollamaChatResponse)
+
return ollamaChatResponse.Messages.Content, nil
}
@@ -740,9 +789,9 @@ func OllamaRequestProcessor(
client *girc.Client,
event girc.Event,
ollamaMemory *[]MemoryElement,
- prompt string,
+ prompt, systemPrompt string,
) string {
- response, err := DoOllamaRequest(appConfig, ollamaMemory, prompt)
+ response, err := DoOllamaRequest(appConfig, ollamaMemory, prompt, systemPrompt)
if err != nil {
client.Cmd.ReplyTo(event, "error: "+err.Error())
@@ -807,7 +856,7 @@ func OllamaHandler(
return
}
- result := OllamaRequestProcessor(appConfig, client, event, ollamaMemory, prompt)
+ result := OllamaRequestProcessor(appConfig, client, event, ollamaMemory, prompt, appConfig.SystemPrompt)
if result != "" {
SendToIRC(client, event, result, appConfig.ChromaFormatter)
}
@@ -822,6 +871,7 @@ func (t *ProxyRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
if err != nil {
return nil, err
}
+
transport.Proxy = http.ProxyURL(proxyURL)
}
@@ -841,7 +891,7 @@ func (t *ProxyRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
func DoGeminiRequest(
appConfig *TomlConfig,
geminiMemory *[]*genai.Content,
- prompt string,
+ prompt, systemPrompt string,
) (string, error) {
httpProxyClient := &http.Client{Transport: &ProxyRoundTripper{
APIKey: appConfig.Apikey,
@@ -853,15 +903,37 @@ func DoGeminiRequest(
clientGemini, err := genai.NewClient(ctx, option.WithHTTPClient(httpProxyClient))
if err != nil {
-
- return "", err
+ return "", fmt.Errorf("Could not create a genai client.", err)
}
defer clientGemini.Close()
model := clientGemini.GenerativeModel(appConfig.Model)
- model.SetTemperature(float32(appConfig.Temp))
+ model.SetTemperature(float32(appConfig.Temperature))
model.SetTopK(appConfig.TopK)
model.SetTopP(appConfig.TopP)
+ model.SystemInstruction = &genai.Content{
+ Parts: []genai.Part{
+ genai.Text(systemPrompt),
+ },
+ }
+ model.SafetySettings = []*genai.SafetySetting{
+ {
+ Category: genai.HarmCategoryDangerousContent,
+ Threshold: genai.HarmBlockNone,
+ },
+ {
+ Category: genai.HarmCategoryHarassment,
+ Threshold: genai.HarmBlockNone,
+ },
+ {
+ Category: genai.HarmCategoryHateSpeech,
+ Threshold: genai.HarmBlockNone,
+ },
+ {
+ Category: genai.HarmCategorySexuallyExplicit,
+ Threshold: genai.HarmBlockNone,
+ },
+ }
cs := model.StartChat()
@@ -869,8 +941,7 @@ func DoGeminiRequest(
resp, err := cs.SendMessage(ctx, genai.Text(prompt))
if err != nil {
-
- return "", err
+ return "", fmt.Errorf("Gemini: Could not send message", err)
}
return returnGeminiResponse(resp), nil
@@ -881,9 +952,9 @@ func GeminiRequestProcessor(
client *girc.Client,
event girc.Event,
geminiMemory *[]*genai.Content,
- prompt string,
+ prompt, systemPrompt string,
) string {
- geminiResponse, err := DoGeminiRequest(appConfig, geminiMemory, prompt)
+ geminiResponse, err := DoGeminiRequest(appConfig, geminiMemory, prompt, systemPrompt)
if err != nil {
client.Cmd.ReplyTo(event, "error: "+err.Error())
@@ -969,7 +1040,7 @@ func GeminiHandler(
return
}
- result := GeminiRequestProcessor(appConfig, client, event, geminiMemory, prompt)
+ result := GeminiRequestProcessor(appConfig, client, event, geminiMemory, prompt, appConfig.SystemPrompt)
if result != "" {
SendToIRC(client, event, result, appConfig.ChromaFormatter)
@@ -980,7 +1051,7 @@ func GeminiHandler(
func DoChatGPTRequest(
appConfig *TomlConfig,
gptMemory *[]openai.ChatCompletionMessage,
- prompt string,
+ prompt, systemPrompt string,
) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second)
defer cancel()
@@ -1011,6 +1082,7 @@ func DoChatGPTRequest(
config := openai.DefaultConfig(appConfig.Apikey)
config.HTTPClient = &httpClient
+
if appConfig.Endpoint != "" {
config.BaseURL = appConfig.Endpoint
log.Print(config.BaseURL)
@@ -1019,6 +1091,11 @@ func DoChatGPTRequest(
gptClient := openai.NewClientWithConfig(config)
*gptMemory = append(*gptMemory, openai.ChatCompletionMessage{
+ Role: openai.ChatMessageRoleSystem,
+ Content: systemPrompt,
+ })
+
+ *gptMemory = append(*gptMemory, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: prompt,
})
@@ -1028,7 +1105,6 @@ func DoChatGPTRequest(
Messages: *gptMemory,
})
if err != nil {
-
return "", err
}
@@ -1040,9 +1116,9 @@ func ChatGPTRequestProcessor(
client *girc.Client,
event girc.Event,
gptMemory *[]openai.ChatCompletionMessage,
- prompt string,
+ prompt, systemPrompt string,
) string {
- resp, err := DoChatGPTRequest(appConfig, gptMemory, prompt)
+ resp, err := DoChatGPTRequest(appConfig, gptMemory, prompt, systemPrompt)
if err != nil {
client.Cmd.ReplyTo(event, "error: "+err.Error())
@@ -1115,14 +1191,16 @@ func ChatGPTHandler(
return
}
- result := ChatGPTRequestProcessor(appConfig, client, event, gptMemory, prompt)
+ result := ChatGPTRequestProcessor(appConfig, client, event, gptMemory, prompt, appConfig.SystemPrompt)
if result != "" {
SendToIRC(client, event, result, appConfig.ChromaFormatter)
}
})
}
-func connectToDB(appConfig *TomlConfig, ctx *context.Context, poolChan chan *pgxpool.Pool) {
+func connectToDB(appConfig *TomlConfig, ctx *context.Context, irc *girc.Client) {
+ var pool *pgxpool.Pool
+
dbURL := fmt.Sprintf(
"postgres://%s:%s@%s/%s",
appConfig.DatabaseUser,
@@ -1134,40 +1212,54 @@ func connectToDB(appConfig *TomlConfig, ctx *context.Context, poolChan chan *pgx
poolConfig, err := pgxpool.ParseConfig(dbURL)
if err != nil {
- LogErrorFatal(err)
+ LogError(err)
+
+ return
}
- pool, err := pgxpool.NewWithConfig(*ctx, poolConfig)
+ dbConnect := func() (*pgxpool.Pool, error) {
+ return pgxpool.NewWithConfig(*ctx, poolConfig)
+ }
+
+ pool, err = backoff.Retry(*ctx, dbConnect, backoff.WithBackOff(backoff.NewExponentialBackOff()))
if err != nil {
- LogErrorFatal(err)
- } else {
- log.Printf("%s connected to database", appConfig.IRCDName)
-
- for _, channel := range appConfig.ScrapeChannels {
- tableName := getTableFromChanName(channel[0], appConfig.IRCDName)
- query := fmt.Sprintf(
- `create table if not exists %s (
- id serial primary key,
- channel text not null,
- log text not null,
- nick text not null,
- dateadded timestamp default current_timestamp
- )`, tableName)
-
- _, err = pool.Exec(*ctx, query)
- if err != nil {
- LogErrorFatal(err)
- }
- }
+ LogError(err)
+ }
+
+ log.Printf("%s connected to database", appConfig.IRCDName)
- appConfig.pool = pool
- poolChan <- pool
+ for _, channel := range appConfig.ScrapeChannels {
+ tableName := getTableFromChanName(channel[0], appConfig.IRCDName)
+ query := fmt.Sprintf(
+ `create table if not exists %s (
+ id serial primary key,
+ channel text not null,
+ log text not null,
+ nick text not null,
+ dateadded timestamp default current_timestamp
+ )`, tableName)
+
+ _, err := pool.Exec(*ctx, query)
+ if err != nil {
+ LogError(err)
+
+ continue
+ }
}
+
+ appConfig.pool = pool
}
-func scrapeChannel(irc *girc.Client, poolChan chan *pgxpool.Pool, appConfig TomlConfig) {
+func scrapeChannel(irc *girc.Client, appConfig *TomlConfig) {
+ log.Print("spawning scraper")
+
irc.Handlers.AddBg(girc.PRIVMSG, func(_ *girc.Client, event girc.Event) {
- pool := <-poolChan
+ if appConfig.pool == nil {
+ log.Println("no db connection. cant write scrapes to db.")
+
+ return
+ }
+
tableName := getTableFromChanName(event.Params[0], appConfig.IRCDName)
query := fmt.Sprintf(
"insert into %s (channel,log,nick) values ('%s','%s','%s')",
@@ -1177,7 +1269,12 @@ func scrapeChannel(irc *girc.Client, poolChan chan *pgxpool.Pool, appConfig Toml
event.Source.Name,
)
- _, err := pool.Exec(context.Background(), query)
+ log.Println(query)
+
+ ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second)
+ defer cancel()
+
+ _, err := appConfig.pool.Exec(ctx, query)
if err != nil {
LogError(err)
}
@@ -1202,7 +1299,6 @@ func populateWatchListWords(appConfig *TomlConfig) {
appConfig.WatchLists[watchlistName] = watchlist
}
}
-
}
func WatchListHandler(irc *girc.Client, appConfig TomlConfig) {
@@ -1269,8 +1365,6 @@ func runIRC(appConfig TomlConfig) {
var ORMemory []MemoryElement
- poolChan := make(chan *pgxpool.Pool, 1)
-
irc := girc.New(girc.Config{
Server: appConfig.IrcServer,
Port: appConfig.IrcPort,
@@ -1332,7 +1426,7 @@ func runIRC(appConfig TomlConfig) {
irc.Config.TLSConfig.Certificates = []tls.Certificate{cert}
}
- irc.Handlers.AddBg(girc.CONNECTED, func(c *girc.Client, _ girc.Event) {
+ irc.Handlers.AddBg(girc.CONNECTED, func(_ *girc.Client, _ girc.Event) {
for _, channel := range appConfig.IrcChannels {
IrcJoin(irc, channel)
}
@@ -1385,21 +1479,21 @@ func runIRC(appConfig TomlConfig) {
context, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second)
defer cancel()
- go connectToDB(&appConfig, &context, poolChan)
+ go connectToDB(&appConfig, &context, irc)
}
if len(appConfig.ScrapeChannels) > 0 {
- irc.Handlers.AddBg(girc.CONNECTED, func(c *girc.Client, _ girc.Event) {
+ irc.Handlers.AddBg(girc.CONNECTED, func(_ *girc.Client, _ girc.Event) {
for _, channel := range appConfig.ScrapeChannels {
IrcJoin(irc, channel)
}
})
- go scrapeChannel(irc, poolChan, appConfig)
+ go scrapeChannel(irc, &appConfig)
}
if len(appConfig.WatchLists) > 0 {
- irc.Handlers.AddBg(girc.CONNECTED, func(client *girc.Client, _ girc.Event) {
+ irc.Handlers.AddBg(girc.CONNECTED, func(_ *girc.Client, _ girc.Event) {
for _, watchlist := range appConfig.WatchLists {
log.Print("joining ", watchlist.AlertChannel)
IrcJoin(irc, watchlist.AlertChannel)
@@ -1417,36 +1511,51 @@ func runIRC(appConfig TomlConfig) {
if len(appConfig.Rss) > 0 {
irc.Handlers.AddBg(girc.CONNECTED, func(client *girc.Client, _ girc.Event) {
- go runRSS(&appConfig, irc)
+ for _, rss := range appConfig.Rss {
+ log.Print("RSS: joining ", rss.Channel)
+ IrcJoin(irc, rss.Channel)
+ }
})
- }
- for {
- var dialer proxy.Dialer
+ go runRSS(&appConfig, irc)
+ }
- if appConfig.IRCProxy != "" {
- proxyURL, err := url.Parse(appConfig.IRCProxy)
- if err != nil {
- LogErrorFatal(err)
- }
+ var dialer proxy.Dialer
- dialer, err = proxy.FromURL(proxyURL, &net.Dialer{Timeout: time.Duration(appConfig.RequestTimeout) * time.Second})
- if err != nil {
- LogErrorFatal(err)
- }
+ if appConfig.IRCProxy != "" {
+ proxyURL, err := url.Parse(appConfig.IRCProxy)
+ if err != nil {
+ LogErrorFatal(err)
}
- if err := irc.DialerConnect(dialer); err != nil {
- LogError(err)
- log.Println("reconnecting in " + strconv.Itoa(appConfig.MillaReconnectDelay))
- time.Sleep(time.Duration(appConfig.MillaReconnectDelay) * time.Second)
- } else {
- return
+ dialer, err = proxy.FromURL(proxyURL, &net.Dialer{Timeout: time.Duration(appConfig.RequestTimeout) * time.Second})
+ if err != nil {
+ LogErrorFatal(err)
}
}
+
+ connectToIRC := func() (string, error) {
+ return "", irc.DialerConnect(dialer)
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.MillaReconnectDelay)*time.Second)
+ defer cancel()
+
+ _, err := backoff.Retry(ctx, connectToIRC, backoff.WithBackOff(backoff.NewExponentialBackOff()))
+ if err != nil {
+ LogError(err)
+ } else {
+ return
+ }
+}
+
+func goroutines() interface{} {
+ return runtime.NumGoroutine()
}
func main() {
+ expvar.Publish("Goroutines", expvar.Func(goroutines))
+
quitChannel := make(chan os.Signal, 1)
signal.Notify(quitChannel, syscall.SIGINT, syscall.SIGTERM)
@@ -1480,5 +1589,10 @@ func main() {
go runIRC(v)
}
+ go func() {
+ err := http.ListenAndServe(":6060", nil)
+ log.Println(err)
+ }()
+
<-quitChannel
}