aboutsummaryrefslogtreecommitdiffstats
path: root/main.go
diff options
context:
space:
mode:
authorterminaldweller <devi@terminaldweller.com>2024-02-12 16:30:55 +0000
committerterminaldweller <devi@terminaldweller.com>2024-02-12 16:30:55 +0000
commitb5a177a235d3b5834c81a2e97d3cae7e90584a5f (patch)
treec77248b21ecc4666813bd820f01d55a91a660924 /main.go
parentchannel var can now accept a list (diff)
downloadmilla-b5a177a235d3b5834c81a2e97d3cae7e90584a5f.tar.gz
milla-b5a177a235d3b5834c81a2e97d3cae7e90584a5f.zip
gemini support, [wip]
Diffstat (limited to 'main.go')
-rw-r--r--main.go147
1 files changed, 96 insertions, 51 deletions
diff --git a/main.go b/main.go
index 49c1d00..d7a8fb8 100644
--- a/main.go
+++ b/main.go
@@ -13,8 +13,10 @@ import (
"time"
"github.com/alecthomas/chroma/v2/quick"
+ "github.com/google/generative-ai-go/genai"
"github.com/lrstanley/girc"
"github.com/pelletier/go-toml/v2"
+ "google.golang.org/api/option"
)
type TomlConfig struct {
@@ -33,6 +35,8 @@ type TomlConfig struct {
Model string
ChromaStyle string
ChromaFormatter string
+ Provider string
+ Apikey string
}
type OllamaResponse struct {
@@ -80,72 +84,113 @@ func runIRC(appConfig TomlConfig, ircChan chan *girc.Client) {
}
})
- irc.Handlers.AddBg(girc.PRIVMSG, func(client *girc.Client, event girc.Event) {
- if strings.HasPrefix(event.Last(), appConfig.IrcNick+": ") {
- prompt := strings.TrimPrefix(event.Last(), appConfig.IrcNick+": ")
- log.Println(prompt)
-
- ollamaRequest := OllamaRequest{
- Model: appConfig.Model,
- System: appConfig.OllamaSystem,
- Prompt: prompt,
- Stream: false,
- Format: "json",
- Options: OllamaRequestOptions{
- Temperature: appConfig.OllamaTemp,
- },
- }
+ if appConfig.Provider == "ollama" {
+ irc.Handlers.AddBg(girc.PRIVMSG, func(client *girc.Client, event girc.Event) {
+ if strings.HasPrefix(event.Last(), appConfig.IrcNick+": ") {
+ prompt := strings.TrimPrefix(event.Last(), appConfig.IrcNick+": ")
+ log.Println(prompt)
- jsonPayload, err := json.Marshal(ollamaRequest)
- if err != nil {
- client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error())))
+ ollamaRequest := OllamaRequest{
+ Model: appConfig.Model,
+ System: appConfig.OllamaSystem,
+ Prompt: prompt,
+ Stream: false,
+ Format: "json",
+ Options: OllamaRequestOptions{
+ Temperature: appConfig.OllamaTemp,
+ },
+ }
- return
- }
+ jsonPayload, err := json.Marshal(ollamaRequest)
+ if err != nil {
+ client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error())))
- ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second)
- defer cancel()
+ return
+ }
- request, err := http.NewRequest(http.MethodPost, appConfig.OllamaEndpoint, bytes.NewBuffer(jsonPayload))
- request = request.WithContext(ctx)
- if err != nil {
- client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error())))
+ ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second)
+ defer cancel()
- return
- }
+ request, err := http.NewRequest(http.MethodPost, appConfig.OllamaEndpoint, bytes.NewBuffer(jsonPayload))
+ request = request.WithContext(ctx)
+ if err != nil {
+ client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error())))
- request.Header.Set("Content-Type", "application/json")
+ return
+ }
- httpClient := http.Client{}
+ request.Header.Set("Content-Type", "application/json")
- response, err := httpClient.Do(request)
- if err != nil {
- client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error())))
+ httpClient := http.Client{}
- return
- }
- defer response.Body.Close()
+ response, err := httpClient.Do(request)
+ if err != nil {
+ client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error())))
- var ollamaResponse OllamaResponse
- err = json.NewDecoder(response.Body).Decode(&ollamaResponse)
- if err != nil {
- client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error())))
+ return
+ }
+ defer response.Body.Close()
- return
- }
+ var ollamaResponse OllamaResponse
+ err = json.NewDecoder(response.Body).Decode(&ollamaResponse)
+ if err != nil {
+ client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error())))
- var writer bytes.Buffer
- err = quick.Highlight(&writer, ollamaResponse.Response, "markdown", appConfig.ChromaFormatter, appConfig.ChromaStyle)
- if err != nil {
- client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error())))
+ return
+ }
- return
+ var writer bytes.Buffer
+ err = quick.Highlight(&writer, ollamaResponse.Response, "markdown", appConfig.ChromaFormatter, appConfig.ChromaStyle)
+ if err != nil {
+ client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error())))
+
+ return
+ }
+
+ fmt.Println(writer.String())
+ client.Cmd.ReplyTo(event, girc.Fmt("\033[0m"+writer.String()))
}
+ })
+ } else if appConfig.Provider == "gemini" {
+ irc.Handlers.AddBg(girc.PRIVMSG, func(client *girc.Client, event girc.Event) {
+ if strings.HasPrefix(event.Last(), appConfig.IrcNick+": ") {
+ prompt := strings.TrimPrefix(event.Last(), appConfig.IrcNick+": ")
+ log.Println(prompt)
- fmt.Println(writer.String())
- client.Cmd.ReplyTo(event, girc.Fmt("\033[0m"+writer.String()))
- }
- })
+ ctx, cancel := context.WithTimeout(context.Background(), time.Duration(appConfig.RequestTimeout)*time.Second)
+ defer cancel()
+
+ client_gemini, err := genai.NewClient(ctx, option.WithAPIKey(appConfig.Apikey))
+ if err != nil {
+ client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error())))
+
+ return
+ }
+ defer client_gemini.Close()
+
+ model := client_gemini.GenerativeModel(appConfig.Model)
+ resp, err := model.GenerateContent(ctx, genai.Text(prompt))
+ if err != nil {
+ client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error())))
+
+ return
+ }
+
+ fmt.Println(resp)
+
+ // var writer bytes.Buffer
+ // err = quick.Highlight(&writer, resp, "markdown", appConfig.ChromaFormatter, appConfig.ChromaStyle)
+ // if err != nil {
+ // client.Cmd.ReplyTo(event, girc.Fmt(fmt.Sprintf("error: %s", err.Error())))
+
+ // return
+ // }
+
+ // fmt.Println(writer.String())
+ // client.Cmd.ReplyTo(event, girc.Fmt("\033[0m"+writer.String()))
+ }
+ })
+ }
ircChan <- irc