diff options
Diffstat (limited to 'main.go')
-rw-r--r-- | main.go | 147 |
1 files changed, 96 insertions, 51 deletions
@@ -13,8 +13,10 @@ import ( "time" "github.com/alecthomas/chroma/v2/quick" + "github.com/google/generative-ai-go/genai" "github.com/lrstanley/girc" "github.com/pelletier/go-toml/v2" + "google.golang.org/api/option" ) type TomlConfig struct { @@ -33,6 +35,8 @@ type TomlConfig struct { Model string ChromaStyle string ChromaFormatter string + Provider string + Apikey string } type OllamaResponse struct { @@ -80,72 +84,113 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) { } }) - irc.Handlers.AddBg(girc.PRIVMSG, func(client *girc.Client, event girc.Event) { - if strings.HasPrefix(event.Last(), appConfig.IrcNick+": ") { - prompt := strings.TrimPrefix(event.Last(), appConfig.IrcNick+": ") - log.Println(prompt) - - ollamaRequest := OllamaRequest{ - Model: appConfig.Model, - System: appConfig.OllamaSystem, - Prompt: prompt, - Stream: false, - Format: "json", - Options: OllamaRequestOptions{ - Temperature: appConfig.OllamaTemp, - }, - } + if appConfig.Provider == "ollama" { + irc.Handlers.AddBg(girc.PRIVMSG, func(client *girc.Client, event girc.Event) { + if strings.HasPrefix(event.Last(), appConfig.IrcNick+": ") { + prompt := strings.TrimPrefix(event.Last(), appConfig.IrcNick+": ") + log.Println(prompt) - jsonPayload, err := json.Marshal(ollamaRequest) - if err != nil { - client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error()))) + ollamaRequest := OllamaRequest{ + Model: appConfig.Model, + System: appConfig.OllamaSystem, + Prompt: prompt, + Stream: false, + Format: "json", + Options: OllamaRequestOptions{ + Temperature: appConfig.OllamaTemp, + }, + } - return - } + jsonPayload, err := json.Marshal(ollamaRequest) + if err != nil { + client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error()))) - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second) - defer cancel() + return + } - request, err := http.NewRequest(http.MethodPost, appConfig.OllamaEndpoint, bytes.NewBuffer(jsonPayload)) - request = request.WithContext(ctx) - if err != nil { - client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error()))) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second) + defer cancel() - return - } + request, err := http.NewRequest(http.MethodPost, appConfig.OllamaEndpoint, bytes.NewBuffer(jsonPayload)) + request = request.WithContext(ctx) + if err != nil { + client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error()))) - request.Header.Set("Content-Type", "application/json") + return + } - httpClient := http.Client{} + request.Header.Set("Content-Type", "application/json") - response, err := httpClient.Do(request) - if err != nil { - client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error()))) + httpClient := http.Client{} - return - } - defer response.Body.Close() + response, err := httpClient.Do(request) + if err != nil { + client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error()))) - var ollamaResponse OllamaResponse - err = json.NewDecoder(response.Body).Decode(&ollamaResponse) - if err != nil { - client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error()))) + return + } + defer response.Body.Close() - return - } + var ollamaResponse OllamaResponse + err = json.NewDecoder(response.Body).Decode(&ollamaResponse) + if err != nil { + client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error()))) - var writer bytes.Buffer - err = quick.Highlight(&writer, ollamaResponse.Response, "markdown", appConfig.ChromaFormatter, appConfig.ChromaStyle) - if err != nil { - client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error()))) + return + } - return + var writer bytes.Buffer + err = quick.Highlight(&writer, ollamaResponse.Response, "markdown", appConfig.ChromaFormatter, appConfig.ChromaStyle) + if err != nil { + client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error()))) + + return + } + + fmt.Println(writer.String()) + client.Cmd.ReplyTo(event, girc.Fmt("\033[0m"+writer.String())) } + }) + } else if appConfig.Provider == "gemini" { + irc.Handlers.AddBg(girc.PRIVMSG, func(client *girc.Client, event girc.Event) { + if strings.HasPrefix(event.Last(), appConfig.IrcNick+": ") { + prompt := strings.TrimPrefix(event.Last(), appConfig.IrcNick+": ") + log.Println(prompt) - fmt.Println(writer.String()) - client.Cmd.ReplyTo(event, girc.Fmt("\033[0m"+writer.String())) - } - }) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second) + defer cancel() + + client_gemini, err := genai.NewClient(ctx, option.WithAPIKey(appConfig.Apikey)) + if err != nil { + client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error()))) + + return + } + defer client_gemini.Close() + + model := client_gemini.GenerativeModel(appConfig.Model) + resp, err := model.GenerateContent(ctx, genai.Text(prompt)) + if err != nil { + client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error()))) + + return + } + + fmt.Println(resp) + + // var writer bytes.Buffer + // err = quick.Highlight(&writer, resp, "markdown", appConfig.ChromaFormatter, appConfig.ChromaStyle) + // if err != nil { + // client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error()))) + + // return + // } + + // fmt.Println(writer.String()) + // client.Cmd.ReplyTo(event, girc.Fmt("\033[0m"+writer.String())) + } + }) + } ircChan <- irc |