mirror of
https://github.com/ddworken/hishtory.git
synced 2024-11-22 00:03:58 +01:00
* Add ability to configure custom OpenAI API endpoint for #186 * Ensure the AiCompletionEndpoint field is always initialized
This commit is contained in:
parent
a0e7f30c10
commit
c1729f1ee2
@ -331,7 +331,7 @@ func (s *Server) aiSuggestionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if numDevices == 0 {
|
||||
panic(fmt.Errorf("rejecting OpenAI request for user_id=%#v since it does not exist", req.UserId))
|
||||
}
|
||||
suggestions, usage, err := ai.GetAiSuggestionsViaOpenAiApi(req.Query, req.ShellName, req.OsName, req.NumberCompletions)
|
||||
suggestions, usage, err := ai.GetAiSuggestionsViaOpenAiApi(ai.DefaultOpenAiEndpoint, req.Query, req.ShellName, req.OsName, req.NumberCompletions)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to query OpenAI API: %w", err))
|
||||
}
|
||||
|
@ -27,10 +27,10 @@ func DebouncedGetAiSuggestions(ctx context.Context, shellName, query string, num
|
||||
}
|
||||
|
||||
func GetAiSuggestions(ctx context.Context, shellName, query string, numberCompletions int) ([]string, error) {
|
||||
if os.Getenv("OPENAI_API_KEY") == "" {
|
||||
if os.Getenv("OPENAI_API_KEY") == "" && hctx.GetConf(ctx).AiCompletionEndpoint == ai.DefaultOpenAiEndpoint {
|
||||
return GetAiSuggestionsViaHishtoryApi(ctx, shellName, query, numberCompletions)
|
||||
} else {
|
||||
suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(query, shellName, getOsName(), numberCompletions)
|
||||
suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(hctx.GetConf(ctx).AiCompletionEndpoint, query, shellName, getOsName(), numberCompletions)
|
||||
return suggestions, err
|
||||
}
|
||||
}
|
||||
|
@ -1572,6 +1572,17 @@ func testConfigGetSet(t *testing.T, tester shellTester) {
|
||||
if out != "Command \"Exit Code\" Timestamp foobar \n" {
|
||||
t.Fatalf("unexpected config-get output: %#v", out)
|
||||
}
|
||||
|
||||
// For OpenAI endpoints
|
||||
out = tester.RunInteractiveShell(t, `hishtory config-get ai-completion-endpoint`)
|
||||
if out != "https://api.openai.com/v1/chat/completions\n" {
|
||||
t.Fatalf("unexpected config-get output: %#v", out)
|
||||
}
|
||||
tester.RunInteractiveShell(t, `hishtory config-set ai-completion-endpoint https://example.com/foo/bar`)
|
||||
out = tester.RunInteractiveShell(t, `hishtory config-get ai-completion-endpoint`)
|
||||
if out != "https://example.com/foo/bar\n" {
|
||||
t.Fatalf("unexpected config-get output: %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func clearControlRSearchFromConfig(t testing.TB) {
|
||||
@ -2166,7 +2177,6 @@ func testTui_ai(t *testing.T) {
|
||||
})
|
||||
out = stripTuiCommandPrefix(t, out)
|
||||
testutils.CompareGoldens(t, out, "TestTui-AiQuery-Disabled")
|
||||
|
||||
}
|
||||
|
||||
func testControlR(t *testing.T, tester shellTester, shellName string, onlineStatus OnlineStatus) {
|
||||
|
@ -146,6 +146,16 @@ var getColorScheme = &cobra.Command{
|
||||
},
|
||||
}
|
||||
|
||||
var getAiCompletionEndpoint = &cobra.Command{
|
||||
Use: "ai-completion-endpoint",
|
||||
Short: "The AI endpoint to use for AI completions",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
ctx := hctx.MakeContext()
|
||||
config := hctx.GetConf(ctx)
|
||||
fmt.Println(config.AiCompletionEndpoint)
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(configGetCmd)
|
||||
configGetCmd.AddCommand(getEnableControlRCmd)
|
||||
@ -159,4 +169,5 @@ func init() {
|
||||
configGetCmd.AddCommand(getPresavingCmd)
|
||||
configGetCmd.AddCommand(getColorScheme)
|
||||
configGetCmd.AddCommand(getDefaultFilterCmd)
|
||||
configGetCmd.AddCommand(getAiCompletionEndpoint)
|
||||
}
|
||||
|
@ -217,6 +217,18 @@ func validateColor(color string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var setAiCompletionEndpoint = &cobra.Command{
|
||||
Use: "ai-completion-endpoint",
|
||||
Short: "The AI endpoint to use for AI completions",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
ctx := hctx.MakeContext()
|
||||
config := hctx.GetConf(ctx)
|
||||
config.AiCompletionEndpoint = args[0]
|
||||
lib.CheckFatalError(hctx.SetConfig(config))
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(configSetCmd)
|
||||
configSetCmd.AddCommand(setEnableControlRCmd)
|
||||
@ -229,6 +241,7 @@ func init() {
|
||||
configSetCmd.AddCommand(setPresavingCmd)
|
||||
configSetCmd.AddCommand(setColorSchemeCmd)
|
||||
configSetCmd.AddCommand(setDefaultFilterCommand)
|
||||
configSetCmd.AddCommand(setAiCompletionEndpoint)
|
||||
setColorSchemeCmd.AddCommand(setColorSchemeSelectedText)
|
||||
setColorSchemeCmd.AddCommand(setColorSchemeSelectedBackground)
|
||||
setColorSchemeCmd.AddCommand(setColorSchemeBorderColor)
|
||||
|
@ -205,6 +205,8 @@ type ClientConfig struct {
|
||||
ColorScheme ColorScheme `json:"color_scheme"`
|
||||
// A default filter that will be applied to all search queries
|
||||
DefaultFilter string `json:"default_filter"`
|
||||
// The endpoint to use for AI suggestions
|
||||
AiCompletionEndpoint string `json:"ai_completion_endpoint"`
|
||||
}
|
||||
|
||||
type ColorScheme struct {
|
||||
@ -272,6 +274,9 @@ func GetConfig() (ClientConfig, error) {
|
||||
if config.ColorScheme.BorderColor == "" {
|
||||
config.ColorScheme.BorderColor = GetDefaultColorScheme().BorderColor
|
||||
}
|
||||
if config.AiCompletionEndpoint == "" {
|
||||
config.AiCompletionEndpoint = "https://api.openai.com/v1/chat/completions"
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
|
@ -1,17 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/ddworken/hishtory/shared/ai"
|
||||
)
|
||||
|
||||
func main() {
|
||||
resp, _, err := ai.GetAiSuggestionsViaOpenAiApi("Find all CSV files in the current directory or subdirectories and select the first column, then prepend `foo` to each line", "bash", "MacOS", 3)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println(strings.Join(resp, "\n"))
|
||||
}
|
@ -12,6 +12,8 @@ import (
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
const DefaultOpenAiEndpoint = "https://api.openai.com/v1/chat/completions"
|
||||
|
||||
type openAiRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []openAiMessage `json:"messages"`
|
||||
@ -51,7 +53,7 @@ type TestOnlyOverrideAiSuggestionRequest struct {
|
||||
|
||||
var TestOnlyOverrideAiSuggestions map[string][]string = make(map[string][]string)
|
||||
|
||||
func GetAiSuggestionsViaOpenAiApi(query, shellName, osName string, numberCompletions int) ([]string, OpenAiUsage, error) {
|
||||
func GetAiSuggestionsViaOpenAiApi(apiEndpoint, query, shellName, osName string, numberCompletions int) ([]string, OpenAiUsage, error) {
|
||||
if results := TestOnlyOverrideAiSuggestions[query]; len(results) > 0 {
|
||||
return results, OpenAiUsage{}, nil
|
||||
}
|
||||
@ -63,7 +65,7 @@ func GetAiSuggestionsViaOpenAiApi(query, shellName, osName string, numberComplet
|
||||
shellName = "bash"
|
||||
}
|
||||
apiKey := os.Getenv("OPENAI_API_KEY")
|
||||
if apiKey == "" {
|
||||
if apiKey == "" && apiEndpoint == DefaultOpenAiEndpoint {
|
||||
return nil, OpenAiUsage{}, fmt.Errorf("OPENAI_API_KEY environment variable is not set")
|
||||
}
|
||||
client := &http.Client{}
|
||||
@ -82,12 +84,14 @@ func GetAiSuggestionsViaOpenAiApi(query, shellName, osName string, numberComplet
|
||||
if err != nil {
|
||||
return nil, OpenAiUsage{}, fmt.Errorf("failed to serialize JSON for OpenAI API: %w", err)
|
||||
}
|
||||
req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", bytes.NewBuffer(apiReqStr))
|
||||
req, err := http.NewRequest("POST", apiEndpoint, bytes.NewBuffer(apiReqStr))
|
||||
if err != nil {
|
||||
return nil, OpenAiUsage{}, fmt.Errorf("failed to create OpenAI API request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
if apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, OpenAiUsage{}, fmt.Errorf("failed to query OpenAI API: %w", err)
|
||||
|
@ -13,7 +13,7 @@ func TestLiveOpenAiApi(t *testing.T) {
|
||||
if os.Getenv("OPENAI_API_KEY") == "" {
|
||||
t.Skip("Skipping test since OPENAI_API_KEY is not set")
|
||||
}
|
||||
results, _, err := GetAiSuggestionsViaOpenAiApi("list files in the current directory", "bash", "Linux", 3)
|
||||
results, _, err := GetAiSuggestionsViaOpenAiApi("https://api.openai.com/v1/chat/completions", "list files in the current directory", "bash", "Linux", 3)
|
||||
require.NoError(t, err)
|
||||
resultsContainsLs := false
|
||||
for _, result := range results {
|
||||
|
Loading…
Reference in New Issue
Block a user