diff --git a/backend/server/internal/server/api_handlers.go b/backend/server/internal/server/api_handlers.go index d815e57..eef981b 100644 --- a/backend/server/internal/server/api_handlers.go +++ b/backend/server/internal/server/api_handlers.go @@ -332,7 +332,7 @@ func (s *Server) aiSuggestionHandler(w http.ResponseWriter, r *http.Request) { if numDevices == 0 { panic(fmt.Errorf("rejecting OpenAI request for user_id=%#v since it does not exist", req.UserId)) } - suggestions, usage, err := ai.GetAiSuggestionsViaOpenAiApi(ai.DefaultOpenAiEndpoint, req.Query, req.ShellName, req.OsName, req.NumberCompletions) + suggestions, usage, err := ai.GetAiSuggestionsViaOpenAiApi(ai.DefaultOpenAiEndpoint, req.Query, req.ShellName, req.OsName, req.Model, req.NumberCompletions) if err != nil { panic(fmt.Errorf("failed to query OpenAI API: %w", err)) } diff --git a/client/ai/ai.go b/client/ai/ai.go index 67acc74..bf180ef 100644 --- a/client/ai/ai.go +++ b/client/ai/ai.go @@ -30,7 +30,7 @@ func GetAiSuggestions(ctx context.Context, shellName, query string, numberComple if os.Getenv("OPENAI_API_KEY") == "" && hctx.GetConf(ctx).AiCompletionEndpoint == ai.DefaultOpenAiEndpoint { return GetAiSuggestionsViaHishtoryApi(ctx, shellName, query, numberCompletions) } else { - suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(hctx.GetConf(ctx).AiCompletionEndpoint, query, shellName, getOsName(), numberCompletions) + suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(hctx.GetConf(ctx).AiCompletionEndpoint, query, shellName, getOsName(), os.Getenv("OPENAI_API_MODEL"), numberCompletions) return suggestions, err } } @@ -64,6 +64,7 @@ func GetAiSuggestionsViaHishtoryApi(ctx context.Context, shellName, query string NumberCompletions: numberCompletions, OsName: getOsName(), ShellName: shellName, + Model: os.Getenv("OPENAI_API_MODEL"), } reqData, err := json.Marshal(req) if err != nil { diff --git a/shared/ai/ai.go b/shared/ai/ai.go index 4bdb52d..c226bc0 100644 --- a/shared/ai/ai.go +++ b/shared/ai/ai.go @@ -54,7 +54,7 @@ type TestOnlyOverrideAiSuggestionRequest struct { var TestOnlyOverrideAiSuggestions map[string][]string = make(map[string][]string) -func GetAiSuggestionsViaOpenAiApi(apiEndpoint, query, shellName, osName string, numberCompletions int) ([]string, OpenAiUsage, error) { +func GetAiSuggestionsViaOpenAiApi(apiEndpoint, query, shellName, osName, overriddenOpenAiModel string, numberCompletions int) ([]string, OpenAiUsage, error) { if results := TestOnlyOverrideAiSuggestions[query]; len(results) > 0 { return results, OpenAiUsage{}, nil } @@ -63,7 +63,7 @@ func GetAiSuggestionsViaOpenAiApi(apiEndpoint, query, shellName, osName string, if apiKey == "" && apiEndpoint == DefaultOpenAiEndpoint { return nil, OpenAiUsage{}, fmt.Errorf("OPENAI_API_KEY environment variable is not set") } - apiReqStr, err := json.Marshal(createOpenAiRequest(query, shellName, osName, numberCompletions)) + apiReqStr, err := json.Marshal(createOpenAiRequest(query, shellName, osName, overriddenOpenAiModel, numberCompletions)) if err != nil { return nil, OpenAiUsage{}, fmt.Errorf("failed to serialize JSON for OpenAI API: %w", err) } @@ -112,13 +112,14 @@ type AiSuggestionRequest struct { NumberCompletions int `json:"number_completions"` ShellName string `json:"shell_name"` OsName string `json:"os_name"` + Model string `json:"model"` } type AiSuggestionResponse struct { Suggestions []string `json:"suggestions"` } -func createOpenAiRequest(query, shellName, osName string, numberCompletions int) openAiRequest { +func createOpenAiRequest(query, shellName, osName, overriddenOpenAiModel string, numberCompletions int) openAiRequest { if osName == "" { osName = "Linux" } @@ -126,11 +127,14 @@ func createOpenAiRequest(query, shellName, osName string, numberCompletions int) shellName = "bash" } - model := os.Getenv("OPENAI_API_MODEL") - if model == "" { - // According to https://platform.openai.com/docs/models gpt-4o-mini is the best model - // by performance/price ratio. - model = "gpt-4o-mini" + // According to https://platform.openai.com/docs/models gpt-4o-mini is the best model + // by performance/price ratio. + model := "gpt-4o-mini" + if envModel := os.Getenv("OPENAI_API_MODEL"); envModel != "" { + model = envModel + } + if overriddenOpenAiModel != "" { + model = overriddenOpenAiModel } if envNumberCompletions := os.Getenv("OPENAI_API_NUMBER_COMPLETIONS"); envNumberCompletions != "" { diff --git a/shared/ai/ai_test.go b/shared/ai/ai_test.go index 7551f33..509a669 100644 --- a/shared/ai/ai_test.go +++ b/shared/ai/ai_test.go @@ -5,15 +5,20 @@ import ( "strings" "testing" + "github.com/ddworken/hishtory/shared/testutils" "github.com/stretchr/testify/require" ) // A basic sanity test that our integration with the OpenAI API is correct and is returning reasonable results (at least for a very basic query) func TestLiveOpenAiApi(t *testing.T) { if os.Getenv("OPENAI_API_KEY") == "" { - t.Skip("Skipping test since OPENAI_API_KEY is not set") + if testutils.IsGithubAction() { + t.Fatal("OPENAI_API_KEY is not set, cannot run TestLiveOpenAiApi") + } else { + t.Skip("Skipping test since OPENAI_API_KEY is not set") + } } - results, _, err := GetAiSuggestionsViaOpenAiApi("https://api.openai.com/v1/chat/completions", "list files in the current directory", "bash", "Linux", 3) + results, _, err := GetAiSuggestionsViaOpenAiApi("https://api.openai.com/v1/chat/completions", "list files in the current directory", "bash", "Linux", "", 3) require.NoError(t, err) resultsContainsLs := false for _, result := range results {