mirror of
https://github.com/ddworken/hishtory.git
synced 2024-11-24 01:03:14 +01:00
* Add ability to configure custom OpenAI API endpoint for #186 * Ensure the AiCompletionEndpoint field is always initialized
This commit is contained in:
parent
46e92803be
commit
21b401bc14
@ -331,7 +331,7 @@ func (s *Server) aiSuggestionHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
if numDevices == 0 {
|
if numDevices == 0 {
|
||||||
panic(fmt.Errorf("rejecting OpenAI request for user_id=%#v since it does not exist", req.UserId))
|
panic(fmt.Errorf("rejecting OpenAI request for user_id=%#v since it does not exist", req.UserId))
|
||||||
}
|
}
|
||||||
suggestions, usage, err := ai.GetAiSuggestionsViaOpenAiApi(req.Query, req.ShellName, req.OsName, req.NumberCompletions)
|
suggestions, usage, err := ai.GetAiSuggestionsViaOpenAiApi(ai.DefaultOpenAiEndpoint, req.Query, req.ShellName, req.OsName, req.NumberCompletions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Errorf("failed to query OpenAI API: %w", err))
|
panic(fmt.Errorf("failed to query OpenAI API: %w", err))
|
||||||
}
|
}
|
||||||
|
@ -27,10 +27,10 @@ func DebouncedGetAiSuggestions(ctx context.Context, shellName, query string, num
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetAiSuggestions(ctx context.Context, shellName, query string, numberCompletions int) ([]string, error) {
|
func GetAiSuggestions(ctx context.Context, shellName, query string, numberCompletions int) ([]string, error) {
|
||||||
if os.Getenv("OPENAI_API_KEY") == "" {
|
if os.Getenv("OPENAI_API_KEY") == "" && hctx.GetConf(ctx).AiCompletionEndpoint == ai.DefaultOpenAiEndpoint {
|
||||||
return GetAiSuggestionsViaHishtoryApi(ctx, shellName, query, numberCompletions)
|
return GetAiSuggestionsViaHishtoryApi(ctx, shellName, query, numberCompletions)
|
||||||
} else {
|
} else {
|
||||||
suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(query, shellName, getOsName(), numberCompletions)
|
suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(hctx.GetConf(ctx).AiCompletionEndpoint, query, shellName, getOsName(), numberCompletions)
|
||||||
return suggestions, err
|
return suggestions, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1572,6 +1572,17 @@ func testConfigGetSet(t *testing.T, tester shellTester) {
|
|||||||
if out != "Command \"Exit Code\" Timestamp foobar \n" {
|
if out != "Command \"Exit Code\" Timestamp foobar \n" {
|
||||||
t.Fatalf("unexpected config-get output: %#v", out)
|
t.Fatalf("unexpected config-get output: %#v", out)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For OpenAI endpoints
|
||||||
|
out = tester.RunInteractiveShell(t, `hishtory config-get ai-completion-endpoint`)
|
||||||
|
if out != "https://api.openai.com/v1/chat/completions\n" {
|
||||||
|
t.Fatalf("unexpected config-get output: %#v", out)
|
||||||
|
}
|
||||||
|
tester.RunInteractiveShell(t, `hishtory config-set ai-completion-endpoint https://example.com/foo/bar`)
|
||||||
|
out = tester.RunInteractiveShell(t, `hishtory config-get ai-completion-endpoint`)
|
||||||
|
if out != "https://example.com/foo/bar\n" {
|
||||||
|
t.Fatalf("unexpected config-get output: %#v", out)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func clearControlRSearchFromConfig(t testing.TB) {
|
func clearControlRSearchFromConfig(t testing.TB) {
|
||||||
@ -2166,7 +2177,6 @@ func testTui_ai(t *testing.T) {
|
|||||||
})
|
})
|
||||||
out = stripTuiCommandPrefix(t, out)
|
out = stripTuiCommandPrefix(t, out)
|
||||||
testutils.CompareGoldens(t, out, "TestTui-AiQuery-Disabled")
|
testutils.CompareGoldens(t, out, "TestTui-AiQuery-Disabled")
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func testControlR(t *testing.T, tester shellTester, shellName string, onlineStatus OnlineStatus) {
|
func testControlR(t *testing.T, tester shellTester, shellName string, onlineStatus OnlineStatus) {
|
||||||
|
@ -146,6 +146,16 @@ var getColorScheme = &cobra.Command{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var getAiCompletionEndpoint = &cobra.Command{
|
||||||
|
Use: "ai-completion-endpoint",
|
||||||
|
Short: "The AI endpoint to use for AI completions",
|
||||||
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
ctx := hctx.MakeContext()
|
||||||
|
config := hctx.GetConf(ctx)
|
||||||
|
fmt.Println(config.AiCompletionEndpoint)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
rootCmd.AddCommand(configGetCmd)
|
rootCmd.AddCommand(configGetCmd)
|
||||||
configGetCmd.AddCommand(getEnableControlRCmd)
|
configGetCmd.AddCommand(getEnableControlRCmd)
|
||||||
@ -159,4 +169,5 @@ func init() {
|
|||||||
configGetCmd.AddCommand(getPresavingCmd)
|
configGetCmd.AddCommand(getPresavingCmd)
|
||||||
configGetCmd.AddCommand(getColorScheme)
|
configGetCmd.AddCommand(getColorScheme)
|
||||||
configGetCmd.AddCommand(getDefaultFilterCmd)
|
configGetCmd.AddCommand(getDefaultFilterCmd)
|
||||||
|
configGetCmd.AddCommand(getAiCompletionEndpoint)
|
||||||
}
|
}
|
||||||
|
@ -217,6 +217,18 @@ func validateColor(color string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var setAiCompletionEndpoint = &cobra.Command{
|
||||||
|
Use: "ai-completion-endpoint",
|
||||||
|
Short: "The AI endpoint to use for AI completions",
|
||||||
|
Args: cobra.ExactArgs(1),
|
||||||
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
ctx := hctx.MakeContext()
|
||||||
|
config := hctx.GetConf(ctx)
|
||||||
|
config.AiCompletionEndpoint = args[0]
|
||||||
|
lib.CheckFatalError(hctx.SetConfig(config))
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
rootCmd.AddCommand(configSetCmd)
|
rootCmd.AddCommand(configSetCmd)
|
||||||
configSetCmd.AddCommand(setEnableControlRCmd)
|
configSetCmd.AddCommand(setEnableControlRCmd)
|
||||||
@ -229,6 +241,7 @@ func init() {
|
|||||||
configSetCmd.AddCommand(setPresavingCmd)
|
configSetCmd.AddCommand(setPresavingCmd)
|
||||||
configSetCmd.AddCommand(setColorSchemeCmd)
|
configSetCmd.AddCommand(setColorSchemeCmd)
|
||||||
configSetCmd.AddCommand(setDefaultFilterCommand)
|
configSetCmd.AddCommand(setDefaultFilterCommand)
|
||||||
|
configSetCmd.AddCommand(setAiCompletionEndpoint)
|
||||||
setColorSchemeCmd.AddCommand(setColorSchemeSelectedText)
|
setColorSchemeCmd.AddCommand(setColorSchemeSelectedText)
|
||||||
setColorSchemeCmd.AddCommand(setColorSchemeSelectedBackground)
|
setColorSchemeCmd.AddCommand(setColorSchemeSelectedBackground)
|
||||||
setColorSchemeCmd.AddCommand(setColorSchemeBorderColor)
|
setColorSchemeCmd.AddCommand(setColorSchemeBorderColor)
|
||||||
|
@ -205,6 +205,8 @@ type ClientConfig struct {
|
|||||||
ColorScheme ColorScheme `json:"color_scheme"`
|
ColorScheme ColorScheme `json:"color_scheme"`
|
||||||
// A default filter that will be applied to all search queries
|
// A default filter that will be applied to all search queries
|
||||||
DefaultFilter string `json:"default_filter"`
|
DefaultFilter string `json:"default_filter"`
|
||||||
|
// The endpoint to use for AI suggestions
|
||||||
|
AiCompletionEndpoint string `json:"ai_completion_endpoint"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ColorScheme struct {
|
type ColorScheme struct {
|
||||||
@ -272,6 +274,9 @@ func GetConfig() (ClientConfig, error) {
|
|||||||
if config.ColorScheme.BorderColor == "" {
|
if config.ColorScheme.BorderColor == "" {
|
||||||
config.ColorScheme.BorderColor = GetDefaultColorScheme().BorderColor
|
config.ColorScheme.BorderColor = GetDefaultColorScheme().BorderColor
|
||||||
}
|
}
|
||||||
|
if config.AiCompletionEndpoint == "" {
|
||||||
|
config.AiCompletionEndpoint = "https://api.openai.com/v1/chat/completions"
|
||||||
|
}
|
||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,17 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/ddworken/hishtory/shared/ai"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
resp, _, err := ai.GetAiSuggestionsViaOpenAiApi("Find all CSV files in the current directory or subdirectories and select the first column, then prepend `foo` to each line", "bash", "MacOS", 3)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
fmt.Println(strings.Join(resp, "\n"))
|
|
||||||
}
|
|
@ -12,6 +12,8 @@ import (
|
|||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const DefaultOpenAiEndpoint = "https://api.openai.com/v1/chat/completions"
|
||||||
|
|
||||||
type openAiRequest struct {
|
type openAiRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Messages []openAiMessage `json:"messages"`
|
Messages []openAiMessage `json:"messages"`
|
||||||
@ -51,7 +53,7 @@ type TestOnlyOverrideAiSuggestionRequest struct {
|
|||||||
|
|
||||||
var TestOnlyOverrideAiSuggestions map[string][]string = make(map[string][]string)
|
var TestOnlyOverrideAiSuggestions map[string][]string = make(map[string][]string)
|
||||||
|
|
||||||
func GetAiSuggestionsViaOpenAiApi(query, shellName, osName string, numberCompletions int) ([]string, OpenAiUsage, error) {
|
func GetAiSuggestionsViaOpenAiApi(apiEndpoint, query, shellName, osName string, numberCompletions int) ([]string, OpenAiUsage, error) {
|
||||||
if results := TestOnlyOverrideAiSuggestions[query]; len(results) > 0 {
|
if results := TestOnlyOverrideAiSuggestions[query]; len(results) > 0 {
|
||||||
return results, OpenAiUsage{}, nil
|
return results, OpenAiUsage{}, nil
|
||||||
}
|
}
|
||||||
@ -63,7 +65,7 @@ func GetAiSuggestionsViaOpenAiApi(query, shellName, osName string, numberComplet
|
|||||||
shellName = "bash"
|
shellName = "bash"
|
||||||
}
|
}
|
||||||
apiKey := os.Getenv("OPENAI_API_KEY")
|
apiKey := os.Getenv("OPENAI_API_KEY")
|
||||||
if apiKey == "" {
|
if apiKey == "" && apiEndpoint == DefaultOpenAiEndpoint {
|
||||||
return nil, OpenAiUsage{}, fmt.Errorf("OPENAI_API_KEY environment variable is not set")
|
return nil, OpenAiUsage{}, fmt.Errorf("OPENAI_API_KEY environment variable is not set")
|
||||||
}
|
}
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
@ -82,12 +84,14 @@ func GetAiSuggestionsViaOpenAiApi(query, shellName, osName string, numberComplet
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, OpenAiUsage{}, fmt.Errorf("failed to serialize JSON for OpenAI API: %w", err)
|
return nil, OpenAiUsage{}, fmt.Errorf("failed to serialize JSON for OpenAI API: %w", err)
|
||||||
}
|
}
|
||||||
req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", bytes.NewBuffer(apiReqStr))
|
req, err := http.NewRequest("POST", apiEndpoint, bytes.NewBuffer(apiReqStr))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, OpenAiUsage{}, fmt.Errorf("failed to create OpenAI API request: %w", err)
|
return nil, OpenAiUsage{}, fmt.Errorf("failed to create OpenAI API request: %w", err)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
if apiKey != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, OpenAiUsage{}, fmt.Errorf("failed to query OpenAI API: %w", err)
|
return nil, OpenAiUsage{}, fmt.Errorf("failed to query OpenAI API: %w", err)
|
||||||
|
@ -13,7 +13,7 @@ func TestLiveOpenAiApi(t *testing.T) {
|
|||||||
if os.Getenv("OPENAI_API_KEY") == "" {
|
if os.Getenv("OPENAI_API_KEY") == "" {
|
||||||
t.Skip("Skipping test since OPENAI_API_KEY is not set")
|
t.Skip("Skipping test since OPENAI_API_KEY is not set")
|
||||||
}
|
}
|
||||||
results, _, err := GetAiSuggestionsViaOpenAiApi("list files in the current directory", "bash", "Linux", 3)
|
results, _, err := GetAiSuggestionsViaOpenAiApi("https://api.openai.com/v1/chat/completions", "list files in the current directory", "bash", "Linux", 3)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
resultsContainsLs := false
|
resultsContainsLs := false
|
||||||
for _, result := range results {
|
for _, result := range results {
|
||||||
|
Loading…
Reference in New Issue
Block a user