diff options
Diffstat (limited to 'main.go')
-rw-r--r-- | main.go | 204 |
1 files changed, 149 insertions, 55 deletions
@@ -9,6 +9,7 @@ import ( "fmt" "log" "net/http" + "net/url" "os" "strings" "time" @@ -18,6 +19,7 @@ import ( "github.com/lrstanley/girc" "github.com/pelletier/go-toml/v2" openai "github.com/sashabaranov/go-openai" + "golang.org/x/net/proxy" "google.golang.org/api/option" ) @@ -27,7 +29,7 @@ type TomlConfig struct { IrcNick string IrcSaslUser string IrcSaslPass string - IrcChannel string + IrcChannels []string OllamaEndpoint string Temp float64 OllamaSystem string @@ -41,6 +43,10 @@ type TomlConfig struct { Apikey string TopP float32 TopK int32 + Chat bool + Admins []string + Color bool + SkipTLSVerify bool } type OllamaResponse struct { @@ -51,6 +57,15 @@ type OllamaRequestOptions struct { Temperature float64 `json:"temperature"` } +type OllamaChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type OllamaChatMessages struct { + Messages []OllamaChatMessage `json:"messages"` +} + type OllamaRequest struct { Model string `json:"model"` System string `json:"system"` @@ -60,6 +75,19 @@ type OllamaRequest struct { Options OllamaRequestOptions `json:"options"` } +type OllamaChatRequest struct { + Model string `json:"model"` + Stream bool `json:"stream"` + Keep_alive time.Duration `json:"keep_alive"` + Options OllamaRequestOptions `json:"options"` + Format string `json:"format"` + Messages OllamaChatMessages `json:"messages"` +} + +type OllamaChatResponse struct { + Messages OllamaChatMessages `json:"messages"` +} + func printResponse(resp *genai.GenerateContentResponse) string { result := "" @@ -77,13 +105,14 @@ func printResponse(resp *genai.GenerateContentResponse) string { func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) { irc := girc.New(girc.Config{ - Server: appConfig.IrcServer, - Port: appConfig.IrcPort, - Nick: appConfig.IrcNick, - User: appConfig.IrcNick, - Name: appConfig.IrcNick, - SSL: true, - TLSConfig: &tls.Config{InsecureSkipVerify: true}, + Server: appConfig.IrcServer, + Port: appConfig.IrcPort, + Nick: appConfig.IrcNick, + User: appConfig.IrcNick, + Name: appConfig.IrcNick, + SSL: true, + TLSConfig: &tls.Config{InsecureSkipVerify: appConfig.SkipTLSVerify, + ServerName: appConfig.IrcServer}, }) saslUser := appConfig.IrcSaslUser @@ -97,8 +126,7 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) { } irc.Handlers.AddBg(girc.CONNECTED, func(c *girc.Client, e girc.Event) { - channels := strings.Split(appConfig.IrcChannel, " ") - for _, channel := range channels { + for _, channel := range appConfig.IrcChannels { c.Cmd.Join(channel) } }) @@ -109,23 +137,55 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) { prompt := strings.TrimPrefix(event.Last(), appConfig.IrcNick+": ") log.Println(prompt) - ollamaRequest := OllamaRequest{ - Model: appConfig.Model, - System: appConfig.OllamaSystem, - Prompt: prompt, - Stream: false, - Format: "json", - Options: OllamaRequestOptions{ - Temperature: appConfig.Temp, - }, + var jsonPayload []byte + var err error + + if appConfig.Chat { + ollamaRequest := OllamaChatRequest{ + Model: appConfig.Model, + Stream: false, + Format: "json", + Messages: OllamaChatMessages{ + []OllamaChatMessage{{ + Role: "user", + Content: prompt, + }}, + }, + Options: OllamaRequestOptions{ + Temperature: appConfig.Temp, + }, + } + jsonPayload, err = json.Marshal(ollamaRequest) + if err != nil { + client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + + return + } + } else { + ollamaRequest := OllamaRequest{ + Model: appConfig.Model, + System: appConfig.OllamaSystem, + Prompt: prompt, + Stream: false, + Format: "json", + Options: OllamaRequestOptions{ + Temperature: appConfig.Temp, + }, + } + jsonPayload, err = json.Marshal(ollamaRequest) + if err != nil { + client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + + return + } } - jsonPayload, err := json.Marshal(ollamaRequest) - if err != nil { - client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error()))) + // jsonPayload, err := json.Marshal(ollamaRequest) + // if err != nil { + // client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) - return - } + // return + // } ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second) defer cancel() @@ -133,7 +193,7 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) { request, err := http.NewRequest(http.MethodPost, appConfig.OllamaEndpoint, bytes.NewBuffer(jsonPayload)) request = request.WithContext(ctx) if err != nil { - client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error()))) + client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) return } @@ -141,10 +201,24 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) { request.Header.Set("Content-Type", "application/json") httpClient := http.Client{} + allProxy := os.Getenv("ALL_PROXY") + if allProxy != "" { + proxyUrl, err := url.Parse(allProxy) + if err != nil { + client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + + return + } + transport := &http.Transport{ + Proxy: http.ProxyURL(proxyUrl), + } + + httpClient.Transport = transport + } response, err := httpClient.Do(request) if err != nil { - client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error()))) + client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) return } @@ -153,7 +227,7 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) { var ollamaResponse OllamaResponse err = json.NewDecoder(response.Body).Decode(&ollamaResponse) if err != nil { - client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error()))) + client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) return } @@ -165,16 +239,16 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) { appConfig.ChromaFormatter, appConfig.ChromaStyle) if err != nil { - client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error()))) + client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) return } - client.Cmd.ReplyTo(event, girc.Fmt("\033[0m"+writer.String())) + log.Println(writer.String()) + client.Cmd.Reply(event, writer.String()) } }) } else if appConfig.Provider == "gemini" { - log.Println("fuck prime") irc.Handlers.AddBg(girc.PRIVMSG, func(client *girc.Client, event girc.Event) { if strings.HasPrefix(event.Last(), appConfig.IrcNick+": ") { prompt := strings.TrimPrefix(event.Last(), appConfig.IrcNick+": ") @@ -183,37 +257,35 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) { ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second) defer cancel() - // dialer := proxy.FromEnvironment() + dialer := proxy.FromEnvironment() - // transport := http.Transport{ - // Dial: dialer.Dial, - // } - // httpClient := http.Client{ - // Transport: &transport, - // Timeout: time.Duration(appConfig.RequestTimeout) * time.Second, - // } + transport := http.Transport{ + Dial: dialer.Dial, + } + httpClient := http.Client{ + Transport: &transport, + Timeout: time.Duration(appConfig.RequestTimeout) * time.Second, + } - // clientGemini, err := genai.NewClient(ctx, option.WithAPIKey(appConfig.Apikey), option.WithHTTPClient(&httpClient)) - clientGemini, err := genai.NewClient(ctx, option.WithAPIKey(appConfig.Apikey)) - if err != nil { - client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error()))) + clientGemini, err := genai.NewClient(ctx, option.WithAPIKey(appConfig.Apikey), option.WithHTTPClient(&httpClient)) + // clientGemini, err := genai.NewClient(ctx, option.WithAPIKey(appConfig.Apikey)) + // if err != nil { + // client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) - return - } + // return + // } defer clientGemini.Close() model := clientGemini.GenerativeModel(appConfig.Model) model.SetTemperature(float32(appConfig.Temp)) model.SetTopK(appConfig.TopK) model.SetTopP(appConfig.TopP) - log.Println("fuck") resp, err := model.GenerateContent(ctx, genai.Text(prompt)) if err != nil { - client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error()))) + client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) return } - log.Println("fuck two") var writer bytes.Buffer err = quick.Highlight( @@ -223,13 +295,13 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) { appConfig.ChromaFormatter, appConfig.ChromaStyle) if err != nil { - client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error()))) + client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) return } log.Println(writer.String()) - client.Cmd.ReplyTo(event, girc.Fmt("\033[0m"+writer.String())) + client.Cmd.Reply(event, writer.String()) } }) } else if appConfig.Provider == "chatgpt" { @@ -241,7 +313,25 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) { ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second) defer cancel() - gptClient := openai.NewClient(appConfig.Apikey) + allProxy := os.Getenv("ALL_PROXY") + config := openai.DefaultConfig(appConfig.Apikey) + if allProxy != "" { + proxyUrl, err := url.Parse(allProxy) + if err != nil { + client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) + + return + } + transport := &http.Transport{ + Proxy: http.ProxyURL(proxyUrl), + } + + config.HTTPClient = &http.Client{ + Transport: transport, + } + } + + gptClient := openai.NewClientWithConfig(config) messages := make([]openai.ChatCompletionMessage, 0) @@ -255,7 +345,7 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) { Messages: messages, }) if err != nil { - client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error()))) + client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) return } @@ -268,13 +358,17 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) { appConfig.ChromaFormatter, appConfig.ChromaStyle) if err != nil { - client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error()))) + client.Cmd.ReplyTo(event, fmt.Sprintf("error: %s", err.Error())) return } log.Println(writer.String()) - client.Cmd.ReplyTo(event, girc.Fmt("\033[0m"+writer.String())) + lines := strings.Split(writer.String(), "\n") + + for _, line := range lines { + client.Cmd.Reply(event, line) + } } }) } @@ -284,7 +378,7 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) { for { if err := irc.Connect(); err != nil { log.Println(err) - log.Println("reconnecting in 30 seconds") + log.Println("reconnecting in {appConfig.MillaReconnectDelay/1000}") time.Sleep(time.Duration(appConfig.MillaReconnectDelay) * time.Second) } else { return @@ -295,7 +389,7 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) { func main() { var appConfig TomlConfig - configPath := flag.String("config", "./config-gemini.toml", "path to the config file") + configPath := flag.String("config", "./config.toml", "path to the config file") flag.Parse() |