hishtory/client/ai/ai.go
David Dworken 21b401bc14
Add ability to configure custom OpenAI API endpoint for #186 (#194)
* Add ability to configure custom OpenAI API endpoint for #186

* Ensure the AiCompletionEndpoint field is always initialized
2024-03-26 22:13:57 -07:00

84 lines
2.5 KiB
Go

package ai
import (
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"runtime"
"time"
"github.com/ddworken/hishtory/client/data"
"github.com/ddworken/hishtory/client/hctx"
"github.com/ddworken/hishtory/client/lib"
"github.com/ddworken/hishtory/shared/ai"
)
var mostRecentQuery string
func DebouncedGetAiSuggestions(ctx context.Context, shellName, query string, numberCompletions int) ([]string, error) {
mostRecentQuery = query
time.Sleep(time.Millisecond * 300)
if mostRecentQuery == query {
return GetAiSuggestions(ctx, shellName, query, numberCompletions)
}
return nil, nil
}
func GetAiSuggestions(ctx context.Context, shellName, query string, numberCompletions int) ([]string, error) {
if os.Getenv("OPENAI_API_KEY") == "" && hctx.GetConf(ctx).AiCompletionEndpoint == ai.DefaultOpenAiEndpoint {
return GetAiSuggestionsViaHishtoryApi(ctx, shellName, query, numberCompletions)
} else {
suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(hctx.GetConf(ctx).AiCompletionEndpoint, query, shellName, getOsName(), numberCompletions)
return suggestions, err
}
}
func getOsName() string {
switch runtime.GOOS {
case "linux":
if _, err := exec.LookPath("apt-get"); err == nil {
return "Ubuntu Linux"
}
if _, err := exec.LookPath("dnf"); err == nil {
return "Fedora Linux"
}
if _, err := exec.LookPath("pacman"); err == nil {
return "Arch Linux"
}
return "Linux"
case "darwin":
return "MacOS"
default:
return runtime.GOOS
}
}
func GetAiSuggestionsViaHishtoryApi(ctx context.Context, shellName, query string, numberCompletions int) ([]string, error) {
hctx.GetLogger().Infof("Running OpenAI query for %#v", query)
req := ai.AiSuggestionRequest{
DeviceId: hctx.GetConf(ctx).DeviceId,
UserId: data.UserId(hctx.GetConf(ctx).UserSecret),
Query: query,
NumberCompletions: numberCompletions,
OsName: getOsName(),
ShellName: shellName,
}
reqData, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal AiSuggestionRequest: %w", err)
}
respData, err := lib.ApiPost(ctx, "/api/v1/ai-suggest", "application/json", reqData)
if err != nil {
return nil, fmt.Errorf("failed to query /api/v1/ai-suggest: %w", err)
}
var resp ai.AiSuggestionResponse
err = json.Unmarshal(respData, &resp)
if err != nil {
return nil, fmt.Errorf("failed to parse /api/v1/ai-suggest response: %w", err)
}
hctx.GetLogger().Infof("For OpenAI query=%#v ==> %#v", query, resp.Suggestions)
return resp.Suggestions, nil
}