aboutsummaryrefslogtreecommitdiffstats
path: root/main.go
diff options
context:
space:
mode:
Diffstat (limited to 'main.go')
-rw-r--r--main.go289
1 files changed, 195 insertions, 94 deletions
diff --git a/main.go b/main.go
index 4b896ce..f8f0739 100644
--- a/main.go
+++ b/main.go
@@ -25,6 +25,7 @@ import (
"github.com/BurntSushi/toml"
"github.com/alecthomas/chroma/v2/quick"
"github.com/google/generative-ai-go/genai"
+ "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/lrstanley/girc"
openai "github.com/sashabaranov/go-openai"
@@ -41,58 +42,67 @@ var (
errUnsupportedType = errors.New("unsupported type")
)
-type CustomCommand struct {
- SQL string `json:"sql"`
- Limit int `json:"limit"`
- Prompt string `json:"prompt"`
+type LogModel struct {
+ // Id int64 `db:"id"`
+ // Channel string `db:"channel"`
+ Log string `db:"log"`
+ // Nick string `db:"nick"`
+ // DateAdded pgtype.Date `db:"dateadded"`
}
-type CustomCommands struct {
- CustomCommands map[string]CustomCommand `json:"customCommands"`
+type CustomCommand struct {
+ SQL string `toml:"sql"`
+ Limit int `toml:"limit"`
+ Prompt string `toml:"prompt"`
}
type TomlConfig struct {
- IrcServer string `toml:"ircServer"`
- IrcNick string `toml:"ircNick"`
- IrcSaslUser string `toml:"ircSaslUser"`
- IrcSaslPass string `toml:"ircSaslPass"`
- OllamaEndpoint string `toml:"ollamaEndpoint"`
- Model string `toml:"model"`
- ChromaStyle string `toml:"chromaStyle"`
- ChromaFormatter string `toml:"chromaFormatter"`
- Provider string `toml:"provider"`
- Apikey string `toml:"apikey"`
- OllamaSystem string `toml:"ollamaSystem"`
- ClientCertPath string `toml:"clientCertPath"`
- ServerPass string `toml:"serverPass"`
- Bind string `toml:"bind"`
- Name string `toml:"name"`
- DatabaseAddress string `toml:"databaseAddress"`
- DatabasePassword string `toml:"databasePassword"`
- DatabaseUser string `toml:"databaseUser"`
- DatabaseName string `toml:"databaseName"`
- LLMProxy string `toml:"llmProxy"`
- IRCProxy string `toml:"ircProxy"`
- IRCDName string `toml:"ircdName"`
- CommandsFile string `toml:"commandsFile"`
- Temp float64 `toml:"temp"`
- RequestTimeout int `toml:"requestTimeout"`
- MillaReconnectDelay int `toml:"millaReconnectDelay"`
- IrcPort int `toml:"ircPort"`
- KeepAlive int `toml:"keepAlive"`
- MemoryLimit int `toml:"memoryLimit"`
- PingDelay int `toml:"pingDelay"`
- PingTimeout int `toml:"pingTimeout"`
- TopP float32 `toml:"topP"`
- TopK int32 `toml:"topK"`
- EnableSasl bool `toml:"enableSasl"`
- SkipTLSVerify bool `toml:"skipTLSVerify"`
- UseTLS bool `toml:"useTLS"`
- DisableSTSFallback bool `toml:"disableSTSFallback"`
- AllowFlood bool `toml:"allowFlood"`
- Debug bool `toml:"debug"`
- Out bool `toml:"out"`
- AdminOnly bool `toml:"adminOnly"`
+ IrcServer string `toml:"ircServer"`
+ IrcNick string `toml:"ircNick"`
+ IrcSaslUser string `toml:"ircSaslUser"`
+ IrcSaslPass string `toml:"ircSaslPass"`
+ OllamaEndpoint string `toml:"ollamaEndpoint"`
+ Model string `toml:"model"`
+ ChromaStyle string `toml:"chromaStyle"`
+ ChromaFormatter string `toml:"chromaFormatter"`
+ Provider string `toml:"provider"`
+ Apikey string `toml:"apikey"`
+ OllamaSystem string `toml:"ollamaSystem"`
+ ClientCertPath string `toml:"clientCertPath"`
+ ServerPass string `toml:"serverPass"`
+ Bind string `toml:"bind"`
+ Name string `toml:"name"`
+ DatabaseAddress string `toml:"databaseAddress"`
+ DatabasePassword string `toml:"databasePassword"`
+ DatabaseUser string `toml:"databaseUser"`
+ DatabaseName string `toml:"databaseName"`
+ LLMProxy string `toml:"llmProxy"`
+ IRCProxy string `toml:"ircProxy"`
+ IRCDName string `toml:"ircdName"`
+ WebIRCPassword string `toml:"webIRCPassword"`
+ WebIRCGateway string `toml:"webIRCGateway"`
+ WebIRCHostname string `toml:"webIRCHostname"`
+ WebIRCAddress string `toml:"webIRCAddress"`
+ CustomCommands map[string]CustomCommand `toml:"customCommands"`
+ Temp float64 `toml:"temp"`
+ RequestTimeout int `toml:"requestTimeout"`
+ MillaReconnectDelay int `toml:"millaReconnectDelay"`
+ IrcPort int `toml:"ircPort"`
+ KeepAlive int `toml:"keepAlive"`
+ MemoryLimit int `toml:"memoryLimit"`
+ PingDelay int `toml:"pingDelay"`
+ PingTimeout int `toml:"pingTimeout"`
+ TopP float32 `toml:"topP"`
+ TopK int32 `toml:"topK"`
+ EnableSasl bool `toml:"enableSasl"`
+ SkipTLSVerify bool `toml:"skipTLSVerify"`
+ UseTLS bool `toml:"useTLS"`
+ DisableSTSFallback bool `toml:"disableSTSFallback"`
+ AllowFlood bool `toml:"allowFlood"`
+ Debug bool `toml:"debug"`
+ Out bool `toml:"out"`
+ AdminOnly bool `toml:"adminOnly"`
+ pool *pgxpool.Pool
Admins []string `toml:"admins"`
IrcChannels []string `toml:"ircChannels"`
ScrapeChannels []string `toml:"scrapeChannels"`
@@ -107,10 +117,6 @@ func addSaneDefaults(config *TomlConfig) {
config.IrcNick = "milla"
}
- if config.IrcSaslUser == "" {
- config.IrcSaslUser = "milla"
- }
-
if config.ChromaStyle == "" {
config.ChromaStyle = "rose-pine-moon"
}
@@ -119,10 +125,6 @@ func addSaneDefaults(config *TomlConfig) {
config.ChromaFormatter = "noop"
}
- if config.Provider == "" {
- config.Provider = "ollam"
- }
-
if config.DatabaseAddress == "" {
config.DatabaseAddress = "postgres"
}
@@ -135,12 +137,8 @@ func addSaneDefaults(config *TomlConfig) {
config.DatabaseName = "milladb"
}
- if config.CommandsFile == "" {
- config.CommandsFile = "./commands.json"
- }
-
if config.Temp == 0 {
- config.Temp = 0.5 //nollint:gomnd
+ config.Temp = 0.5
}
if config.RequestTimeout == 0 {
@@ -190,11 +188,11 @@ type OllamaChatMessagesResponse struct {
}
type OllamaChatRequest struct {
- Model string `json:"model"`
- Stream bool `json:"stream"`
- Keep_alive time.Duration `json:"keep_alive"`
- Options OllamaRequestOptions `json:"options"`
- Messages []MemoryElement `json:"messages"`
+ Model string `json:"model"`
+ Stream bool `json:"stream"`
+ KeepAlive time.Duration `json:"keep_alive"`
+ Options OllamaRequestOptions `json:"options"`
+ Messages []MemoryElement `json:"messages"`
}
type MemoryElement struct {
@@ -303,6 +301,9 @@ func getHelpString() string {
helpString += "help - show this help message\n"
helpString += "set - set a configuration value\n"
helpString += "get - get a configuration value\n"
+ helpString += "join - joins a given channel\n"
+ helpString += "leave - leaves a given channel\n"
+ helpString += "cmd - run a custom command defined in the customcommands file\n"
helpString += "getall - returns all config options with their value\n"
helpString += "memstats - returns the memory status currently being used\n"
@@ -352,7 +353,7 @@ func setFieldByName(v reflect.Value, field string, value string) error {
func byteToMByte(bytes uint64,
) uint64 {
- return bytes / 1024 / 1024 //nolint:gomnd
+ return bytes / 1024 / 1024
}
func runCommand(
@@ -383,7 +384,7 @@ func runCommand(
case "help":
sendToIRC(client, event, getHelpString(), "noop")
case "set":
- if len(args) < 3 { //nolint:gomnd
+ if len(args) < 3 {
client.Cmd.Reply(event, errNotEnoughArgs.Error())
break
@@ -394,7 +395,7 @@ func runCommand(
client.Cmd.Reply(event, err.Error())
}
case "get":
- if len(args) < 2 { //nolint:gomnd
+ if len(args) < 2 {
client.Cmd.Reply(event, errNotEnoughArgs.Error())
break
@@ -429,6 +430,81 @@ func runCommand(
client.Cmd.Reply(event, fmt.Sprintf("Alloc: %d MiB", byteToMByte(memStats.Alloc)))
client.Cmd.Reply(event, fmt.Sprintf("TotalAlloc: %d MiB", byteToMByte(memStats.TotalAlloc)))
client.Cmd.Reply(event, fmt.Sprintf("Sys: %d MiB", byteToMByte(memStats.Sys)))
+ case "join":
+ if len(args) < 2 {
+ client.Cmd.Reply(event, errNotEnoughArgs.Error())
+
+ break
+ }
+
+ client.Cmd.Join(args[1])
+ case "leave":
+ if len(args) < 2 {
+ client.Cmd.Reply(event, errNotEnoughArgs.Error())
+
+ break
+ }
+
+ client.Cmd.Part(args[1])
+ case "cmd":
+ if len(args) < 2 {
+ client.Cmd.Reply(event, errNotEnoughArgs.Error())
+
+ break
+ }
+
+ customCommand := appConfig.CustomCommands[args[1]]
+
+ if customCommand.SQL == "" {
+ client.Cmd.Reply(event, "empty sql commands in the custom command")
+
+ break
+ }
+
+ if appConfig.pool == nil {
+ client.Cmd.Reply(event, "no database connection")
+
+ break
+ }
+
+ log.Println(customCommand.SQL)
+
+ rows, err := appConfig.pool.Query(context.Background(), customCommand.SQL)
+ defer rows.Close()
+
+ if err != nil {
+ client.Cmd.Reply(event, "error: "+err.Error())
+
+ break
+ }
+
+ var gptMemory []openai.ChatCompletionMessage
+
+ logs, err := pgx.CollectRows(rows, pgx.RowToStructByName[LogModel])
+ if err != nil {
+ log.Println(err.Error())
+
+ break
+ }
+
+ log.Println(logs)
+ logs = logs[:customCommand.Limit]
+
+ if err != nil {
+ log.Println(err.Error())
+
+ break
+ }
+
+ for _, log := range logs {
+ gptMemory = append(gptMemory, openai.ChatCompletionMessage{
+ Role: openai.ChatMessageRoleUser,
+ Content: log.Log,
+ })
+ }
+
+ chatGPTRequest(appConfig, client, event, &gptMemory, customCommand.Prompt)
+
default:
client.Cmd.Reply(event, errUnknCmd.Error())
}
@@ -457,10 +533,10 @@ func doOllamaRequest(
*ollamaMemory = append(*ollamaMemory, memoryElement)
ollamaRequest := OllamaChatRequest{
- Model: appConfig.Model,
- Keep_alive: time.Duration(appConfig.KeepAlive),
- Stream: false,
- Messages: *ollamaMemory,
+ Model: appConfig.Model,
+ KeepAlive: time.Duration(appConfig.KeepAlive),
+ Stream: false,
+ Messages: *ollamaMemory,
Options: OllamaRequestOptions{
Temperature: appConfig.Temp,
},
@@ -468,9 +544,9 @@ func doOllamaRequest(
jsonPayload, err = json.Marshal(ollamaRequest)
if err != nil {
- client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
+ client.Cmd.ReplyTo(event, "error: "+err.Error())
- return nil, err
+ return nil, fmt.Errorf("could not marshal json payload: %v", err)
}
log.Printf("json payload: %s", string(jsonPayload))
@@ -480,9 +556,9 @@ func doOllamaRequest(
request, err := http.NewRequest(http.MethodPost, appConfig.OllamaEndpoint, bytes.NewBuffer(jsonPayload))
if err != nil {
- client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
+ client.Cmd.ReplyTo(event, "error: "+err.Error())
- return nil, err
+ return nil, fmt.Errorf("could not make a new http request: %v", err)
}
request = request.WithContext(ctx)
@@ -531,10 +607,11 @@ func ollamaRequest(
}
if err != nil {
- client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
+ client.Cmd.ReplyTo(event, "error: "+err.Error())
return
}
+
defer response.Body.Close()
log.Println("response body:", response.Body)
@@ -545,7 +622,7 @@ func ollamaRequest(
err = json.NewDecoder(response.Body).Decode(&ollamaChatResponse)
if err != nil {
- client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
+ client.Cmd.ReplyTo(event, "error: "+err.Error())
}
assistantElement := MemoryElement{
@@ -563,7 +640,7 @@ func ollamaRequest(
appConfig.ChromaFormatter,
appConfig.ChromaStyle)
if err != nil {
- client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
+ client.Cmd.ReplyTo(event, "error: "+err.Error())
return
}
@@ -580,17 +657,21 @@ func ollamaHandler(
if !strings.HasPrefix(event.Last(), appConfig.IrcNick+": ") {
return
}
+
if appConfig.AdminOnly {
byAdmin := false
+
for _, admin := range appConfig.Admins {
if event.Source.Name == admin {
byAdmin = true
}
}
+
if !byAdmin {
return
}
}
+
prompt := strings.TrimPrefix(event.Last(), appConfig.IrcNick+": ")
log.Println(prompt)
@@ -616,7 +697,7 @@ func doGeminiRequest(
clientGemini, err := genai.NewClient(ctx, option.WithAPIKey(appConfig.Apikey))
if err != nil {
- client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
+ client.Cmd.ReplyTo(event, "error: "+err.Error())
return ""
}
@@ -633,7 +714,7 @@ func doGeminiRequest(
resp, err := cs.SendMessage(ctx, genai.Text(prompt))
if err != nil {
- client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
+ client.Cmd.ReplyTo(event, "error: "+err.Error())
return ""
}
@@ -678,7 +759,7 @@ func geminiRequest(
appConfig.ChromaFormatter,
appConfig.ChromaStyle)
if err != nil {
- client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
+ client.Cmd.ReplyTo(event, "error: "+err.Error())
return
}
@@ -695,17 +776,21 @@ func geminiHandler(
if !strings.HasPrefix(event.Last(), appConfig.IrcNick+": ") {
return
}
+
if appConfig.AdminOnly {
byAdmin := false
+
for _, admin := range appConfig.Admins {
if event.Source.Name == admin {
byAdmin = true
}
}
+
if !byAdmin {
return
}
}
+
prompt := strings.TrimPrefix(event.Last(), appConfig.IrcNick+": ")
log.Println(prompt)
@@ -735,7 +820,7 @@ func doChatGPTRequest(
proxyURL, err := url.Parse(appConfig.IRCProxy)
if err != nil {
cancel()
- client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
+ client.Cmd.ReplyTo(event, "error: "+err.Error())
log.Fatal(err.Error())
}
@@ -743,7 +828,7 @@ func doChatGPTRequest(
dialer, err := proxy.FromURL(proxyURL, &net.Dialer{Timeout: time.Duration(appConfig.RequestTimeout) * time.Second})
if err != nil {
cancel()
- client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
+ client.Cmd.ReplyTo(event, "error: "+err.Error())
log.Fatal(err.Error())
}
@@ -782,7 +867,7 @@ func chatGPTRequest(
) {
resp, err := doChatGPTRequest(appConfig, client, event, gptMemory, prompt)
if err != nil {
- client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
+ client.Cmd.ReplyTo(event, "error: "+err.Error())
return
}
@@ -805,7 +890,7 @@ func chatGPTRequest(
appConfig.ChromaFormatter,
appConfig.ChromaStyle)
if err != nil {
- client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
+ client.Cmd.ReplyTo(event, "error: "+err.Error())
return
}
@@ -822,17 +907,21 @@ func chatGPTHandler(
if !strings.HasPrefix(event.Last(), appConfig.IrcNick+": ") {
return
}
+
if appConfig.AdminOnly {
byAdmin := false
+
for _, admin := range appConfig.Admins {
if event.Source.Name == admin {
byAdmin = true
}
}
+
if !byAdmin {
return
}
}
+
prompt := strings.TrimPrefix(event.Last(), appConfig.IrcNick+": ")
log.Println(prompt)
@@ -846,7 +935,7 @@ func chatGPTHandler(
})
}
-func connectToDB(appConfig TomlConfig, ctx *context.Context, poolChan chan *pgxpool.Pool) {
+func connectToDB(appConfig *TomlConfig, ctx *context.Context, poolChan chan *pgxpool.Pool) {
for {
if appConfig.DatabaseUser == "" {
appConfig.DatabaseUser = os.Getenv("MILLA_DB_USER")
@@ -895,6 +984,7 @@ func connectToDB(appConfig TomlConfig, ctx *context.Context, poolChan chan *pgxp
nick text not null,
dateadded timestamp default current_timestamp
)`, tableName)
+
_, err = pool.Exec(*ctx, query)
if err != nil {
log.Println(err.Error())
@@ -902,13 +992,14 @@ func connectToDB(appConfig TomlConfig, ctx *context.Context, poolChan chan *pgxp
}
}
+ appConfig.pool = pool
poolChan <- pool
}
}
}
func scrapeChannel(irc *girc.Client, poolChan chan *pgxpool.Pool, appConfig TomlConfig) {
- irc.Handlers.AddBg(girc.PRIVMSG, func(client *girc.Client, event girc.Event) {
+ irc.Handlers.AddBg(girc.PRIVMSG, func(_ *girc.Client, event girc.Event) {
pool := <-poolChan
tableName := getTableFromChanName(event.Params[0], appConfig.IRCDName)
query := fmt.Sprintf(
@@ -953,6 +1044,13 @@ func runIRC(appConfig TomlConfig) {
},
})
+ if appConfig.WebIRCGateway != "" {
+ irc.Config.WebIRC.Address = appConfig.WebIRCAddress
+ irc.Config.WebIRC.Gateway = appConfig.WebIRCGateway
+ irc.Config.WebIRC.Hostname = appConfig.WebIRCHostname
+ irc.Config.WebIRC.Password = appConfig.WebIRCPassword
+ }
+
if appConfig.Debug {
irc.Config.Debug = os.Stdout
}
@@ -1003,7 +1101,7 @@ func runIRC(appConfig TomlConfig) {
irc.Config.TLSConfig.Certificates = []tls.Certificate{cert}
}
- irc.Handlers.AddBg(girc.CONNECTED, func(c *girc.Client, e girc.Event) {
+ irc.Handlers.AddBg(girc.CONNECTED, func(c *girc.Client, _ girc.Event) {
for _, channel := range appConfig.IrcChannels {
c.Cmd.Join(channel)
}
@@ -1022,11 +1120,11 @@ func runIRC(appConfig TomlConfig) {
context, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second)
defer cancel()
- go connectToDB(appConfig, &context, poolChan)
+ go connectToDB(&appConfig, &context, poolChan)
}
if len(appConfig.ScrapeChannels) > 0 {
- irc.Handlers.AddBg(girc.CONNECTED, func(c *girc.Client, e girc.Event) {
+ irc.Handlers.AddBg(girc.CONNECTED, func(c *girc.Client, _ girc.Event) {
for _, channel := range appConfig.ScrapeChannels {
c.Cmd.Join(channel)
}
@@ -1080,10 +1178,13 @@ func main() {
log.Fatal(err)
}
+ for key, value := range config.Ircd {
+ addSaneDefaults(&value)
+ value.IRCDName = key
+ config.Ircd[key] = value
+ }
+
for k, v := range config.Ircd {
- addSaneDefaults(&v)
- v.IRCDName = k
- config.Ircd[k] = v
log.Println(k, v)
}