Add ability to configure custom OpenAI API endpoint for #186 (#194)

* Add ability to configure custom OpenAI API endpoint for #186

* Ensure the AiCompletionEndpoint field is always initialized
This commit is contained in:
David Dworken 2024-03-26 22:13:57 -07:00 committed by GitHub
parent 46e92803be
commit 21b401bc14
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 52 additions and 26 deletions

View File

@ -331,7 +331,7 @@ func (s *Server) aiSuggestionHandler(w http.ResponseWriter, r *http.Request) {
if numDevices == 0 { if numDevices == 0 {
panic(fmt.Errorf("rejecting OpenAI request for user_id=%#v since it does not exist", req.UserId)) panic(fmt.Errorf("rejecting OpenAI request for user_id=%#v since it does not exist", req.UserId))
} }
suggestions, usage, err := ai.GetAiSuggestionsViaOpenAiApi(req.Query, req.ShellName, req.OsName, req.NumberCompletions) suggestions, usage, err := ai.GetAiSuggestionsViaOpenAiApi(ai.DefaultOpenAiEndpoint, req.Query, req.ShellName, req.OsName, req.NumberCompletions)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to query OpenAI API: %w", err)) panic(fmt.Errorf("failed to query OpenAI API: %w", err))
} }

View File

@ -27,10 +27,10 @@ func DebouncedGetAiSuggestions(ctx context.Context, shellName, query string, num
} }
func GetAiSuggestions(ctx context.Context, shellName, query string, numberCompletions int) ([]string, error) { func GetAiSuggestions(ctx context.Context, shellName, query string, numberCompletions int) ([]string, error) {
if os.Getenv("OPENAI_API_KEY") == "" { if os.Getenv("OPENAI_API_KEY") == "" && hctx.GetConf(ctx).AiCompletionEndpoint == ai.DefaultOpenAiEndpoint {
return GetAiSuggestionsViaHishtoryApi(ctx, shellName, query, numberCompletions) return GetAiSuggestionsViaHishtoryApi(ctx, shellName, query, numberCompletions)
} else { } else {
suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(query, shellName, getOsName(), numberCompletions) suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(hctx.GetConf(ctx).AiCompletionEndpoint, query, shellName, getOsName(), numberCompletions)
return suggestions, err return suggestions, err
} }
} }

View File

