From bf5fee194f52e827614b20964ed6eccea352fc0b Mon Sep 17 00:00:00 2001 From: David Dworken Date: Sat, 11 Nov 2023 20:41:59 -0800 Subject: [PATCH] Add basic debouncing for AI integration + implement AI suggestions via hishtory API endpoint --- .../server/internal/server/api_handlers.go | 3 +- client/ai/ai.go | 57 ++++++++++++++ client/tui/tui.go | 2 +- scripts/aimain.go | 2 +- shared/ai/ai.go | 76 ++++--------------- shared/ai/ai_test.go | 19 ----- 6 files changed, 75 insertions(+), 84 deletions(-) create mode 100644 client/ai/ai.go delete mode 100644 shared/ai/ai_test.go diff --git a/backend/server/internal/server/api_handlers.go b/backend/server/internal/server/api_handlers.go index 9eea56c..c61816c 100644 --- a/backend/server/internal/server/api_handlers.go +++ b/backend/server/internal/server/api_handlers.go @@ -325,11 +325,12 @@ func (s *Server) aiSuggestionHandler(w http.ResponseWriter, r *http.Request) { if numDevices == 0 { panic(fmt.Errorf("rejecting OpenAI request for user_id=%#v since it does not exist", req.UserId)) } - suggestions, err := ai.GetAiSuggestionsViaOpenAiApi(req.Query, req.NumberCompletions) + suggestions, usage, err := ai.GetAiSuggestionsViaOpenAiApi(req.Query, req.NumberCompletions) if err != nil { panic(fmt.Errorf("failed to query OpenAI API: %w", err)) } s.statsd.Incr("hishtory.openai.query", []string{}, float64(req.NumberCompletions)) + s.statsd.Incr("hishtory.openai.tokens", []string{}, float64(usage.TotalTokens)) var resp ai.AiSuggestionResponse resp.Suggestions = suggestions if err := json.NewEncoder(w).Encode(resp); err != nil { diff --git a/client/ai/ai.go b/client/ai/ai.go new file mode 100644 index 0000000..38a4fa1 --- /dev/null +++ b/client/ai/ai.go @@ -0,0 +1,57 @@ +package ai + +import ( + "context" + "encoding/json" + "fmt" + "os" + "time" + + "github.com/ddworken/hishtory/client/data" + "github.com/ddworken/hishtory/client/hctx" + "github.com/ddworken/hishtory/client/lib" + "github.com/ddworken/hishtory/shared/ai" +) + +var mostRecentQuery string + +func DebouncedGetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) { + mostRecentQuery = query + time.Sleep(time.Millisecond * 300) + if mostRecentQuery == query { + return GetAiSuggestions(ctx, query, numberCompletions) + } + return nil, nil +} + +func GetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) { + if os.Getenv("OPENAI_API_KEY") == "" { + return GetAiSuggestionsViaHishtoryApi(ctx, query, numberCompletions) + } else { + suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(query, numberCompletions) + return suggestions, err + } +} + +func GetAiSuggestionsViaHishtoryApi(ctx context.Context, query string, numberCompletions int) ([]string, error) { + req := ai.AiSuggestionRequest{ + DeviceId: hctx.GetConf(ctx).DeviceId, + UserId: data.UserId(hctx.GetConf(ctx).UserSecret), + Query: query, + NumberCompletions: numberCompletions, + } + reqData, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("failed to marshal AiSuggestionRequest: %w", err) + } + respData, err := lib.ApiPost(ctx, "/api/v1/ai-suggest", "application/json", reqData) + if err != nil { + return nil, fmt.Errorf("failed to query /api/v1/ai-suggest: %w", err) + } + var resp ai.AiSuggestionResponse + err = json.Unmarshal(respData, &resp) + if err != nil { + return nil, fmt.Errorf("failed to parse /api/v1/ai-suggest response: %w", err) + } + return resp.Suggestions, nil +} diff --git a/client/tui/tui.go b/client/tui/tui.go index 7911e87..04b0190 100644 --- a/client/tui/tui.go +++ b/client/tui/tui.go @@ -17,12 +17,12 @@ import ( "github.com/charmbracelet/bubbles/textinput" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" + "github.com/ddworken/hishtory/client/ai" "github.com/ddworken/hishtory/client/data" "github.com/ddworken/hishtory/client/hctx" "github.com/ddworken/hishtory/client/lib" "github.com/ddworken/hishtory/client/table" "github.com/ddworken/hishtory/shared" - "github.com/ddworken/hishtory/shared/ai" "github.com/muesli/termenv" "golang.org/x/term" ) diff --git a/scripts/aimain.go b/scripts/aimain.go index 2736c9a..be83021 100644 --- a/scripts/aimain.go +++ b/scripts/aimain.go @@ -9,7 +9,7 @@ import ( ) func main() { - resp, err := ai.GetAiSuggestions("Find all CSV files in the current directory or subdirectories and select the first column, then prepend `foo` to each line", 3) + resp, _, err := ai.GetAiSuggestionsViaOpenAiApi("Find all CSV files in the current directory or subdirectories and select the first column, then prepend `foo` to each line", 3) if err != nil { log.Fatal(err) } diff --git a/shared/ai/ai.go b/shared/ai/ai.go index 525099a..b193377 100644 --- a/shared/ai/ai.go +++ b/shared/ai/ai.go @@ -2,16 +2,13 @@ package ai import ( "bytes" - "context" "encoding/json" "fmt" "io" "net/http" "os" - "time" "github.com/ddworken/hishtory/client/hctx" - "github.com/zmwangx/debounce" "golang.org/x/exp/slices" ) @@ -31,7 +28,7 @@ type openAiResponse struct { Object string `json:"object"` Created int `json:"created"` Model string `json:"model"` - Usage openAiUsage `json:"usage"` + Usage OpenAiUsage `json:"usage"` Choices []openAiChoice `json:"choices"` } @@ -41,67 +38,22 @@ type openAiChoice struct { FinishReason string `json:"finish_reason"` } -type openAiUsage struct { +type OpenAiUsage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` } -type debouncedAiSuggestionsIn struct { - ctx context.Context - query string - numberCompletions int -} +var testOnlyOverrideAiSuggestions map[string][]string -type debouncedAiSuggestionsOut struct { - suggestions []string - err error -} - -var debouncedGetAiSuggestionsInner func(...debouncedAiSuggestionsIn) debouncedAiSuggestionsOut -var debouncedControl debounce.ControlWithReturnValue[debouncedAiSuggestionsOut] - -func init() { - openAiDebouncePeriod := time.Millisecond * 1000 - debouncedGetAiSuggestionsInner, debouncedControl = debounce.DebounceWithCustomSignature( - func(input ...debouncedAiSuggestionsIn) debouncedAiSuggestionsOut { - if len(input) != 1 { - return debouncedAiSuggestionsOut{nil, fmt.Errorf("unexpected input length: %d", len(input))} - } - suggestions, err := GetAiSuggestions(input[0].ctx, input[0].query, input[0].numberCompletions) - return debouncedAiSuggestionsOut{suggestions, err} - }, - openAiDebouncePeriod, - debounce.WithLeading(false), - debounce.WithTrailing(true), - debounce.WithMaxWait(openAiDebouncePeriod), - ) - go func() { - for { - debouncedControl.Flush() - time.Sleep(openAiDebouncePeriod) - } - }() -} - -func DebouncedGetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) { - resp := debouncedGetAiSuggestionsInner(debouncedAiSuggestionsIn{ctx, query, numberCompletions}) - return resp.suggestions, resp.err -} - -func GetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) { - if os.Getenv("OPENAI_API_KEY") == "" { - panic("TODO: Implement integration with the hishtory API") - } else { - return GetAiSuggestionsViaOpenAiApi(query, numberCompletions) +func GetAiSuggestionsViaOpenAiApi(query string, numberCompletions int) ([]string, OpenAiUsage, error) { + if results := testOnlyOverrideAiSuggestions[query]; len(results) > 0 { + return results, OpenAiUsage{}, nil } -} - -func GetAiSuggestionsViaOpenAiApi(query string, numberCompletions int) ([]string, error) { hctx.GetLogger().Infof("Running OpenAI query for %#v", query) apiKey := os.Getenv("OPENAI_API_KEY") if apiKey == "" { - return nil, fmt.Errorf("OPENAI_API_KEY environment variable is not set") + return nil, OpenAiUsage{}, fmt.Errorf("OPENAI_API_KEY environment variable is not set") } client := &http.Client{} apiReq := openAiRequest{ @@ -114,30 +66,30 @@ func GetAiSuggestionsViaOpenAiApi(query string, numberCompletions int) ([]string } apiReqStr, err := json.Marshal(apiReq) if err != nil { - return nil, fmt.Errorf("failed to serialize JSON for OpenAI API: %w", err) + return nil, OpenAiUsage{}, fmt.Errorf("failed to serialize JSON for OpenAI API: %w", err) } req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", bytes.NewBuffer(apiReqStr)) if err != nil { - return nil, fmt.Errorf("failed to create OpenAI API request: %w", err) + return nil, OpenAiUsage{}, fmt.Errorf("failed to create OpenAI API request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+apiKey) resp, err := client.Do(req) if err != nil { - return nil, fmt.Errorf("failed to query OpenAI API: %w", err) + return nil, OpenAiUsage{}, fmt.Errorf("failed to query OpenAI API: %w", err) } defer resp.Body.Close() bodyText, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("failed to read OpenAI API response: %w", err) + return nil, OpenAiUsage{}, fmt.Errorf("failed to read OpenAI API response: %w", err) } var apiResp openAiResponse err = json.Unmarshal(bodyText, &apiResp) if err != nil { - return nil, fmt.Errorf("failed to parse OpenAI API response: %w", err) + return nil, OpenAiUsage{}, fmt.Errorf("failed to parse OpenAI API response: %w", err) } if len(apiResp.Choices) == 0 { - return nil, fmt.Errorf("OpenAI API returned zero choicesm, resp=%#v", apiResp) + return nil, OpenAiUsage{}, fmt.Errorf("OpenAI API returned zero choicesm, resp=%#v", apiResp) } ret := make([]string, 0) for _, item := range apiResp.Choices { @@ -146,7 +98,7 @@ func GetAiSuggestionsViaOpenAiApi(query string, numberCompletions int) ([]string } } hctx.GetLogger().Infof("For OpenAI query=%#v ==> %#v", query, ret) - return ret, nil + return ret, apiResp.Usage, nil } type AiSuggestionRequest struct { diff --git a/shared/ai/ai_test.go b/shared/ai/ai_test.go deleted file mode 100644 index ae204ec..0000000 --- a/shared/ai/ai_test.go +++ /dev/null @@ -1,19 +0,0 @@ -package ai - -import ( - "strings" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestGetAiSuggestion(t *testing.T) { - suggestions, err := ai.GetAiSuggestions("list files by size") - require.NoError(t, err) - for _, suggestion := range suggestions { - if strings.Contains(suggestion, "ls") { - return - } - } - t.Fatalf("none of the AI suggestions %#v contain 'ls' which is suspicious", suggestions) -}