Add ability to configure custom OpenAI API endpoint for #186 (#194)

* Add ability to configure custom OpenAI API endpoint for #186

* Ensure the AiCompletionEndpoint field is always initialized
This commit is contained in:
David Dworken 2024-03-26 22:13:57 -07:00 committed by GitHub
parent 46e92803be
commit 21b401bc14
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 52 additions and 26 deletions

View File

@ -331,7 +331,7 @@ func (s *Server) aiSuggestionHandler(w http.ResponseWriter, r *http.Request) {
if numDevices == 0 {
panic(fmt.Errorf("rejecting OpenAI request for user_id=%#v since it does not exist", req.UserId))
}
suggestions, usage, err := ai.GetAiSuggestionsViaOpenAiApi(req.Query, req.ShellName, req.OsName, req.NumberCompletions)
suggestions, usage, err := ai.GetAiSuggestionsViaOpenAiApi(ai.DefaultOpenAiEndpoint, req.Query, req.ShellName, req.OsName, req.NumberCompletions)
if err != nil {
panic(fmt.Errorf("failed to query OpenAI API: %w", err))
}

View File

@ -27,10 +27,10 @@ func DebouncedGetAiSuggestions(ctx context.Context, shellName, query string, num
}
func GetAiSuggestions(ctx context.Context, shellName, query string, numberCompletions int) ([]string, error) {
if os.Getenv("OPENAI_API_KEY") == "" {
if os.Getenv("OPENAI_API_KEY") == "" && hctx.GetConf(ctx).AiCompletionEndpoint == ai.DefaultOpenAiEndpoint {
return GetAiSuggestionsViaHishtoryApi(ctx, shellName, query, numberCompletions)
} else {
suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(query, shellName, getOsName(), numberCompletions)
suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(hctx.GetConf(ctx).AiCompletionEndpoint, query, shellName, getOsName(), numberCompletions)
return suggestions, err
}
}

View File

@ -1572,6 +1572,17 @@ func testConfigGetSet(t *testing.T, tester shellTester) {
if out != "Command \"Exit Code\" Timestamp foobar \n" {
t.Fatalf("unexpected config-get output: %#v", out)
}
// For OpenAI endpoints
out = tester.RunInteractiveShell(t, `hishtory config-get ai-completion-endpoint`)
if out != "https://api.openai.com/v1/chat/completions\n" {
t.Fatalf("unexpected config-get output: %#v", out)
}
tester.RunInteractiveShell(t, `hishtory config-set ai-completion-endpoint https://example.com/foo/bar`)
out = tester.RunInteractiveShell(t, `hishtory config-get ai-completion-endpoint`)
if out != "https://example.com/foo/bar\n" {
t.Fatalf("unexpected config-get output: %#v", out)
}
}
func clearControlRSearchFromConfig(t testing.TB) {
@ -2166,7 +2177,6 @@ func testTui_ai(t *testing.T) {
})
out = stripTuiCommandPrefix(t, out)
testutils.CompareGoldens(t, out, "TestTui-AiQuery-Disabled")
}
func testControlR(t *testing.T, tester shellTester, shellName string, onlineStatus OnlineStatus) {

View File

@ -146,6 +146,16 @@ var getColorScheme = &cobra.Command{
},
}
var getAiCompletionEndpoint = &cobra.Command{
Use: "ai-completion-endpoint",
Short: "The AI endpoint to use for AI completions",
Run: func(cmd *cobra.Command, args []string) {
ctx := hctx.MakeContext()
config := hctx.GetConf(ctx)
fmt.Println(config.AiCompletionEndpoint)
},
}
func init() {
rootCmd.AddCommand(configGetCmd)
configGetCmd.AddCommand(getEnableControlRCmd)
@ -159,4 +169,5 @@ func init() {
configGetCmd.AddCommand(getPresavingCmd)
configGetCmd.AddCommand(getColorScheme)
configGetCmd.AddCommand(getDefaultFilterCmd)
configGetCmd.AddCommand(getAiCompletionEndpoint)
}

View File

@ -217,6 +217,18 @@ func validateColor(color string) error {
return nil
}
var setAiCompletionEndpoint = &cobra.Command{
Use: "ai-completion-endpoint",
Short: "The AI endpoint to use for AI completions",
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
ctx := hctx.MakeContext()
config := hctx.GetConf(ctx)
config.AiCompletionEndpoint = args[0]
lib.CheckFatalError(hctx.SetConfig(config))
},
}
func init() {
rootCmd.AddCommand(configSetCmd)
configSetCmd.AddCommand(setEnableControlRCmd)
@ -229,6 +241,7 @@ func init() {
configSetCmd.AddCommand(setPresavingCmd)
configSetCmd.AddCommand(setColorSchemeCmd)
configSetCmd.AddCommand(setDefaultFilterCommand)
configSetCmd.AddCommand(setAiCompletionEndpoint)
setColorSchemeCmd.AddCommand(setColorSchemeSelectedText)
setColorSchemeCmd.AddCommand(setColorSchemeSelectedBackground)
setColorSchemeCmd.AddCommand(setColorSchemeBorderColor)

View File

@ -205,6 +205,8 @@ type ClientConfig struct {
ColorScheme ColorScheme `json:"color_scheme"`
// A default filter that will be applied to all search queries
DefaultFilter string `json:"default_filter"`
// The endpoint to use for AI suggestions
AiCompletionEndpoint string `json:"ai_completion_endpoint"`
}
type ColorScheme struct {
@ -272,6 +274,9 @@ func GetConfig() (ClientConfig, error) {
if config.ColorScheme.BorderColor == "" {
config.ColorScheme.BorderColor = GetDefaultColorScheme().BorderColor
}
if config.AiCompletionEndpoint == "" {
config.AiCompletionEndpoint = "https://api.openai.com/v1/chat/completions"
}
return config, nil
}

View File

@ -1,17 +0,0 @@
package main
import (
"fmt"
"log"
"strings"
"github.com/ddworken/hishtory/shared/ai"
)
func main() {
resp, _, err := ai.GetAiSuggestionsViaOpenAiApi("Find all CSV files in the current directory or subdirectories and select the first column, then prepend `foo` to each line", "bash", "MacOS", 3)
if err != nil {
log.Fatal(err)
}
fmt.Println(strings.Join(resp, "\n"))
}

View File

@ -12,6 +12,8 @@ import (
"golang.org/x/exp/slices"
)
const DefaultOpenAiEndpoint = "https://api.openai.com/v1/chat/completions"
type openAiRequest struct {
Model string `json:"model"`
Messages []openAiMessage `json:"messages"`
@ -51,7 +53,7 @@ type TestOnlyOverrideAiSuggestionRequest struct {
var TestOnlyOverrideAiSuggestions map[string][]string = make(map[string][]string)
func GetAiSuggestionsViaOpenAiApi(query, shellName, osName string, numberCompletions int) ([]string, OpenAiUsage, error) {
func GetAiSuggestionsViaOpenAiApi(apiEndpoint, query, shellName, osName string, numberCompletions int) ([]string, OpenAiUsage, error) {
if results := TestOnlyOverrideAiSuggestions[query]; len(results) > 0 {
return results, OpenAiUsage{}, nil
}
@ -63,7 +65,7 @@ func GetAiSuggestionsViaOpenAiApi(query, shellName, osName string, numberComplet
shellName = "bash"
}
apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey == "" {
if apiKey == "" && apiEndpoint == DefaultOpenAiEndpoint {
return nil, OpenAiUsage{}, fmt.Errorf("OPENAI_API_KEY environment variable is not set")
}
client := &http.Client{}
@ -82,12 +84,14 @@ func GetAiSuggestionsViaOpenAiApi(query, shellName, osName string, numberComplet
if err != nil {
return nil, OpenAiUsage{}, fmt.Errorf("failed to serialize JSON for OpenAI API: %w", err)
}
req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", bytes.NewBuffer(apiReqStr))
req, err := http.NewRequest("POST", apiEndpoint, bytes.NewBuffer(apiReqStr))
if err != nil {
return nil, OpenAiUsage{}, fmt.Errorf("failed to create OpenAI API request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
resp, err := client.Do(req)
if err != nil {
return nil, OpenAiUsage{}, fmt.Errorf("failed to query OpenAI API: %w", err)

View File

@ -13,7 +13,7 @@ func TestLiveOpenAiApi(t *testing.T) {
if os.Getenv("OPENAI_API_KEY") == "" {
t.Skip("Skipping test since OPENAI_API_KEY is not set")
}
results, _, err := GetAiSuggestionsViaOpenAiApi("list files in the current directory", "bash", "Linux", 3)
results, _, err := GetAiSuggestionsViaOpenAiApi("https://api.openai.com/v1/chat/completions", "list files in the current directory", "bash", "Linux", 3)
require.NoError(t, err)
resultsContainsLs := false
for _, result := range results {