Wire through the shell name into AI suggestions so that we can get more precise AI suggestions for the current shell

This commit is contained in:
David Dworken 2024-02-19 12:12:04 -08:00
parent 339da47636
commit 0787840a10
No known key found for this signature in database
6 changed files with 36 additions and 34 deletions

View File

@ -17,20 +17,20 @@ import (
var mostRecentQuery string var mostRecentQuery string
func DebouncedGetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) { func DebouncedGetAiSuggestions(ctx context.Context, shellName, query string, numberCompletions int) ([]string, error) {
mostRecentQuery = query mostRecentQuery = query
time.Sleep(time.Millisecond * 300) time.Sleep(time.Millisecond * 300)
if mostRecentQuery == query { if mostRecentQuery == query {
return GetAiSuggestions(ctx, query, numberCompletions) return GetAiSuggestions(ctx, shellName, query, numberCompletions)
} }
return nil, nil return nil, nil
} }
func GetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) { func GetAiSuggestions(ctx context.Context, shellName, query string, numberCompletions int) ([]string, error) {
if os.Getenv("OPENAI_API_KEY") == "" { if os.Getenv("OPENAI_API_KEY") == "" {
return GetAiSuggestionsViaHishtoryApi(ctx, query, numberCompletions) return GetAiSuggestionsViaHishtoryApi(ctx, shellName, query, numberCompletions)
} else { } else {
suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(query, getShellName(), getOsName(), numberCompletions) suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(query, shellName, getOsName(), numberCompletions)
return suggestions, err return suggestions, err
} }
} }
@ -55,12 +55,7 @@ func getOsName() string {
} }
} }
func getShellName() string { func GetAiSuggestionsViaHishtoryApi(ctx context.Context, shellName, query string, numberCompletions int) ([]string, error) {
// TODO: Wire the real shell name in here
return "bash"
}
func GetAiSuggestionsViaHishtoryApi(ctx context.Context, query string, numberCompletions int) ([]string, error) {
hctx.GetLogger().Infof("Running OpenAI query for %#v", query) hctx.GetLogger().Infof("Running OpenAI query for %#v", query)
req := ai.AiSuggestionRequest{ req := ai.AiSuggestionRequest{
DeviceId: hctx.GetConf(ctx).DeviceId, DeviceId: hctx.GetConf(ctx).DeviceId,
@ -68,7 +63,7 @@ func GetAiSuggestionsViaHishtoryApi(ctx context.Context, query string, numberCom
Query: query, Query: query,
NumberCompletions: numberCompletions, NumberCompletions: numberCompletions,
OsName: getOsName(), OsName: getOsName(),
ShellName: getShellName(), ShellName: shellName,
} }
reqData, err := json.Marshal(req) reqData, err := json.Marshal(req)
if err != nil { if err != nil {

View File

@ -50,7 +50,11 @@ var tqueryCmd = &cobra.Command{
DisableFlagParsing: true, DisableFlagParsing: true,
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
ctx := hctx.MakeContext() ctx := hctx.MakeContext()
lib.CheckFatalError(tui.TuiQuery(ctx, strings.Join(args, " "))) shellName := "bash"
if os.Getenv("HISHTORY_SHELL_NAME") != "" {
shellName = os.Getenv("HISHTORY_SHELL_NAME")
}
lib.CheckFatalError(tui.TuiQuery(ctx, shellName, strings.Join(args, " ")))
}, },
} }

View File

@ -29,7 +29,7 @@ end
function __hishtory_on_control_r function __hishtory_on_control_r
set -l tmp (mktemp -t fish.XXXXXX) set -l tmp (mktemp -t fish.XXXXXX)
set -x init_query (commandline -b) set -x init_query (commandline -b)
HISHTORY_TERM_INTEGRATION=1 hishtory tquery $init_query > $tmp HISHTORY_TERM_INTEGRATION=1 HISHTORY_SHELL_NAME=fish hishtory tquery $init_query > $tmp
set -l res $status set -l res $status
commandline -f repaint commandline -f repaint
if [ -s $tmp ] if [ -s $tmp ]

View File

@ -46,7 +46,7 @@ PROMPT_COMMAND="__hishtory_postcommand; $PROMPT_COMMAND"
export HISTTIMEFORMAT=$HISTTIMEFORMAT export HISTTIMEFORMAT=$HISTTIMEFORMAT
__history_control_r() { __history_control_r() {
READLINE_LINE=$(HISHTORY_TERM_INTEGRATION=1 hishtory tquery "$READLINE_LINE") READLINE_LINE=$(HISHTORY_TERM_INTEGRATION=1 HISHTORY_SHELL_NAME=bash hishtory tquery "$READLINE_LINE")
READLINE_POINT=0x7FFFFFFF READLINE_POINT=0x7FFFFFFF
} }

View File

@ -33,7 +33,7 @@ function _hishtory_precmd() {
} }
_hishtory_widget() { _hishtory_widget() {
BUFFER=$(HISHTORY_TERM_INTEGRATION=1 hishtory tquery $BUFFER) BUFFER=$(HISHTORY_TERM_INTEGRATION=1 HISHTORY_SHELL_NAME=zsh hishtory tquery $BUFFER)
CURSOR=${#BUFFER} CURSOR=${#BUFFER}
zle reset-prompt zle reset-prompt
} }

View File

@ -182,6 +182,9 @@ type model struct {
// A banner from the backend to be displayed. Generally an empty string. // A banner from the backend to be displayed. Generally an empty string.
banner string banner string
// The currently executing shell. Defaults to bash if not specified. Used for more precise AI suggestions.
shellName string
} }
type doneDownloadingMsg struct{} type doneDownloadingMsg struct{}
@ -205,7 +208,7 @@ type asyncQueryFinishedMsg struct {
overriddenSearchQuery *string overriddenSearchQuery *string
} }
func initialModel(ctx context.Context, initialQuery string) model { func initialModel(ctx context.Context, shellName, initialQuery string) model {
s := spinner.New() s := spinner.New()
s.Spinner = spinner.Dot s.Spinner = spinner.Dot
s.Style = lipgloss.NewStyle().Foreground(lipgloss.Color("205")) s.Style = lipgloss.NewStyle().Foreground(lipgloss.Color("205"))
@ -231,7 +234,7 @@ func initialModel(ctx context.Context, initialQuery string) model {
queryInput.SetValue(initialQuery) queryInput.SetValue(initialQuery)
} }
CURRENT_QUERY_FOR_HIGHLIGHTING = initialQuery CURRENT_QUERY_FOR_HIGHLIGHTING = initialQuery
return model{ctx: ctx, spinner: s, isLoading: true, table: nil, tableEntries: []*data.HistoryEntry{}, runQuery: &initialQuery, queryInput: queryInput, help: help.New()} return model{ctx: ctx, spinner: s, isLoading: true, table: nil, tableEntries: []*data.HistoryEntry{}, runQuery: &initialQuery, queryInput: queryInput, help: help.New(), shellName: shellName}
} }
func (m model) Init() tea.Cmd { func (m model) Init() tea.Cmd {
@ -252,7 +255,7 @@ func updateTable(m model, rows []table.Row, entries []*data.HistoryEntry, search
initialCursor = m.table.Cursor() initialCursor = m.table.Cursor()
} }
if forceUpdateTable || m.table == nil { if forceUpdateTable || m.table == nil {
t, err := makeTable(m.ctx, rows) t, err := makeTable(m.ctx, m.shellName, rows)
if err != nil { if err != nil {
m.fatalErr = err m.fatalErr = err
return m return m
@ -299,7 +302,7 @@ func runQueryAndUpdateTable(m model, forceUpdateTable, maintainCursor bool) tea.
// The default filter was cleared for this session, so don't apply it // The default filter was cleared for this session, so don't apply it
defaultFilter = "" defaultFilter = ""
} }
rows, entries, searchErr := getRows(m.ctx, conf.DisplayedColumns, defaultFilter, query, PADDED_NUM_ENTRIES) rows, entries, searchErr := getRows(m.ctx, conf.DisplayedColumns, m.shellName, defaultFilter, query, PADDED_NUM_ENTRIES)
return asyncQueryFinishedMsg{queryId, rows, entries, searchErr, forceUpdateTable, maintainCursor, nil} return asyncQueryFinishedMsg{queryId, rows, entries, searchErr, forceUpdateTable, maintainCursor, nil}
} }
} }
@ -493,8 +496,8 @@ func renderNullableTable(m model, helpText string) string {
return baseStyle.Render(m.table.View()) return baseStyle.Render(m.table.View())
} }
func getRowsFromAiSuggestions(ctx context.Context, columnNames []string, query string) ([]table.Row, []*data.HistoryEntry, error) { func getRowsFromAiSuggestions(ctx context.Context, columnNames []string, shellName, query string) ([]table.Row, []*data.HistoryEntry, error) {
suggestions, err := ai.DebouncedGetAiSuggestions(ctx, strings.TrimPrefix(query, "?"), 5) suggestions, err := ai.DebouncedGetAiSuggestions(ctx, shellName, strings.TrimPrefix(query, "?"), 5)
if err != nil { if err != nil {
hctx.GetLogger().Infof("failed to get AI query suggestions: %v", err) hctx.GetLogger().Infof("failed to get AI query suggestions: %v", err)
return nil, nil, fmt.Errorf("failed to get AI query suggestions: %w", err) return nil, nil, fmt.Errorf("failed to get AI query suggestions: %w", err)
@ -525,11 +528,11 @@ func getRowsFromAiSuggestions(ctx context.Context, columnNames []string, query s
return rows, entries, nil return rows, entries, nil
} }
func getRows(ctx context.Context, columnNames []string, defaultFilter, query string, numEntries int) ([]table.Row, []*data.HistoryEntry, error) { func getRows(ctx context.Context, columnNames []string, shellName, defaultFilter, query string, numEntries int) ([]table.Row, []*data.HistoryEntry, error) {
db := hctx.GetDb(ctx) db := hctx.GetDb(ctx)
config := hctx.GetConf(ctx) config := hctx.GetConf(ctx)
if config.AiCompletion && !config.IsOffline && strings.HasPrefix(query, "?") && len(query) > 1 { if config.AiCompletion && !config.IsOffline && strings.HasPrefix(query, "?") && len(query) > 1 {
return getRowsFromAiSuggestions(ctx, columnNames, query) return getRowsFromAiSuggestions(ctx, columnNames, shellName, query)
} }
searchResults, err := lib.Search(ctx, db, defaultFilter+" "+query, numEntries) searchResults, err := lib.Search(ctx, db, defaultFilter+" "+query, numEntries)
if err != nil { if err != nil {
@ -588,10 +591,10 @@ func getTerminalSize() (int, int, error) {
var bigQueryResults []table.Row var bigQueryResults []table.Row
func makeTableColumns(ctx context.Context, columnNames []string, rows []table.Row) ([]table.Column, error) { func makeTableColumns(ctx context.Context, shellName string, columnNames []string, rows []table.Row) ([]table.Column, error) {
// Handle an initial query with no results // Handle an initial query with no results
if len(rows) == 0 || len(rows[0]) == 0 { if len(rows) == 0 || len(rows[0]) == 0 {
allRows, _, err := getRows(ctx, columnNames, hctx.GetConf(ctx).DefaultFilter, "", 25) allRows, _, err := getRows(ctx, columnNames, shellName, hctx.GetConf(ctx).DefaultFilter, "", 25)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -604,7 +607,7 @@ func makeTableColumns(ctx context.Context, columnNames []string, rows []table.Ro
} }
allRows = append(allRows, row) allRows = append(allRows, row)
} }
return makeTableColumns(ctx, columnNames, allRows) return makeTableColumns(ctx, shellName, columnNames, allRows)
} }
// Calculate the minimum amount of space that we need for each column for the current actual search // Calculate the minimum amount of space that we need for each column for the current actual search
@ -617,7 +620,7 @@ func makeTableColumns(ctx context.Context, columnNames []string, rows []table.Ro
// Calculate the maximum column width that is useful for each column if we search for the empty string // Calculate the maximum column width that is useful for each column if we search for the empty string
if bigQueryResults == nil { if bigQueryResults == nil {
bigRows, _, err := getRows(ctx, columnNames, "", "", 1000) bigRows, _, err := getRows(ctx, columnNames, shellName, "", "", 1000)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -678,9 +681,9 @@ func min(a, b int) int {
return b return b
} }
func makeTable(ctx context.Context, rows []table.Row) (table.Model, error) { func makeTable(ctx context.Context, shellName string, rows []table.Row) (table.Model, error) {
config := hctx.GetConf(ctx) config := hctx.GetConf(ctx)
columns, err := makeTableColumns(ctx, config.DisplayedColumns, rows) columns, err := makeTableColumns(ctx, shellName, config.DisplayedColumns, rows)
if err != nil { if err != nil {
return table.Model{}, err return table.Model{}, err
} }
@ -887,22 +890,22 @@ func configureColorProfile(ctx context.Context) {
} }
} }
func TuiQuery(ctx context.Context, initialQuery string) error { func TuiQuery(ctx context.Context, shellName, initialQuery string) error {
configureColorProfile(ctx) configureColorProfile(ctx)
p := tea.NewProgram(initialModel(ctx, initialQuery), tea.WithOutput(os.Stderr)) p := tea.NewProgram(initialModel(ctx, shellName, initialQuery), tea.WithOutput(os.Stderr))
// Async: Get the initial set of rows // Async: Get the initial set of rows
go func() { go func() {
LAST_DISPATCHED_QUERY_ID++ LAST_DISPATCHED_QUERY_ID++
queryId := LAST_DISPATCHED_QUERY_ID queryId := LAST_DISPATCHED_QUERY_ID
LAST_DISPATCHED_QUERY_TIMESTAMP = time.Now() LAST_DISPATCHED_QUERY_TIMESTAMP = time.Now()
conf := hctx.GetConf(ctx) conf := hctx.GetConf(ctx)
rows, entries, err := getRows(ctx, conf.DisplayedColumns, conf.DefaultFilter, initialQuery, PADDED_NUM_ENTRIES) rows, entries, err := getRows(ctx, conf.DisplayedColumns, shellName, conf.DefaultFilter, initialQuery, PADDED_NUM_ENTRIES)
if err == nil || initialQuery == "" { if err == nil || initialQuery == "" {
p.Send(asyncQueryFinishedMsg{queryId: queryId, rows: rows, entries: entries, searchErr: err, forceUpdateTable: true, maintainCursor: false, overriddenSearchQuery: nil}) p.Send(asyncQueryFinishedMsg{queryId: queryId, rows: rows, entries: entries, searchErr: err, forceUpdateTable: true, maintainCursor: false, overriddenSearchQuery: nil})
} else { } else {
// initialQuery is likely invalid in some way, let's just drop it // initialQuery is likely invalid in some way, let's just drop it
emptyQuery := "" emptyQuery := ""
rows, entries, err := getRows(ctx, hctx.GetConf(ctx).DisplayedColumns, conf.DefaultFilter, emptyQuery, PADDED_NUM_ENTRIES) rows, entries, err := getRows(ctx, hctx.GetConf(ctx).DisplayedColumns, shellName, conf.DefaultFilter, emptyQuery, PADDED_NUM_ENTRIES)
p.Send(asyncQueryFinishedMsg{queryId: queryId, rows: rows, entries: entries, searchErr: err, forceUpdateTable: true, maintainCursor: false, overriddenSearchQuery: &emptyQuery}) p.Send(asyncQueryFinishedMsg{queryId: queryId, rows: rows, entries: entries, searchErr: err, forceUpdateTable: true, maintainCursor: false, overriddenSearchQuery: &emptyQuery})
} }
}() }()