aboutsummaryrefslogtreecommitdiffstats
path: root/main.go
diff options
context:
space:
mode:
authorterminaldweller <devi@terminaldweller.com>2024-05-13 23:23:01 +0000
committerterminaldweller <devi@terminaldweller.com>2024-05-13 23:23:01 +0000
commit2c02180d2bb6d74c03a967595bbf5a414a86aa75 (patch)
treefcc9c73be09b3aa95ac48984378ccaafe44413e9 /main.go
parentupdated the readme, added example config file (diff)
downloadmilla-2c02180d2bb6d74c03a967595bbf5a414a86aa75.tar.gz
milla-2c02180d2bb6d74c03a967595bbf5a414a86aa75.zip
added a distroless build, made sure milla works with gvisor's runsc as the runtime
Diffstat (limited to 'main.go')
-rw-r--r--main.go113
1 files changed, 54 insertions, 59 deletions
diff --git a/main.go b/main.go
index b470738..68120e5 100644
--- a/main.go
+++ b/main.go
@@ -10,9 +10,9 @@ import (
"fmt"
"log"
"net/http"
- "net/url"
"os"
"reflect"
+ "regexp"
"strconv"
"strings"
"time"
@@ -22,6 +22,7 @@ import (
"github.com/google/generative-ai-go/genai"
"github.com/lrstanley/girc"
openai "github.com/sashabaranov/go-openai"
+ "golang.org/x/net/proxy"
"google.golang.org/api/option"
)
@@ -77,7 +78,6 @@ func NewTomlConfig() *TomlConfig {
ChromaStyle: "rose-pine-moon",
ChromaFormatter: "noop",
Provider: "ollama",
- ClientCertPath: "milla.pem",
Temp: 0.5, //nolint:gomnd
RequestTimeout: 10, //nolint:gomnd
MillaReconnectDelay: 30, //nolint:gomnd
@@ -137,26 +137,26 @@ func returnGeminiResponse(resp *genai.GenerateContentResponse) string {
return result
}
-// func extractLast256ColorEscapeCode(str string) (string, error) {
-// pattern256F := `\033\[38;5;(\d+)m`
-// // pattern256B := `\033\[48;5;(\d+)m`
-// // pattern16mF := `\033\[38;2;(\d+);(\d+);(\d+)m`
-// // pattern16mB := `\033\[48;2;(\d+);(\d+);(\d+)m`
+func extractLast256ColorEscapeCode(str string) (string, error) {
+ pattern256F := `\033\[38;5;(\d+)m`
+ // pattern256B := `\033\[48;5;(\d+)m`
+ // pattern16mF := `\033\[38;2;(\d+);(\d+);(\d+)m`
+ // pattern16mB := `\033\[48;2;(\d+);(\d+);(\d+)m`
-// r, err := regexp.Compile(pattern256F)
-// if err != nil {
-// return "", fmt.Errorf("failed to compile regular expression: %w", err)
-// }
+ r, err := regexp.Compile(pattern256F)
+ if err != nil {
+ return "", fmt.Errorf("failed to compile regular expression: %w", err)
+ }
-// matches := r.FindAllStringSubmatch(str, -1)
-// if len(matches) == 0 {
-// return "", nil
-// }
+ matches := r.FindAllStringSubmatch(str, -1)
+ if len(matches) == 0 {
+ return "", nil
+ }
-// lastMatch := matches[len(matches)-1]
+ lastMatch := matches[len(matches)-1]
-// return lastMatch[1], nil
-// }
+ return lastMatch[1], nil
+}
func chunker(inputString string, chromaFormatter string) []string {
chunks := strings.Split(inputString, "\n")
@@ -169,17 +169,16 @@ func chunker(inputString string, chromaFormatter string) []string {
case "terminal16":
fallthrough
case "terminal256":
- // for count, chunk := range chunks {
- // lastColorCode, err := extractLast256ColorEscapeCode(chunk)
- // if err != nil {
- // continue
- // }
+ for count, chunk := range chunks {
+ lastColorCode, err := extractLast256ColorEscapeCode(chunk)
+ if err != nil {
+ continue
+ }
- // if count <= len(chunks)-2 {
- // chunks[count+1] = fmt.Sprintf("\033[38;5;%sm", lastColorCode) + chunks[count+1]
- // }
- // }
- fallthrough
+ if count <= len(chunks)-2 {
+ chunks[count+1] = fmt.Sprintf("\033[38;5;%sm", lastColorCode) + chunks[count+1]
+ }
+ }
case "terminal16m":
fallthrough
default:
@@ -384,20 +383,14 @@ func ollamaHandler(
request.Header.Set("Content-Type", "application/json")
- httpClient := http.Client{}
- allProxy := os.Getenv("ALL_PROXY")
- if allProxy != "" {
- proxyURL, err := url.Parse(allProxy)
- if err != nil {
- client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
+ var httpClient http.Client
- return
- }
- transport := &http.Transport{
- Proxy: http.ProxyURL(proxyURL),
- }
+ dialer := proxy.FromEnvironment()
- httpClient.Transport = transport
+ httpClient = http.Client{
+ Transport: &http.Transport{
+ Dial: dialer.Dial,
+ },
}
response, err := httpClient.Do(request)
@@ -564,24 +557,19 @@ func chatGPTHandler(
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second)
defer cancel()
- allProxy := os.Getenv("ALL_PROXY")
- config := openai.DefaultConfig(appConfig.Apikey)
- if allProxy != "" {
- proxyURL, err := url.Parse(allProxy)
- if err != nil {
- client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
+ var httpClient http.Client
- return
- }
- transport := &http.Transport{
- Proxy: http.ProxyURL(proxyURL),
- }
+ dialer := proxy.FromEnvironment()
- config.HTTPClient = &http.Client{
- Transport: transport,
- }
+ httpClient = http.Client{
+ Transport: &http.Transport{
+ Dial: dialer.Dial,
+ },
}
+ config := openai.DefaultConfig(appConfig.Apikey)
+ config.HTTPClient = &httpClient
+
gptClient := openai.NewClientWithConfig(config)
*gptMemory = append(*gptMemory, openai.ChatCompletionMessage{
@@ -645,7 +633,7 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) {
DisableSTSFallback: appConfig.DisableSTSFallback,
GlobalFormat: true,
TLSConfig: &tls.Config{
- InsecureSkipVerify: appConfig.SkipTLSVerify,
+ InsecureSkipVerify: appConfig.SkipTLSVerify, // #nosec G402
ServerName: appConfig.IrcServer,
},
})
@@ -676,9 +664,16 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) {
}
}
- // if appConfig.EnableSasl && appConfig.ClientCertPath != "" {
- // // TODO - add client cert support
- // }
+ if appConfig.EnableSasl && appConfig.ClientCertPath != "" {
+ cert, err := tls.LoadX509KeyPair(appConfig.ClientCertPath, appConfig.ClientCertPath)
+ if err != nil {
+ log.Println("invalid client certificate.")
+
+ return
+ }
+
+ irc.Config.TLSConfig.Certificates = []tls.Certificate{cert}
+ }
irc.Handlers.AddBg(girc.CONNECTED, func(c *girc.Client, e girc.Event) {
for _, channel := range appConfig.IrcChannels {
@@ -700,7 +695,7 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) {
for {
if err := irc.Connect(); err != nil {
log.Println(err)
- log.Println("reconnecting in" + strconv.Itoa(appConfig.MillaReconnectDelay))
+ log.Println("reconnecting in " + strconv.Itoa(appConfig.MillaReconnectDelay))
time.Sleep(time.Duration(appConfig.MillaReconnectDelay) * time.Second)
} else {
return