Improve AI suggestions by specifying shell name and OS in OpenAI query

This commit is contained in:
David Dworken 2023-12-19 16:49:31 -08:00
parent d3baf03dde
commit 8fd809fdc8
5 changed files with 35 additions and 6 deletions

View File

@ -319,7 +319,7 @@ func (s *Server) aiSuggestionHandler(w http.ResponseWriter, r *http.Request) {
if numDevices == 0 { if numDevices == 0 {
panic(fmt.Errorf("rejecting OpenAI request for user_id=%#v since it does not exist", req.UserId)) panic(fmt.Errorf("rejecting OpenAI request for user_id=%#v since it does not exist", req.UserId))
} }
suggestions, usage, err := ai.GetAiSuggestionsViaOpenAiApi(req.Query, req.NumberCompletions) suggestions, usage, err := ai.GetAiSuggestionsViaOpenAiApi(req.Query, req.ShellName, req.OsName, req.NumberCompletions)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to query OpenAI API: %w", err)) panic(fmt.Errorf("failed to query OpenAI API: %w", err))
} }

View File

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
"runtime"
"time" "time"
"github.com/ddworken/hishtory/client/data" "github.com/ddworken/hishtory/client/data"
@ -28,11 +29,26 @@ func GetAiSuggestions(ctx context.Context, query string, numberCompletions int)
if os.Getenv("OPENAI_API_KEY") == "" { if os.Getenv("OPENAI_API_KEY") == "" {
return GetAiSuggestionsViaHishtoryApi(ctx, query, numberCompletions) return GetAiSuggestionsViaHishtoryApi(ctx, query, numberCompletions)
} else { } else {
suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(query, numberCompletions) suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(query, getShellName(), getOsName(), numberCompletions)
return suggestions, err return suggestions, err
} }
} }
func getOsName() string {
switch runtime.GOOS {
case "linux":
return "Linux"
case "darwin":
return "MacOS"
default:
return runtime.GOOS
}
}
func getShellName() string {
return "bash"
}
func GetAiSuggestionsViaHishtoryApi(ctx context.Context, query string, numberCompletions int) ([]string, error) { func GetAiSuggestionsViaHishtoryApi(ctx context.Context, query string, numberCompletions int) ([]string, error) {
hctx.GetLogger().Infof("Running OpenAI query for %#v", query) hctx.GetLogger().Infof("Running OpenAI query for %#v", query)
req := ai.AiSuggestionRequest{ req := ai.AiSuggestionRequest{
@ -40,6 +56,8 @@ func GetAiSuggestionsViaHishtoryApi(ctx context.Context, query string, numberCom
UserId: data.UserId(hctx.GetConf(ctx).UserSecret), UserId: data.UserId(hctx.GetConf(ctx).UserSecret),
Query: query, Query: query,
NumberCompletions: numberCompletions, NumberCompletions: numberCompletions,
OsName: getOsName(),
ShellName: getShellName(),
} }
reqData, err := json.Marshal(req) reqData, err := json.Marshal(req)
if err != nil { if err != nil {

View File

@ -9,7 +9,7 @@ import (
) )
func main() { func main() {
resp, _, err := ai.GetAiSuggestionsViaOpenAiApi("Find all CSV files in the current directory or subdirectories and select the first column, then prepend `foo` to each line", 3) resp, _, err := ai.GetAiSuggestionsViaOpenAiApi("Find all CSV files in the current directory or subdirectories and select the first column, then prepend `foo` to each line", "bash", "MacOS", 3)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

View File

@ -51,11 +51,17 @@ type TestOnlyOverrideAiSuggestionRequest struct {
var TestOnlyOverrideAiSuggestions map[string][]string = make(map[string][]string) var TestOnlyOverrideAiSuggestions map[string][]string = make(map[string][]string)
func GetAiSuggestionsViaOpenAiApi(query string, numberCompletions int) ([]string, OpenAiUsage, error) { func GetAiSuggestionsViaOpenAiApi(query, shellName, osName string, numberCompletions int) ([]string, OpenAiUsage, error) {
if results := TestOnlyOverrideAiSuggestions[query]; len(results) > 0 { if results := TestOnlyOverrideAiSuggestions[query]; len(results) > 0 {
return results, OpenAiUsage{}, nil return results, OpenAiUsage{}, nil
} }
hctx.GetLogger().Infof("Running OpenAI query for %#v", query) hctx.GetLogger().Infof("Running OpenAI query for %#v", query)
if osName == "" {
osName = "Linux"
}
if shellName == "" {
shellName = "bash"
}
apiKey := os.Getenv("OPENAI_API_KEY") apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey == "" { if apiKey == "" {
return nil, OpenAiUsage{}, fmt.Errorf("OPENAI_API_KEY environment variable is not set") return nil, OpenAiUsage{}, fmt.Errorf("OPENAI_API_KEY environment variable is not set")
@ -65,7 +71,10 @@ func GetAiSuggestionsViaOpenAiApi(query string, numberCompletions int) ([]string
Model: "gpt-3.5-turbo", Model: "gpt-3.5-turbo",
NumberCompletions: numberCompletions, NumberCompletions: numberCompletions,
Messages: []openAiMessage{ Messages: []openAiMessage{
{Role: "system", Content: "You are an expert programmer that loves to help people with writing shell commands. You always reply with just a shell command and no additional context, information, or formatting. Your replies will be directly executed in bash, so ensure that they are correct and do not contain anything other than a bash command."}, {Role: "system", Content: "You are an expert programmer that loves to help people with writing shell commands. " +
"You always reply with just a shell command and no additional context, information, or formatting. " +
"Your replies will be directly executed in " + shellName + " on " + osName +
", so ensure that they are correct and do not contain anything other than a shell command."},
{Role: "user", Content: query}, {Role: "user", Content: query},
}, },
} }
@ -111,6 +120,8 @@ type AiSuggestionRequest struct {
UserId string `json:"user_id"` UserId string `json:"user_id"`
Query string `json:"query"` Query string `json:"query"`
NumberCompletions int `json:"number_completions"` NumberCompletions int `json:"number_completions"`
ShellName string `json:"shell_name"`
OsName string `json:"os_name"`
} }
type AiSuggestionResponse struct { type AiSuggestionResponse struct {

View File

@ -13,7 +13,7 @@ func TestLiveOpenAiApi(t *testing.T) {
if os.Getenv("OPENAI_API_KEY") == "" { if os.Getenv("OPENAI_API_KEY") == "" {
t.Skip("Skipping test since OPENAI_API_KEY is not set") t.Skip("Skipping test since OPENAI_API_KEY is not set")
} }
results, _, err := GetAiSuggestionsViaOpenAiApi("list files in the current directory", 3) results, _, err := GetAiSuggestionsViaOpenAiApi("list files in the current directory", "MacOS", "Linux", 3)
require.NoError(t, err) require.NoError(t, err)
resultsContainsLs := false resultsContainsLs := false
for _, result := range results { for _, result := range results {