aboutsummaryrefslogtreecommitdiffstats
path: root/main.go
diff options
context:
space:
mode:
authorterminaldweller <devi@terminaldweller.com>2024-05-10 06:47:52 +0000
committerterminaldweller <devi@terminaldweller.com>2024-05-10 06:47:52 +0000
commit404067002141c9d548765d9cc44bfdd916beea4e (patch)
tree9ebc38091b5b68ee39ca04c0e6a199f2f091442c /main.go
parentrunpod serverless ollama - WIP (diff)
downloadmilla-404067002141c9d548765d9cc44bfdd916beea4e.tar.gz
milla-404067002141c9d548765d9cc44bfdd916beea4e.zip
WIP
Diffstat (limited to '')
-rw-r--r--main.go204
1 files changed, 149 insertions, 55 deletions
diff --git a/main.go b/main.go
index 2fda234..88dc33f 100644
--- a/main.go
+++ b/main.go
@@ -9,6 +9,7 @@ import (
"fmt"
"log"
"net/http"
+ "net/url"
"os"
"strings"
"time"
@@ -18,6 +19,7 @@ import (
"github.com/lrstanley/girc"
"github.com/pelletier/go-toml/v2"
openai "github.com/sashabaranov/go-openai"
+ "golang.org/x/net/proxy"
"google.golang.org/api/option"
)
@@ -27,7 +29,7 @@ type TomlConfig struct {
IrcNick string
IrcSaslUser string
IrcSaslPass string
- IrcChannel string
+ IrcChannels []string
OllamaEndpoint string
Temp float64
OllamaSystem string
@@ -41,6 +43,10 @@ type TomlConfig struct {
Apikey string
TopP float32
TopK int32
+ Chat bool
+ Admins []string
+ Color bool
+ SkipTLSVerify bool
}
type OllamaResponse struct {
@@ -51,6 +57,15 @@ type OllamaRequestOptions struct {
Temperature float64 `json:"temperature"`
}
+type OllamaChatMessage struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+}
+
+type OllamaChatMessages struct {
+ Messages []OllamaChatMessage `json:"messages"`
+}
+
type OllamaRequest struct {
Model string `json:"model"`
System string `json:"system"`
@@ -60,6 +75,19 @@ type OllamaRequest struct {
Options OllamaRequestOptions `json:"options"`
}
+type OllamaChatRequest struct {
+ Model string `json:"model"`
+ Stream bool `json:"stream"`
+ Keep_alive time.Duration `json:"keep_alive"`
+ Options OllamaRequestOptions `json:"options"`
+ Format string `json:"format"`
+ Messages OllamaChatMessages `json:"messages"`
+}
+
+type OllamaChatResponse struct {
+ Messages OllamaChatMessages `json:"messages"`
+}
+
func printResponse(resp *genai.GenerateContentResponse) string {
result := ""
@@ -77,13 +105,14 @@ func printResponse(resp *genai.GenerateContentResponse) string {
func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) {
irc := girc.New(girc.Config{
- Server: appConfig.IrcServer,
- Port: appConfig.IrcPort,
- Nick: appConfig.IrcNick,
- User: appConfig.IrcNick,
- Name: appConfig.IrcNick,
- SSL: true,
- TLSConfig: &tls.Config{InsecureSkipVerify: true},
+ Server: appConfig.IrcServer,
+ Port: appConfig.IrcPort,
+ Nick: appConfig.IrcNick,
+ User: appConfig.IrcNick,
+ Name: appConfig.IrcNick,
+ SSL: true,
+ TLSConfig: &tls.Config{InsecureSkipVerify: appConfig.SkipTLSVerify,
+ ServerName: appConfig.IrcServer},
})
saslUser := appConfig.IrcSaslUser
@@ -97,8 +126,7 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) {
}
irc.Handlers.AddBg(girc.CONNECTED, func(c *girc.Client, e girc.Event) {
- channels := strings.Split(appConfig.IrcChannel, " ")
- for _, channel := range channels {
+ for _, channel := range appConfig.IrcChannels {
c.Cmd.Join(channel)
}
})
@@ -109,23 +137,55 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) {
prompt := strings.TrimPrefix(event.Last(), appConfig.IrcNick+": ")
log.Println(prompt)
- ollamaRequest := OllamaRequest{
- Model: appConfig.Model,
- System: appConfig.OllamaSystem,
- Prompt: prompt,
- Stream: false,
- Format: "json",
- Options: OllamaRequestOptions{
- Temperature: appConfig.Temp,
- },
+ var jsonPayload []byte
+ var err error
+
+ if appConfig.Chat {
+ ollamaRequest := OllamaChatRequest{
+ Model: appConfig.Model,
+ Stream: false,
+ Format: "json",
+ Messages: OllamaChatMessages{
+ []OllamaChatMessage{{
+ Role: "user",
+ Content: prompt,
+ }},
+ },
+ Options: OllamaRequestOptions{
+ Temperature: appConfig.Temp,
+ },
+ }
+ jsonPayload, err = json.Marshal(ollamaRequest)
+ if err != nil {
+ client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
+
+ return
+ }
+ } else {
+ ollamaRequest := OllamaRequest{
+ Model: appConfig.Model,
+ System: appConfig.OllamaSystem,
+ Prompt: prompt,
+ Stream: false,
+ Format: "json",
+ Options: OllamaRequestOptions{
+ Temperature: appConfig.Temp,
+ },
+ }
+ jsonPayload, err = json.Marshal(ollamaRequest)
+ if err != nil {
+ client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
+
+ return
+ }
}
- jsonPayload, err := json.Marshal(ollamaRequest)
- if err != nil {
- client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error())))
+ // jsonPayload, err := json.Marshal(ollamaRequest)
+ // if err != nil {
+ // client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
- return
- }
+ // return
+ // }
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second)
defer cancel()
@@ -133,7 +193,7 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) {
request, err := http.NewRequest(http.MethodPost, appConfig.OllamaEndpoint, bytes.NewBuffer(jsonPayload))
request = request.WithContext(ctx)
if err != nil {
- client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error())))
+ client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
return
}
@@ -141,10 +201,24 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) {
request.Header.Set("Content-Type", "application/json")
httpClient := http.Client{}
+ allProxy := os.Getenv("ALL_PROXY")
+ if allProxy != "" {
+ proxyUrl, err := url.Parse(allProxy)
+ if err != nil {
+ client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
+
+ return
+ }
+ transport := &http.Transport{
+ Proxy: http.ProxyURL(proxyUrl),
+ }
+
+ httpClient.Transport = transport
+ }
response, err := httpClient.Do(request)
if err != nil {
- client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error())))
+ client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
return
}
@@ -153,7 +227,7 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) {
var ollamaResponse OllamaResponse
err = json.NewDecoder(response.Body).Decode(&ollamaResponse)
if err != nil {
- client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error())))
+ client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
return
}
@@ -165,16 +239,16 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) {
appConfig.ChromaFormatter,
appConfig.ChromaStyle)
if err != nil {
- client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error())))
+ client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
return
}
- client.Cmd.ReplyTo(event, girc.Fmt("\033[0m"+writer.String()))
+ log.Println(writer.String())
+ client.Cmd.Reply(event, writer.String())
}
})
} else if appConfig.Provider == "gemini" {
- log.Println("fuck prime")
irc.Handlers.AddBg(girc.PRIVMSG, func(client *girc.Client, event girc.Event) {
if strings.HasPrefix(event.Last(), appConfig.IrcNick+": ") {
prompt := strings.TrimPrefix(event.Last(), appConfig.IrcNick+": ")
@@ -183,37 +257,35 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second)
defer cancel()
- // dialer := proxy.FromEnvironment()
+ dialer := proxy.FromEnvironment()
- // transport := http.Transport{
- // Dial: dialer.Dial,
- // }
- // httpClient := http.Client{
- // Transport: &transport,
- // Timeout: time.Duration(appConfig.RequestTimeout) * time.Second,
- // }
+ transport := http.Transport{
+ Dial: dialer.Dial,
+ }
+ httpClient := http.Client{
+ Transport: &transport,
+ Timeout: time.Duration(appConfig.RequestTimeout) * time.Second,
+ }
- // clientGemini, err := genai.NewClient(ctx, option.WithAPIKey(appConfig.Apikey), option.WithHTTPClient(&httpClient))
- clientGemini, err := genai.NewClient(ctx, option.WithAPIKey(appConfig.Apikey))
- if err != nil {
- client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error())))
+ clientGemini, err := genai.NewClient(ctx, option.WithAPIKey(appConfig.Apikey), option.WithHTTPClient(&httpClient))
+ // clientGemini, err := genai.NewClient(ctx, option.WithAPIKey(appConfig.Apikey))
+ // if err != nil {
+ // client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
- return
- }
+ // return
+ // }
defer clientGemini.Close()
model := clientGemini.GenerativeModel(appConfig.Model)
model.SetTemperature(float32(appConfig.Temp))
model.SetTopK(appConfig.TopK)
model.SetTopP(appConfig.TopP)
- log.Println("fuck")
resp, err := model.GenerateContent(ctx, genai.Text(prompt))
if err != nil {
- client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error())))
+ client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
return
}
- log.Println("fuck two")
var writer bytes.Buffer
err = quick.Highlight(
@@ -223,13 +295,13 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) {
appConfig.ChromaFormatter,
appConfig.ChromaStyle)
if err != nil {
- client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error())))
+ client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
return
}
log.Println(writer.String())
- client.Cmd.ReplyTo(event, girc.Fmt("\033[0m"+writer.String()))
+ client.Cmd.Reply(event, writer.String())
}
})
} else if appConfig.Provider == "chatgpt" {
@@ -241,7 +313,25 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second)
defer cancel()
- gptClient := openai.NewClient(appConfig.Apikey)
+ allProxy := os.Getenv("ALL_PROXY")
+ config := openai.DefaultConfig(appConfig.Apikey)
+ if allProxy != "" {
+ proxyUrl, err := url.Parse(allProxy)
+ if err != nil {
+ client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
+
+ return
+ }
+ transport := &http.Transport{
+ Proxy: http.ProxyURL(proxyUrl),
+ }
+
+ config.HTTPClient = &http.Client{
+ Transport: transport,
+ }
+ }
+
+ gptClient := openai.NewClientWithConfig(config)
messages := make([]openai.ChatCompletionMessage, 0)
@@ -255,7 +345,7 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) {
Messages: messages,
})
if err != nil {
- client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error())))
+ client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
return
}
@@ -268,13 +358,17 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) {
appConfig.ChromaFormatter,
appConfig.ChromaStyle)
if err != nil {
- client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error())))
+ client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error()))
return
}
log.Println(writer.String())
- client.Cmd.ReplyTo(event, girc.Fmt("\033[0m"+writer.String()))
+ lines := strings.Split(writer.String(), "\n")
+
+ for _, line := range lines {
+ client.Cmd.Reply(event, line)
+ }
}
})
}
@@ -284,7 +378,7 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) {
for {
if err := irc.Connect(); err != nil {
log.Println(err)
- log.Println("reconnecting in 30 seconds")
+ log.Println("reconnecting in {appConfig.MillaReconnectDelay/1000}")
time.Sleep(time.Duration(appConfig.MillaReconnectDelay) * time.Second)
} else {
return
@@ -295,7 +389,7 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) {
func main() {
var appConfig TomlConfig
- configPath := flag.String("config", "./config-gemini.toml", "path to the config file")
+ configPath := flag.String("config", "./config.toml", "path to the config file")
flag.Parse()