From 21b401bc1484fce756302fdaf13e62229f98c70f Mon Sep 17 00:00:00 2001 From: David Dworken Date: Tue, 26 Mar 2024 22:13:57 -0700 Subject: [PATCH] Add ability to configure custom OpenAI API endpoint for #186 (#194) * Add ability to configure custom OpenAI API endpoint for #186 * Ensure the AiCompletionEndpoint field is always initialized --- backend/server/internal/server/api_handlers.go | 2 +- client/ai/ai.go | 4 ++-- client/client_test.go | 12 +++++++++++- client/cmd/configGet.go | 11 +++++++++++ client/cmd/configSet.go | 13 +++++++++++++ client/hctx/hctx.go | 5 +++++ scripts/aimain.go | 17 ----------------- shared/ai/ai.go | 12 ++++++++---- shared/ai/ai_test.go | 2 +- 9 files changed, 52 insertions(+), 26 deletions(-) delete mode 100644 scripts/aimain.go diff --git a/backend/server/internal/server/api_handlers.go b/backend/server/internal/server/api_handlers.go index a8ee141..2500b3d 100644 --- a/backend/server/internal/server/api_handlers.go +++ b/backend/server/internal/server/api_handlers.go @@ -331,7 +331,7 @@ func (s *Server) aiSuggestionHandler(w http.ResponseWriter, r *http.Request) { if numDevices == 0 { panic(fmt.Errorf("rejecting OpenAI request for user_id=%#v since it does not exist", req.UserId)) } - suggestions, usage, err := ai.GetAiSuggestionsViaOpenAiApi(req.Query, req.ShellName, req.OsName, req.NumberCompletions) + suggestions, usage, err := ai.GetAiSuggestionsViaOpenAiApi(ai.DefaultOpenAiEndpoint, req.Query, req.ShellName, req.OsName, req.NumberCompletions) if err != nil { panic(fmt.Errorf("failed to query OpenAI API: %w", err)) } diff --git a/client/ai/ai.go b/client/ai/ai.go index 5078cd9..f32c1e4 100644 --- a/client/ai/ai.go +++ b/client/ai/ai.go @@ -27,10 +27,10 @@ func DebouncedGetAiSuggestions(ctx context.Context, shellName, query string, num } func GetAiSuggestions(ctx context.Context, shellName, query string, numberCompletions int) ([]string, error) { - if os.Getenv("OPENAI_API_KEY") == "" { + if os.Getenv("OPENAI_API_KEY") == "" && hctx.GetConf(ctx).AiCompletionEndpoint == ai.DefaultOpenAiEndpoint { return GetAiSuggestionsViaHishtoryApi(ctx, shellName, query, numberCompletions) } else { - suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(query, shellName, getOsName(), numberCompletions) + suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(hctx.GetConf(ctx).AiCompletionEndpoint, query, shellName, getOsName(), numberCompletions) return suggestions, err } } diff --git a/client/client_test.go b/client/client_test.go index d58d471..eb0edbb 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1572,6 +1572,17 @@ func testConfigGetSet(t *testing.T, tester shellTester) { if out != "Command \"Exit Code\" Timestamp foobar \n" { t.Fatalf("unexpected config-get output: %#v", out) } + + // For OpenAI endpoints + out = tester.RunInteractiveShell(t, `hishtory config-get ai-completion-endpoint`) + if out != "https://api.openai.com/v1/chat/completions\n" { + t.Fatalf("unexpected config-get output: %#v", out) + } + tester.RunInteractiveShell(t, `hishtory config-set ai-completion-endpoint https://example.com/foo/bar`) + out = tester.RunInteractiveShell(t, `hishtory config-get ai-completion-endpoint`) + if out != "https://example.com/foo/bar\n" { + t.Fatalf("unexpected config-get output: %#v", out) + } } func clearControlRSearchFromConfig(t testing.TB) { @@ -2166,7 +2177,6 @@ func testTui_ai(t *testing.T) { }) out = stripTuiCommandPrefix(t, out) testutils.CompareGoldens(t, out, "TestTui-AiQuery-Disabled") - } func testControlR(t *testing.T, tester shellTester, shellName string, onlineStatus OnlineStatus) { diff --git a/client/cmd/configGet.go b/client/cmd/configGet.go index 41ea858..8a08b14 100644 --- a/client/cmd/configGet.go +++ b/client/cmd/configGet.go @@ -146,6 +146,16 @@ var getColorScheme = &cobra.Command{ }, } +var getAiCompletionEndpoint = &cobra.Command{ + Use: "ai-completion-endpoint", + Short: "The AI endpoint to use for AI completions", + Run: func(cmd *cobra.Command, args []string) { + ctx := hctx.MakeContext() + config := hctx.GetConf(ctx) + fmt.Println(config.AiCompletionEndpoint) + }, +} + func init() { rootCmd.AddCommand(configGetCmd) configGetCmd.AddCommand(getEnableControlRCmd) @@ -159,4 +169,5 @@ func init() { configGetCmd.AddCommand(getPresavingCmd) configGetCmd.AddCommand(getColorScheme) configGetCmd.AddCommand(getDefaultFilterCmd) + configGetCmd.AddCommand(getAiCompletionEndpoint) } diff --git a/client/cmd/configSet.go b/client/cmd/configSet.go index 27e688b..03cc9af 100644 --- a/client/cmd/configSet.go +++ b/client/cmd/configSet.go @@ -217,6 +217,18 @@ func validateColor(color string) error { return nil } +var setAiCompletionEndpoint = &cobra.Command{ + Use: "ai-completion-endpoint", + Short: "The AI endpoint to use for AI completions", + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + ctx := hctx.MakeContext() + config := hctx.GetConf(ctx) + config.AiCompletionEndpoint = args[0] + lib.CheckFatalError(hctx.SetConfig(config)) + }, +} + func init() { rootCmd.AddCommand(configSetCmd) configSetCmd.AddCommand(setEnableControlRCmd) @@ -229,6 +241,7 @@ func init() { configSetCmd.AddCommand(setPresavingCmd) configSetCmd.AddCommand(setColorSchemeCmd) configSetCmd.AddCommand(setDefaultFilterCommand) + configSetCmd.AddCommand(setAiCompletionEndpoint) setColorSchemeCmd.AddCommand(setColorSchemeSelectedText) setColorSchemeCmd.AddCommand(setColorSchemeSelectedBackground) setColorSchemeCmd.AddCommand(setColorSchemeBorderColor) diff --git a/client/hctx/hctx.go b/client/hctx/hctx.go index 316a03b..9757c5b 100644 --- a/client/hctx/hctx.go +++ b/client/hctx/hctx.go @@ -205,6 +205,8 @@ type ClientConfig struct { ColorScheme ColorScheme `json:"color_scheme"` // A default filter that will be applied to all search queries DefaultFilter string `json:"default_filter"` + // The endpoint to use for AI suggestions + AiCompletionEndpoint string `json:"ai_completion_endpoint"` } type ColorScheme struct { @@ -272,6 +274,9 @@ func GetConfig() (ClientConfig, error) { if config.ColorScheme.BorderColor == "" { config.ColorScheme.BorderColor = GetDefaultColorScheme().BorderColor } + if config.AiCompletionEndpoint == "" { + config.AiCompletionEndpoint = "https://api.openai.com/v1/chat/completions" + } return config, nil } diff --git a/scripts/aimain.go b/scripts/aimain.go deleted file mode 100644 index 7fd11c5..0000000 --- a/scripts/aimain.go +++ /dev/null @@ -1,17 +0,0 @@ -package main - -import ( - "fmt" - "log" - "strings" - - "github.com/ddworken/hishtory/shared/ai" -) - -func main() { - resp, _, err := ai.GetAiSuggestionsViaOpenAiApi("Find all CSV files in the current directory or subdirectories and select the first column, then prepend `foo` to each line", "bash", "MacOS", 3) - if err != nil { - log.Fatal(err) - } - fmt.Println(strings.Join(resp, "\n")) -} diff --git a/shared/ai/ai.go b/shared/ai/ai.go index c6093ef..926b311 100644 --- a/shared/ai/ai.go +++ b/shared/ai/ai.go @@ -12,6 +12,8 @@ import ( "golang.org/x/exp/slices" ) +const DefaultOpenAiEndpoint = "https://api.openai.com/v1/chat/completions" + type openAiRequest struct { Model string `json:"model"` Messages []openAiMessage `json:"messages"` @@ -51,7 +53,7 @@ type TestOnlyOverrideAiSuggestionRequest struct { var TestOnlyOverrideAiSuggestions map[string][]string = make(map[string][]string) -func GetAiSuggestionsViaOpenAiApi(query, shellName, osName string, numberCompletions int) ([]string, OpenAiUsage, error) { +func GetAiSuggestionsViaOpenAiApi(apiEndpoint, query, shellName, osName string, numberCompletions int) ([]string, OpenAiUsage, error) { if results := TestOnlyOverrideAiSuggestions[query]; len(results) > 0 { return results, OpenAiUsage{}, nil } @@ -63,7 +65,7 @@ func GetAiSuggestionsViaOpenAiApi(query, shellName, osName string, numberComplet shellName = "bash" } apiKey := os.Getenv("OPENAI_API_KEY") - if apiKey == "" { + if apiKey == "" && apiEndpoint == DefaultOpenAiEndpoint { return nil, OpenAiUsage{}, fmt.Errorf("OPENAI_API_KEY environment variable is not set") } client := &http.Client{} @@ -82,12 +84,14 @@ func GetAiSuggestionsViaOpenAiApi(query, shellName, osName string, numberComplet if err != nil { return nil, OpenAiUsage{}, fmt.Errorf("failed to serialize JSON for OpenAI API: %w", err) } - req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", bytes.NewBuffer(apiReqStr)) + req, err := http.NewRequest("POST", apiEndpoint, bytes.NewBuffer(apiReqStr)) if err != nil { return nil, OpenAiUsage{}, fmt.Errorf("failed to create OpenAI API request: %w", err) } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) + if apiKey != "" { + req.Header.Set("Authorization", "Bearer "+apiKey) + } resp, err := client.Do(req) if err != nil { return nil, OpenAiUsage{}, fmt.Errorf("failed to query OpenAI API: %w", err) diff --git a/shared/ai/ai_test.go b/shared/ai/ai_test.go index 8454fd0..7551f33 100644 --- a/shared/ai/ai_test.go +++ b/shared/ai/ai_test.go @@ -13,7 +13,7 @@ func TestLiveOpenAiApi(t *testing.T) { if os.Getenv("OPENAI_API_KEY") == "" { t.Skip("Skipping test since OPENAI_API_KEY is not set") } - results, _, err := GetAiSuggestionsViaOpenAiApi("list files in the current directory", "bash", "Linux", 3) + results, _, err := GetAiSuggestionsViaOpenAiApi("https://api.openai.com/v1/chat/completions", "list files in the current directory", "bash", "Linux", 3) require.NoError(t, err) resultsContainsLs := false for _, result := range results {