Wire through the shell name into AI suggestions so that we can get more precise AI suggestions for the current shell

This commit is contained in:
David Dworken 2024-02-19 12:12:04 -08:00
parent 22d43098fe
commit 87c2cde688
No known key found for this signature in database
6 changed files with 36 additions and 34 deletions

View File

@ -17,20 +17,20 @@ import (
var mostRecentQuery string
func DebouncedGetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) {
func DebouncedGetAiSuggestions(ctx context.Context, shellName, query string, numberCompletions int) ([]string, error) {
mostRecentQuery = query
time.Sleep(time.Millisecond * 300)
if mostRecentQuery == query {
return GetAiSuggestions(ctx, query, numberCompletions)
return GetAiSuggestions(ctx, shellName, query, numberCompletions)
}
return nil, nil
}
func GetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) {
func GetAiSuggestions(ctx context.Context, shellName, query string, numberCompletions int) ([]string, error) {
if os.Getenv("OPENAI_API_KEY") == "" {
return GetAiSuggestionsViaHishtoryApi(ctx, query, numberCompletions)
return GetAiSuggestionsViaHishtoryApi(ctx, shellName, query, numberCompletions)
} else {
suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(query, getShellName(), getOsName(), numberCompletions)
suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(query, shellName, getOsName(), numberCompletions)
return suggestions, err
}
}
@ -55,12 +55,7 @@ func getOsName() string {
}
}
func getShellName() string {
// TODO: Wire the real shell name in here
return "bash"
}
func GetAiSuggestionsViaHishtoryApi(ctx context.Context, query string, numberCompletions int) ([]string, error) {
func GetAiSuggestionsViaHishtoryApi(ctx context.Context, shellName, query string, numberCompletions int) ([]string, error) {
hctx.GetLogger().Infof("Running OpenAI query for %#v", query)
req := ai.AiSuggestionRequest{
DeviceId: hctx.GetConf(ctx).DeviceId,
@ -68,7 +63,7 @@ func GetAiSuggestionsViaHishtoryApi(ctx context.Context, query string, numberCom
Query: query,
NumberCompletions: numberCompletions,
OsName: getOsName(),
ShellName: getShellName(),
ShellName: shellName,
}
reqData, err := json.Marshal(req)
if err != nil {

View File

@ -50,7 +50,11 @@ var tqueryCmd = &cobra.Command{
DisableFlagParsing: true,
Run: func(cmd *cobra.Command, args []string) {
ctx := hctx.MakeContext()
lib.CheckFatalError(tui.TuiQuery(ctx, strings.Join(args, " ")))
shellName := "bash"
if os.Getenv("HISHTORY_SHELL_NAME") != "" {
shellName = os.Getenv("HISHTORY_SHELL_NAME")
}
lib.CheckFatalError(tui.TuiQuery(ctx, shellName, strings.Join(args, " ")))
},
}

View File

@ -29,7 +29,7 @@ end
function __hishtory_on_control_r
set -l tmp (mktemp -t fish.XXXXXX)
set -x init_query (commandline -b)
HISHTORY_TERM_INTEGRATION=1 hishtory tquery $init_query > $tmp
HISHTORY_TERM_INTEGRATION=1 HISHTORY_SHELL_NAME=fish hishtory tquery $init_query > $tmp
set -l res $status
commandline -f repaint
if [ -s $tmp ]

View File

@ -46,7 +46,7 @@ PROMPT_COMMAND="__hishtory_postcommand; $PROMPT_COMMAND"
export HISTTIMEFORMAT=$HISTTIMEFORMAT
__history_control_r() {
READLINE_LINE=$(HISHTORY_TERM_INTEGRATION=1 hishtory tquery "$READLINE_LINE")
READLINE_LINE=$(HISHTORY_TERM_INTEGRATION=1 HISHTORY_SHELL_NAME=bash hishtory tquery "$READLINE_LINE")
READLINE_POINT=0x7FFFFFFF
}

View File

@ -33,7 +33,7 @@ function _hishtory_precmd() {
}
_hishtory_widget() {
BUFFER=$(HISHTORY_TERM_INTEGRATION=1 hishtory tquery $BUFFER)
BUFFER=$(HISHTORY_TERM_INTEGRATION=1 HISHTORY_SHELL_NAME=zsh hishtory tquery $BUFFER)
CURSOR=${#BUFFER}
zle reset-prompt
}

View File

@ -182,6 +182,9 @@ type model struct {
// A banner from the backend to be displayed. Generally an empty string.
banner string
// The currently executing shell. Defaults to bash if not specified. Used for more precise AI suggestions.
shellName string
}
type doneDownloadingMsg struct{}
@ -205,7 +208,7 @@ type asyncQueryFinishedMsg struct {
overriddenSearchQuery *string
}
func initialModel(ctx context.Context, initialQuery string) model {
func initialModel(ctx context.Context, shellName, initialQuery string) model {
s := spinner.New()
s.Spinner = spinner.Dot
s.Style = lipgloss.NewStyle().Foreground(lipgloss.Color("205"))
@ -231,7 +234,7 @@ func initialModel(ctx context.Context, initialQuery string) model {
queryInput.SetValue(initialQuery)
}
CURRENT_QUERY_FOR_HIGHLIGHTING = initialQuery
return model{ctx: ctx, spinner: s, isLoading: true, table: nil, tableEntries: []*data.HistoryEntry{}, runQuery: &initialQuery, queryInput: queryInput, help: help.New()}
return model{ctx: ctx, spinner: s, isLoading: true, table: nil, tableEntries: []*data.HistoryEntry{}, runQuery: &initialQuery, queryInput: queryInput, help: help.New(), shellName: shellName}
}
func (m model) Init() tea.Cmd {
@ -252,7 +255,7 @@ func updateTable(m model, rows []table.Row, entries []*data.HistoryEntry, search
initialCursor = m.table.Cursor()
}
if forceUpdateTable || m.table == nil {
t, err := makeTable(m.ctx, rows)
t, err := makeTable(m.ctx, m.shellName, rows)
if err != nil {
m.fatalErr = err
return m
@ -299,7 +302,7 @@ func runQueryAndUpdateTable(m model, forceUpdateTable, maintainCursor bool) tea.
// The default filter was cleared for this session, so don't apply it
defaultFilter = ""
}
rows, entries, searchErr := getRows(m.ctx, conf.DisplayedColumns, defaultFilter, query, PADDED_NUM_ENTRIES)
rows, entries, searchErr := getRows(m.ctx, conf.DisplayedColumns, m.shellName, defaultFilter, query, PADDED_NUM_ENTRIES)
return asyncQueryFinishedMsg{queryId, rows, entries, searchErr, forceUpdateTable, maintainCursor, nil}
}
}
@ -493,8 +496,8 @@ func renderNullableTable(m model, helpText string) string {
return baseStyle.Render(m.table.View())
}
func getRowsFromAiSuggestions(ctx context.Context, columnNames []string, query string) ([]table.Row, []*data.HistoryEntry, error) {
suggestions, err := ai.DebouncedGetAiSuggestions(ctx, strings.TrimPrefix(query, "?"), 5)
func getRowsFromAiSuggestions(ctx context.Context, columnNames []string, shellName, query string) ([]table.Row, []*data.HistoryEntry, error) {
suggestions, err := ai.DebouncedGetAiSuggestions(ctx, shellName, strings.TrimPrefix(query, "?"), 5)
if err != nil {
hctx.GetLogger().Infof("failed to get AI query suggestions: %v", err)
return nil, nil, fmt.Errorf("failed to get AI query suggestions: %w", err)
@ -525,11 +528,11 @@ func getRowsFromAiSuggestions(ctx context.Context, columnNames []string, query s
return rows, entries, nil
}
func getRows(ctx context.Context, columnNames []string, defaultFilter, query string, numEntries int) ([]table.Row, []*data.HistoryEntry, error) {
func getRows(ctx context.Context, columnNames []string, shellName, defaultFilter, query string, numEntries int) ([]table.Row, []*data.HistoryEntry, error) {
db := hctx.GetDb(ctx)
config := hctx.GetConf(ctx)
if config.AiCompletion && !config.IsOffline && strings.HasPrefix(query, "?") && len(query) > 1 {
return getRowsFromAiSuggestions(ctx, columnNames, query)
return getRowsFromAiSuggestions(ctx, columnNames, shellName, query)
}
searchResults, err := lib.Search(ctx, db, defaultFilter+" "+query, numEntries)
if err != nil {
@ -588,10 +591,10 @@ func getTerminalSize() (int, int, error) {
var bigQueryResults []table.Row
func makeTableColumns(ctx context.Context, columnNames []string, rows []table.Row) ([]table.Column, error) {
func makeTableColumns(ctx context.Context, shellName string, columnNames []string, rows []table.Row) ([]table.Column, error) {
// Handle an initial query with no results
if len(rows) == 0 || len(rows[0]) == 0 {
allRows, _, err := getRows(ctx, columnNames, hctx.GetConf(ctx).DefaultFilter, "", 25)
allRows, _, err := getRows(ctx, columnNames, shellName, hctx.GetConf(ctx).DefaultFilter, "", 25)
if err != nil {
return nil, err
}
@ -604,7 +607,7 @@ func makeTableColumns(ctx context.Context, columnNames []string, rows []table.Ro
}
allRows = append(allRows, row)
}
return makeTableColumns(ctx, columnNames, allRows)
return makeTableColumns(ctx, shellName, columnNames, allRows)
}
// Calculate the minimum amount of space that we need for each column for the current actual search
@ -617,7 +620,7 @@ func makeTableColumns(ctx context.Context, columnNames []string, rows []table.Ro
// Calculate the maximum column width that is useful for each column if we search for the empty string
if bigQueryResults == nil {
bigRows, _, err := getRows(ctx, columnNames, "", "", 1000)
bigRows, _, err := getRows(ctx, columnNames, shellName, "", "", 1000)
if err != nil {
return nil, err
}
@ -678,9 +681,9 @@ func min(a, b int) int {
return b
}
func makeTable(ctx context.Context, rows []table.Row) (table.Model, error) {
func makeTable(ctx context.Context, shellName string, rows []table.Row) (table.Model, error) {
config := hctx.GetConf(ctx)
columns, err := makeTableColumns(ctx, config.DisplayedColumns, rows)
columns, err := makeTableColumns(ctx, shellName, config.DisplayedColumns, rows)
if err != nil {
return table.Model{}, err
}
@ -887,22 +890,22 @@ func configureColorProfile(ctx context.Context) {
}
}
func TuiQuery(ctx context.Context, initialQuery string) error {
func TuiQuery(ctx context.Context, shellName, initialQuery string) error {
configureColorProfile(ctx)
p := tea.NewProgram(initialModel(ctx, initialQuery), tea.WithOutput(os.Stderr))
p := tea.NewProgram(initialModel(ctx, shellName, initialQuery), tea.WithOutput(os.Stderr))
// Async: Get the initial set of rows
go func() {
LAST_DISPATCHED_QUERY_ID++
queryId := LAST_DISPATCHED_QUERY_ID
LAST_DISPATCHED_QUERY_TIMESTAMP = time.Now()
conf := hctx.GetConf(ctx)
rows, entries, err := getRows(ctx, conf.DisplayedColumns, conf.DefaultFilter, initialQuery, PADDED_NUM_ENTRIES)
rows, entries, err := getRows(ctx, conf.DisplayedColumns, shellName, conf.DefaultFilter, initialQuery, PADDED_NUM_ENTRIES)
if err == nil || initialQuery == "" {
p.Send(asyncQueryFinishedMsg{queryId: queryId, rows: rows, entries: entries, searchErr: err, forceUpdateTable: true, maintainCursor: false, overriddenSearchQuery: nil})
} else {
// initialQuery is likely invalid in some way, let's just drop it
emptyQuery := ""
rows, entries, err := getRows(ctx, hctx.GetConf(ctx).DisplayedColumns, conf.DefaultFilter, emptyQuery, PADDED_NUM_ENTRIES)
rows, entries, err := getRows(ctx, hctx.GetConf(ctx).DisplayedColumns, shellName, conf.DefaultFilter, emptyQuery, PADDED_NUM_ENTRIES)
p.Send(asyncQueryFinishedMsg{queryId: queryId, rows: rows, entries: entries, searchErr: err, forceUpdateTable: true, maintainCursor: false, overriddenSearchQuery: &emptyQuery})
}
}()