diff --git a/client/ai/ai.go b/client/ai/ai.go index 028aa4f..5078cd9 100644 --- a/client/ai/ai.go +++ b/client/ai/ai.go @@ -17,20 +17,20 @@ import ( var mostRecentQuery string -func DebouncedGetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) { +func DebouncedGetAiSuggestions(ctx context.Context, shellName, query string, numberCompletions int) ([]string, error) { mostRecentQuery = query time.Sleep(time.Millisecond * 300) if mostRecentQuery == query { - return GetAiSuggestions(ctx, query, numberCompletions) + return GetAiSuggestions(ctx, shellName, query, numberCompletions) } return nil, nil } -func GetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) { +func GetAiSuggestions(ctx context.Context, shellName, query string, numberCompletions int) ([]string, error) { if os.Getenv("OPENAI_API_KEY") == "" { - return GetAiSuggestionsViaHishtoryApi(ctx, query, numberCompletions) + return GetAiSuggestionsViaHishtoryApi(ctx, shellName, query, numberCompletions) } else { - suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(query, getShellName(), getOsName(), numberCompletions) + suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(query, shellName, getOsName(), numberCompletions) return suggestions, err } } @@ -55,12 +55,7 @@ func getOsName() string { } } -func getShellName() string { - // TODO: Wire the real shell name in here - return "bash" -} - -func GetAiSuggestionsViaHishtoryApi(ctx context.Context, query string, numberCompletions int) ([]string, error) { +func GetAiSuggestionsViaHishtoryApi(ctx context.Context, shellName, query string, numberCompletions int) ([]string, error) { hctx.GetLogger().Infof("Running OpenAI query for %#v", query) req := ai.AiSuggestionRequest{ DeviceId: hctx.GetConf(ctx).DeviceId, @@ -68,7 +63,7 @@ func GetAiSuggestionsViaHishtoryApi(ctx context.Context, query string, numberCom Query: query, NumberCompletions: numberCompletions, OsName: getOsName(), - ShellName: getShellName(), + ShellName: shellName, } reqData, err := json.Marshal(req) if err != nil { diff --git a/client/cmd/query.go b/client/cmd/query.go index f399460..c56bf3e 100644 --- a/client/cmd/query.go +++ b/client/cmd/query.go @@ -50,7 +50,11 @@ var tqueryCmd = &cobra.Command{ DisableFlagParsing: true, Run: func(cmd *cobra.Command, args []string) { ctx := hctx.MakeContext() - lib.CheckFatalError(tui.TuiQuery(ctx, strings.Join(args, " "))) + shellName := "bash" + if os.Getenv("HISHTORY_SHELL_NAME") != "" { + shellName = os.Getenv("HISHTORY_SHELL_NAME") + } + lib.CheckFatalError(tui.TuiQuery(ctx, shellName, strings.Join(args, " "))) }, } diff --git a/client/lib/config.fish b/client/lib/config.fish index d5bcedc..b1c3529 100644 --- a/client/lib/config.fish +++ b/client/lib/config.fish @@ -29,7 +29,7 @@ end function __hishtory_on_control_r set -l tmp (mktemp -t fish.XXXXXX) set -x init_query (commandline -b) - HISHTORY_TERM_INTEGRATION=1 hishtory tquery $init_query > $tmp + HISHTORY_TERM_INTEGRATION=1 HISHTORY_SHELL_NAME=fish hishtory tquery $init_query > $tmp set -l res $status commandline -f repaint if [ -s $tmp ] diff --git a/client/lib/config.sh b/client/lib/config.sh index bf8fac7..1454107 100644 --- a/client/lib/config.sh +++ b/client/lib/config.sh @@ -46,7 +46,7 @@ PROMPT_COMMAND="__hishtory_postcommand; $PROMPT_COMMAND" export HISTTIMEFORMAT=$HISTTIMEFORMAT __history_control_r() { - READLINE_LINE=$(HISHTORY_TERM_INTEGRATION=1 hishtory tquery "$READLINE_LINE") + READLINE_LINE=$(HISHTORY_TERM_INTEGRATION=1 HISHTORY_SHELL_NAME=bash hishtory tquery "$READLINE_LINE") READLINE_POINT=0x7FFFFFFF } diff --git a/client/lib/config.zsh b/client/lib/config.zsh index acccd67..e00f44e 100644 --- a/client/lib/config.zsh +++ b/client/lib/config.zsh @@ -33,7 +33,7 @@ function _hishtory_precmd() { } _hishtory_widget() { - BUFFER=$(HISHTORY_TERM_INTEGRATION=1 hishtory tquery $BUFFER) + BUFFER=$(HISHTORY_TERM_INTEGRATION=1 HISHTORY_SHELL_NAME=zsh hishtory tquery $BUFFER) CURSOR=${#BUFFER} zle reset-prompt } diff --git a/client/tui/tui.go b/client/tui/tui.go index 3d18bb3..e92cce8 100644 --- a/client/tui/tui.go +++ b/client/tui/tui.go @@ -182,6 +182,9 @@ type model struct { // A banner from the backend to be displayed. Generally an empty string. banner string + + // The currently executing shell. Defaults to bash if not specified. Used for more precise AI suggestions. + shellName string } type doneDownloadingMsg struct{} @@ -205,7 +208,7 @@ type asyncQueryFinishedMsg struct { overriddenSearchQuery *string } -func initialModel(ctx context.Context, initialQuery string) model { +func initialModel(ctx context.Context, shellName, initialQuery string) model { s := spinner.New() s.Spinner = spinner.Dot s.Style = lipgloss.NewStyle().Foreground(lipgloss.Color("205")) @@ -231,7 +234,7 @@ func initialModel(ctx context.Context, initialQuery string) model { queryInput.SetValue(initialQuery) } CURRENT_QUERY_FOR_HIGHLIGHTING = initialQuery - return model{ctx: ctx, spinner: s, isLoading: true, table: nil, tableEntries: []*data.HistoryEntry{}, runQuery: &initialQuery, queryInput: queryInput, help: help.New()} + return model{ctx: ctx, spinner: s, isLoading: true, table: nil, tableEntries: []*data.HistoryEntry{}, runQuery: &initialQuery, queryInput: queryInput, help: help.New(), shellName: shellName} } func (m model) Init() tea.Cmd { @@ -252,7 +255,7 @@ func updateTable(m model, rows []table.Row, entries []*data.HistoryEntry, search initialCursor = m.table.Cursor() } if forceUpdateTable || m.table == nil { - t, err := makeTable(m.ctx, rows) + t, err := makeTable(m.ctx, m.shellName, rows) if err != nil { m.fatalErr = err return m @@ -299,7 +302,7 @@ func runQueryAndUpdateTable(m model, forceUpdateTable, maintainCursor bool) tea. // The default filter was cleared for this session, so don't apply it defaultFilter = "" } - rows, entries, searchErr := getRows(m.ctx, conf.DisplayedColumns, defaultFilter, query, PADDED_NUM_ENTRIES) + rows, entries, searchErr := getRows(m.ctx, conf.DisplayedColumns, m.shellName, defaultFilter, query, PADDED_NUM_ENTRIES) return asyncQueryFinishedMsg{queryId, rows, entries, searchErr, forceUpdateTable, maintainCursor, nil} } } @@ -493,8 +496,8 @@ func renderNullableTable(m model, helpText string) string { return baseStyle.Render(m.table.View()) } -func getRowsFromAiSuggestions(ctx context.Context, columnNames []string, query string) ([]table.Row, []*data.HistoryEntry, error) { - suggestions, err := ai.DebouncedGetAiSuggestions(ctx, strings.TrimPrefix(query, "?"), 5) +func getRowsFromAiSuggestions(ctx context.Context, columnNames []string, shellName, query string) ([]table.Row, []*data.HistoryEntry, error) { + suggestions, err := ai.DebouncedGetAiSuggestions(ctx, shellName, strings.TrimPrefix(query, "?"), 5) if err != nil { hctx.GetLogger().Infof("failed to get AI query suggestions: %v", err) return nil, nil, fmt.Errorf("failed to get AI query suggestions: %w", err) @@ -525,11 +528,11 @@ func getRowsFromAiSuggestions(ctx context.Context, columnNames []string, query s return rows, entries, nil } -func getRows(ctx context.Context, columnNames []string, defaultFilter, query string, numEntries int) ([]table.Row, []*data.HistoryEntry, error) { +func getRows(ctx context.Context, columnNames []string, shellName, defaultFilter, query string, numEntries int) ([]table.Row, []*data.HistoryEntry, error) { db := hctx.GetDb(ctx) config := hctx.GetConf(ctx) if config.AiCompletion && !config.IsOffline && strings.HasPrefix(query, "?") && len(query) > 1 { - return getRowsFromAiSuggestions(ctx, columnNames, query) + return getRowsFromAiSuggestions(ctx, columnNames, shellName, query) } searchResults, err := lib.Search(ctx, db, defaultFilter+" "+query, numEntries) if err != nil { @@ -588,10 +591,10 @@ func getTerminalSize() (int, int, error) { var bigQueryResults []table.Row -func makeTableColumns(ctx context.Context, columnNames []string, rows []table.Row) ([]table.Column, error) { +func makeTableColumns(ctx context.Context, shellName string, columnNames []string, rows []table.Row) ([]table.Column, error) { // Handle an initial query with no results if len(rows) == 0 || len(rows[0]) == 0 { - allRows, _, err := getRows(ctx, columnNames, hctx.GetConf(ctx).DefaultFilter, "", 25) + allRows, _, err := getRows(ctx, columnNames, shellName, hctx.GetConf(ctx).DefaultFilter, "", 25) if err != nil { return nil, err } @@ -604,7 +607,7 @@ func makeTableColumns(ctx context.Context, columnNames []string, rows []table.Ro } allRows = append(allRows, row) } - return makeTableColumns(ctx, columnNames, allRows) + return makeTableColumns(ctx, shellName, columnNames, allRows) } // Calculate the minimum amount of space that we need for each column for the current actual search @@ -617,7 +620,7 @@ func makeTableColumns(ctx context.Context, columnNames []string, rows []table.Ro // Calculate the maximum column width that is useful for each column if we search for the empty string if bigQueryResults == nil { - bigRows, _, err := getRows(ctx, columnNames, "", "", 1000) + bigRows, _, err := getRows(ctx, columnNames, shellName, "", "", 1000) if err != nil { return nil, err } @@ -678,9 +681,9 @@ func min(a, b int) int { return b } -func makeTable(ctx context.Context, rows []table.Row) (table.Model, error) { +func makeTable(ctx context.Context, shellName string, rows []table.Row) (table.Model, error) { config := hctx.GetConf(ctx) - columns, err := makeTableColumns(ctx, config.DisplayedColumns, rows) + columns, err := makeTableColumns(ctx, shellName, config.DisplayedColumns, rows) if err != nil { return table.Model{}, err } @@ -887,22 +890,22 @@ func configureColorProfile(ctx context.Context) { } } -func TuiQuery(ctx context.Context, initialQuery string) error { +func TuiQuery(ctx context.Context, shellName, initialQuery string) error { configureColorProfile(ctx) - p := tea.NewProgram(initialModel(ctx, initialQuery), tea.WithOutput(os.Stderr)) + p := tea.NewProgram(initialModel(ctx, shellName, initialQuery), tea.WithOutput(os.Stderr)) // Async: Get the initial set of rows go func() { LAST_DISPATCHED_QUERY_ID++ queryId := LAST_DISPATCHED_QUERY_ID LAST_DISPATCHED_QUERY_TIMESTAMP = time.Now() conf := hctx.GetConf(ctx) - rows, entries, err := getRows(ctx, conf.DisplayedColumns, conf.DefaultFilter, initialQuery, PADDED_NUM_ENTRIES) + rows, entries, err := getRows(ctx, conf.DisplayedColumns, shellName, conf.DefaultFilter, initialQuery, PADDED_NUM_ENTRIES) if err == nil || initialQuery == "" { p.Send(asyncQueryFinishedMsg{queryId: queryId, rows: rows, entries: entries, searchErr: err, forceUpdateTable: true, maintainCursor: false, overriddenSearchQuery: nil}) } else { // initialQuery is likely invalid in some way, let's just drop it emptyQuery := "" - rows, entries, err := getRows(ctx, hctx.GetConf(ctx).DisplayedColumns, conf.DefaultFilter, emptyQuery, PADDED_NUM_ENTRIES) + rows, entries, err := getRows(ctx, hctx.GetConf(ctx).DisplayedColumns, shellName, conf.DefaultFilter, emptyQuery, PADDED_NUM_ENTRIES) p.Send(asyncQueryFinishedMsg{queryId: queryId, rows: rows, entries: entries, searchErr: err, forceUpdateTable: true, maintainCursor: false, overriddenSearchQuery: &emptyQuery}) } }()