Add basic test for AI queries

This commit is contained in:
David Dworken 2023-11-11 20:59:45 -08:00
parent 0ea3ce2399
commit b0f3107da2
2 changed files with 19 additions and 2 deletions

View File

@ -22,6 +22,7 @@ import (
"github.com/ddworken/hishtory/client/hctx"
"github.com/ddworken/hishtory/client/lib"
"github.com/ddworken/hishtory/shared"
"github.com/ddworken/hishtory/shared/ai"
"github.com/ddworken/hishtory/shared/testutils"
"github.com/stretchr/testify/require"
)
@ -105,6 +106,7 @@ func TestParam(t *testing.T) {
t.Run("testTui/delete", testTui_delete)
t.Run("testTui/color", testTui_color)
t.Run("testTui/errors", testTui_errors)
t.Run("testTui/ai", testTui_ai)
// Assert there are no leaked connections
assertNoLeakedConnections(t)
@ -1929,6 +1931,21 @@ func testTui_errors(t *testing.T) {
testutils.CompareGoldens(t, out, "TestTui-OfflineInvalid")
}
func testTui_ai(t *testing.T) {
// Setup
defer testutils.BackupAndRestore(t)()
tester, _, _ := setupTestTui(t, Online)
// Test running an AI query
ai.TestOnlyOverrideAiSuggestions["myQuery"] = []string{"result 1", "result 2", "longer result 3"}
out := captureTerminalOutput(t, tester, []string{
"hishtory SPACE tquery ENTER",
"?myQuery",
})
out = strings.TrimSpace(strings.Split(out, "hishtory tquery")[1])
testutils.CompareGoldens(t, out, "TestTui-AiQuery")
}
func testControlR(t *testing.T, tester shellTester, shellName string, onlineStatus OnlineStatus) {
// Setup
defer testutils.BackupAndRestore(t)()

View File

@ -44,10 +44,10 @@ type OpenAiUsage struct {
TotalTokens int `json:"total_tokens"`
}
var testOnlyOverrideAiSuggestions map[string][]string
var TestOnlyOverrideAiSuggestions map[string][]string = make(map[string][]string)
func GetAiSuggestionsViaOpenAiApi(query string, numberCompletions int) ([]string, OpenAiUsage, error) {
if results := testOnlyOverrideAiSuggestions[query]; len(results) > 0 {
if results := TestOnlyOverrideAiSuggestions[query]; len(results) > 0 {
return results, OpenAiUsage{}, nil
}
hctx.GetLogger().Infof("Running OpenAI query for %#v", query)