Add ability for the client to configure the model via an environment variable

This commit is contained in:
David Dworken 2024-08-11 12:15:44 -07:00
parent 820e3b8567
commit f64f97095f
No known key found for this signature in database
4 changed files with 22 additions and 12 deletions

View File

@ -332,7 +332,7 @@ func (s *Server) aiSuggestionHandler(w http.ResponseWriter, r *http.Request) {
if numDevices == 0 {
panic(fmt.Errorf("rejecting OpenAI request for user_id=%#v since it does not exist", req.UserId))
}
suggestions, usage, err := ai.GetAiSuggestionsViaOpenAiApi(ai.DefaultOpenAiEndpoint, req.Query, req.ShellName, req.OsName, req.NumberCompletions)
suggestions, usage, err := ai.GetAiSuggestionsViaOpenAiApi(ai.DefaultOpenAiEndpoint, req.Query, req.ShellName, req.OsName, req.Model, req.NumberCompletions)
if err != nil {
panic(fmt.Errorf("failed to query OpenAI API: %w", err))
}

View File

@ -30,7 +30,7 @@ func GetAiSuggestions(ctx context.Context, shellName, query string, numberComple
if os.Getenv("OPENAI_API_KEY") == "" && hctx.GetConf(ctx).AiCompletionEndpoint == ai.DefaultOpenAiEndpoint {
return GetAiSuggestionsViaHishtoryApi(ctx, shellName, query, numberCompletions)
} else {
suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(hctx.GetConf(ctx).AiCompletionEndpoint, query, shellName, getOsName(), numberCompletions)
suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(hctx.GetConf(ctx).AiCompletionEndpoint, query, shellName, getOsName(), os.Getenv("OPENAI_API_MODEL"), numberCompletions)
return suggestions, err
}
}
@ -64,6 +64,7 @@ func GetAiSuggestionsViaHishtoryApi(ctx context.Context, shellName, query string
NumberCompletions: numberCompletions,
OsName: getOsName(),
ShellName: shellName,
Model: os.Getenv("OPENAI_API_MODEL"),
}
reqData, err := json.Marshal(req)
if err != nil {

View File

@ -54,7 +54,7 @@ type TestOnlyOverrideAiSuggestionRequest struct {
var TestOnlyOverrideAiSuggestions map[string][]string = make(map[string][]string)
func GetAiSuggestionsViaOpenAiApi(apiEndpoint, query, shellName, osName string, numberCompletions int) ([]string, OpenAiUsage, error) {
func GetAiSuggestionsViaOpenAiApi(apiEndpoint, query, shellName, osName, overriddenOpenAiModel string, numberCompletions int) ([]string, OpenAiUsage, error) {
if results := TestOnlyOverrideAiSuggestions[query]; len(results) > 0 {
return results, OpenAiUsage{}, nil
}
@ -63,7 +63,7 @@ func GetAiSuggestionsViaOpenAiApi(apiEndpoint, query, shellName, osName string,
if apiKey == "" && apiEndpoint == DefaultOpenAiEndpoint {
return nil, OpenAiUsage{}, fmt.Errorf("OPENAI_API_KEY environment variable is not set")
}
apiReqStr, err := json.Marshal(createOpenAiRequest(query, shellName, osName, numberCompletions))
apiReqStr, err := json.Marshal(createOpenAiRequest(query, shellName, osName, overriddenOpenAiModel, numberCompletions))
if err != nil {
return nil, OpenAiUsage{}, fmt.Errorf("failed to serialize JSON for OpenAI API: %w", err)
}
@ -112,13 +112,14 @@ type AiSuggestionRequest struct {
NumberCompletions int `json:"number_completions"`
ShellName string `json:"shell_name"`
OsName string `json:"os_name"`
Model string `json:"model"`
}
type AiSuggestionResponse struct {
Suggestions []string `json:"suggestions"`
}
func createOpenAiRequest(query, shellName, osName string, numberCompletions int) openAiRequest {
func createOpenAiRequest(query, shellName, osName, overriddenOpenAiModel string, numberCompletions int) openAiRequest {
if osName == "" {
osName = "Linux"
}
@ -126,11 +127,14 @@ func createOpenAiRequest(query, shellName, osName string, numberCompletions int)
shellName = "bash"
}
model := os.Getenv("OPENAI_API_MODEL")
if model == "" {
// According to https://platform.openai.com/docs/models gpt-4o-mini is the best model
// by performance/price ratio.
model = "gpt-4o-mini"
// According to https://platform.openai.com/docs/models gpt-4o-mini is the best model
// by performance/price ratio.
model := "gpt-4o-mini"
if envModel := os.Getenv("OPENAI_API_MODEL"); envModel != "" {
model = envModel
}
if overriddenOpenAiModel != "" {
model = overriddenOpenAiModel
}
if envNumberCompletions := os.Getenv("OPENAI_API_NUMBER_COMPLETIONS"); envNumberCompletions != "" {

View File

@ -5,15 +5,20 @@ import (
"strings"
"testing"
"github.com/ddworken/hishtory/shared/testutils"
"github.com/stretchr/testify/require"
)
// A basic sanity test that our integration with the OpenAI API is correct and is returning reasonable results (at least for a very basic query)
func TestLiveOpenAiApi(t *testing.T) {
if os.Getenv("OPENAI_API_KEY") == "" {
t.Skip("Skipping test since OPENAI_API_KEY is not set")
if testutils.IsGithubAction() {
t.Fatal("OPENAI_API_KEY is not set, cannot run TestLiveOpenAiApi")
} else {
t.Skip("Skipping test since OPENAI_API_KEY is not set")
}
}
results, _, err := GetAiSuggestionsViaOpenAiApi("https://api.openai.com/v1/chat/completions", "list files in the current directory", "bash", "Linux", 3)
results, _, err := GetAiSuggestionsViaOpenAiApi("https://api.openai.com/v1/chat/completions", "list files in the current directory", "bash", "Linux", "", 3)
require.NoError(t, err)
resultsContainsLs := false
for _, result := range results {