@ -1572,6 +1572,17 @@ func testConfigGetSet(t *testing.T, tester shellTester) {
if out != "Command \"Exit Code\" Timestamp foobar \n" { if out != "Command \"Exit Code\" Timestamp foobar \n" {
t.Fatalf("unexpected config-get output: %#v", out) t.Fatalf("unexpected config-get output: %#v", out)
} }
// For OpenAI endpoints
out = tester.RunInteractiveShell(t, `hishtory config-get ai-completion-endpoint`)
if out != "https://api.openai.com/v1/chat/completions\n" {
t.Fatalf("unexpected config-get output: %#v", out)
}
tester.RunInteractiveShell(t, `hishtory config-set ai-completion-endpoint https://example.com/foo/bar`)
out = tester.RunInteractiveShell(t, `hishtory config-get ai-completion-endpoint`)
if out != "https://example.com/foo/bar\n" {
t.Fatalf("unexpected config-get output: %#v", out)
}
} }
func clearControlRSearchFromConfig(t testing.TB) { func clearControlRSearchFromConfig(t testing.TB) {
@ -2166,7 +2177,6 @@ func testTui_ai(t *testing.T) {
}) })
out = stripTuiCommandPrefix(t, out) out = stripTuiCommandPrefix(t, out)
testutils.CompareGoldens(t, out, "TestTui-AiQuery-Disabled") testutils.CompareGoldens(t, out, "TestTui-AiQuery-Disabled")
} }
func testControlR(t *testing.T, tester shellTester, shellName string, onlineStatus OnlineStatus) { func testControlR(t *testing.T, tester shellTester, shellName string, onlineStatus OnlineStatus) {

View File

@ -146,6 +146,16 @@ var getColorScheme = &cobra.Command{
}, },
} }
var getAiCompletionEndpoint = &cobra.Command{
Use: "ai-completion-endpoint",
Short: "The AI endpoint to use for AI completions",
Run: func(cmd *cobra.Command, args []string) {
ctx := hctx.MakeContext()
config := hctx.GetConf(ctx)
fmt.Println(config.AiCompletionEndpoint)
},
}
func init() { func init() {
rootCmd.AddCommand(configGetCmd) rootCmd.AddCommand(configGetCmd)
configGetCmd.AddCommand(getEnableControlRCmd) configGetCmd.AddCommand(getEnableControlRCmd)
@ -159,4 +169,5 @@ func init() {
configGetCmd.AddCommand(getPresavingCmd) configGetCmd.AddCommand(getPresavingCmd)
configGetCmd.AddCommand(getColorScheme) configGetCmd.AddCommand(getColorScheme)
configGetCmd.AddCommand(getDefaultFilterCmd) configGetCmd.AddCommand(getDefaultFilterCmd)
configGetCmd.AddCommand(getAiCompletionEndpoint)
} }

View File

@ -217,6 +217,18 @@ func validateColor(color string) error {
return nil return nil
} }
var setAiCompletionEndpoint = &cobra.Command{
Use: "ai-completion-endpoint",
Short: "The AI endpoint to use for AI completions",
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
ctx := hctx.MakeContext()
config := hctx.GetConf(ctx)
config.AiCompletionEndpoint = args[0]
lib.CheckFatalError(hctx.SetConfig(config))
},
}
func init() { func init() {
rootCmd.AddCommand(configSetCmd) rootCmd.AddCommand(configSetCmd)
configSetCmd.AddCommand(setEnableControlRCmd) configSetCmd.AddCommand(setEnableControlRCmd)
@ -229,6 +241,7 @@ func init() {
configSetCmd.AddCommand(setPresavingCmd) configSetCmd.AddCommand(setPresavingCmd)
configSetCmd.AddCommand(setColorSchemeCmd) configSetCmd.AddCommand(setColorSchemeCmd)
configSetCmd.AddCommand(setDefaultFilterCommand) configSetCmd.AddCommand(setDefaultFilterCommand)
configSetCmd.AddCommand(setAiCompletionEndpoint)
setColorSchemeCmd.AddCommand(setColorSchemeSelectedText) setColorSchemeCmd.AddCommand(setColorSchemeSelectedText)
setColorSchemeCmd.AddCommand(setColorSchemeSelectedBackground) setColorSchemeCmd.AddCommand(setColorSchemeSelectedBackground)
setColorSchemeCmd.AddCommand(setColorSchemeBorderColor) setColorSchemeCmd.AddCommand(setColorSchemeBorderColor)

View File

@ -205,6 +205,8 @@ type ClientConfig struct {
ColorScheme ColorScheme `json:"color_scheme"` ColorScheme ColorScheme `json:"color_scheme"`
// A default filter that will be applied to all search queries // A default filter that will be applied to all search queries
DefaultFilter string `json:"default_filter"` DefaultFilter string `json:"default_filter"`
// The endpoint to use for AI suggestions
AiCompletionEndpoint string `json:"ai_completion_endpoint"`
} }
type ColorScheme struct { type ColorScheme struct {
@ -272,6 +274,9 @@ func GetConfig() (ClientConfig, error) {
if config.ColorScheme.BorderColor == "" { if config.ColorScheme.BorderColor == "" {
config.ColorScheme.BorderColor = GetDefaultColorScheme().BorderColor config.ColorScheme.BorderColor = GetDefaultColorScheme().BorderColor
} }
if config.AiCompletionEndpoint == "" {
config.AiCompletionEndpoint = "https://api.openai.com/v1/chat/completions"
}
return config, nil return config, nil
} }

View File

@ -1,17 +0,0 @@
package main
import (
"fmt"
"log"
"strings"
"github.com/ddworken/hishtory/shared/ai"
)
func main() {
resp, _, err := ai.GetAiSuggestionsViaOpenAiApi("Find all CSV files in the current directory or subdirectories and select the first column, then prepend `foo` to each line", "bash", "MacOS", 3)
if err != nil {
log.Fatal(err)
}
fmt.Println(strings.Join(resp, "\n"))
}

View File

@ -12,6 +12,8 @@ import (
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
) )
const DefaultOpenAiEndpoint = "https://api.openai.com/v1/chat/completions"
type openAiRequest struct { type openAiRequest struct {
Model string `json:"model"` Model string `json:"model"`
Messages []openAiMessage `json:"messages"` Messages []openAiMessage `json:"messages"`
@ -51,7 +53,7 @@ type TestOnlyOverrideAiSuggestionRequest struct {
var TestOnlyOverrideAiSuggestions map[string][]string = make(map[string][]string) var TestOnlyOverrideAiSuggestions map[string][]string = make(map[string][]string)
func GetAiSuggestionsViaOpenAiApi(query, shellName, osName string, numberCompletions int) ([]string, OpenAiUsage, error) { func GetAiSuggestionsViaOpenAiApi(apiEndpoint, query, shellName, osName string, numberCompletions int) ([]string, OpenAiUsage, error) {
if results := TestOnlyOverrideAiSuggestions[query]; len(results) > 0 { if results := TestOnlyOverrideAiSuggestions[query]; len(results) > 0 {
return results, OpenAiUsage{}, nil return results, OpenAiUsage{}, nil
} }
@ -63,7 +65,7 @@ func GetAiSuggestionsViaOpenAiApi(query, shellName, osName string, numberComplet
shellName = "bash" shellName = "bash"
} }
apiKey := os.Getenv("OPENAI_API_KEY") apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey == "" { if apiKey == "" && apiEndpoint == DefaultOpenAiEndpoint {
return nil, OpenAiUsage{}, fmt.Errorf("OPENAI_API_KEY environment variable is not set") return nil, OpenAiUsage{}, fmt.Errorf("OPENAI_API_KEY environment variable is not set")
} }
client := &http.Client{} client := &http.Client{}
@ -82,12 +84,14 @@ func GetAiSuggestionsViaOpenAiApi(query, shellName, osName string, numberComplet
if err != nil { if err != nil {
return nil, OpenAiUsage{}, fmt.Errorf("failed to serialize JSON for OpenAI API: %w", err) return nil, OpenAiUsage{}, fmt.Errorf("failed to serialize JSON for OpenAI API: %w", err)
} }
req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", bytes.NewBuffer(apiReqStr)) req, err := http.NewRequest("POST", apiEndpoint, bytes.NewBuffer(apiReqStr))
if err != nil { if err != nil {
return nil, OpenAiUsage{}, fmt.Errorf("failed to create OpenAI API request: %w", err) return nil, OpenAiUsage{}, fmt.Errorf("failed to create OpenAI API request: %w", err)
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
if apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey) req.Header.Set("Authorization", "Bearer "+apiKey)
}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return nil, OpenAiUsage{}, fmt.Errorf("failed to query OpenAI API: %w", err) return nil, OpenAiUsage{}, fmt.Errorf("failed to query OpenAI API: %w", err)

View File

@ -13,7 +13,7 @@ func TestLiveOpenAiApi(t *testing.T) {
if os.Getenv("OPENAI_API_KEY") == "" { if os.Getenv("OPENAI_API_KEY") == "" {
t.Skip("Skipping test since OPENAI_API_KEY is not set") t.Skip("Skipping test since OPENAI_API_KEY is not set")
} }
results, _, err := GetAiSuggestionsViaOpenAiApi("list files in the current directory", "bash", "Linux", 3) results, _, err := GetAiSuggestionsViaOpenAiApi("https://api.openai.com/v1/chat/completions", "list files in the current directory", "bash", "Linux", 3)
require.NoError(t, err) require.NoError(t, err)
resultsContainsLs := false resultsContainsLs := false
for _, result := range results { for _, result := range results {