Add basic debouncing for AI integration + implement AI suggestions via hishtory API endpoint

This commit is contained in:
David Dworken 2023-11-11 20:41:59 -08:00
parent eb835fe52c
commit 0ea3ce2399
6 changed files with 75 additions and 84 deletions

View File

@ -325,11 +325,12 @@ func (s *Server) aiSuggestionHandler(w http.ResponseWriter, r *http.Request) {
if numDevices == 0 {
panic(fmt.Errorf("rejecting OpenAI request for user_id=%#v since it does not exist", req.UserId))
}
suggestions, err := ai.GetAiSuggestionsViaOpenAiApi(req.Query, req.NumberCompletions)
suggestions, usage, err := ai.GetAiSuggestionsViaOpenAiApi(req.Query, req.NumberCompletions)
if err != nil {
panic(fmt.Errorf("failed to query OpenAI API: %w", err))
}
s.statsd.Incr("hishtory.openai.query", []string{}, float64(req.NumberCompletions))
s.statsd.Incr("hishtory.openai.tokens", []string{}, float64(usage.TotalTokens))
var resp ai.AiSuggestionResponse
resp.Suggestions = suggestions
if err := json.NewEncoder(w).Encode(resp); err != nil {

57
client/ai/ai.go Normal file
View File

@ -0,0 +1,57 @@
package ai
import (
"context"
"encoding/json"
"fmt"
"os"
"time"
"github.com/ddworken/hishtory/client/data"
"github.com/ddworken/hishtory/client/hctx"
"github.com/ddworken/hishtory/client/lib"
"github.com/ddworken/hishtory/shared/ai"
)
var mostRecentQuery string
func DebouncedGetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) {
mostRecentQuery = query
time.Sleep(time.Millisecond * 300)
if mostRecentQuery == query {
return GetAiSuggestions(ctx, query, numberCompletions)
}
return nil, nil
}
func GetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) {
if os.Getenv("OPENAI_API_KEY") == "" {
return GetAiSuggestionsViaHishtoryApi(ctx, query, numberCompletions)
} else {
suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(query, numberCompletions)
return suggestions, err
}
}
func GetAiSuggestionsViaHishtoryApi(ctx context.Context, query string, numberCompletions int) ([]string, error) {
req := ai.AiSuggestionRequest{
DeviceId: hctx.GetConf(ctx).DeviceId,
UserId: data.UserId(hctx.GetConf(ctx).UserSecret),
Query: query,
NumberCompletions: numberCompletions,
}
reqData, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal AiSuggestionRequest: %w", err)
}
respData, err := lib.ApiPost(ctx, "/api/v1/ai-suggest", "application/json", reqData)
if err != nil {
return nil, fmt.Errorf("failed to query /api/v1/ai-suggest: %w", err)
}
var resp ai.AiSuggestionResponse
err = json.Unmarshal(respData, &resp)
if err != nil {
return nil, fmt.Errorf("failed to parse /api/v1/ai-suggest response: %w", err)
}
return resp.Suggestions, nil
}

View File

@ -17,12 +17,12 @@ import (
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/ddworken/hishtory/client/ai"
"github.com/ddworken/hishtory/client/data"
"github.com/ddworken/hishtory/client/hctx"
"github.com/ddworken/hishtory/client/lib"
"github.com/ddworken/hishtory/client/table"
"github.com/ddworken/hishtory/shared"
"github.com/ddworken/hishtory/shared/ai"
"github.com/muesli/termenv"
"golang.org/x/term"
)

View File

@ -9,7 +9,7 @@ import (
)
func main() {
resp, err := ai.GetAiSuggestions("Find all CSV files in the current directory or subdirectories and select the first column, then prepend `foo` to each line", 3)
resp, _, err := ai.GetAiSuggestionsViaOpenAiApi("Find all CSV files in the current directory or subdirectories and select the first column, then prepend `foo` to each line", 3)
if err != nil {
log.Fatal(err)
}

View File

@ -2,16 +2,13 @@ package ai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"time"
"github.com/ddworken/hishtory/client/hctx"
"github.com/zmwangx/debounce"
"golang.org/x/exp/slices"
)
@ -31,7 +28,7 @@ type openAiResponse struct {
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
Usage openAiUsage `json:"usage"`
Usage OpenAiUsage `json:"usage"`
Choices []openAiChoice `json:"choices"`
}
@ -41,67 +38,22 @@ type openAiChoice struct {
FinishReason string `json:"finish_reason"`
}
type openAiUsage struct {
type OpenAiUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type debouncedAiSuggestionsIn struct {
ctx context.Context
query string
numberCompletions int
}
var testOnlyOverrideAiSuggestions map[string][]string
type debouncedAiSuggestionsOut struct {
suggestions []string
err error
}
var debouncedGetAiSuggestionsInner func(...debouncedAiSuggestionsIn) debouncedAiSuggestionsOut
var debouncedControl debounce.ControlWithReturnValue[debouncedAiSuggestionsOut]
func init() {
openAiDebouncePeriod := time.Millisecond * 1000
debouncedGetAiSuggestionsInner, debouncedControl = debounce.DebounceWithCustomSignature(
func(input ...debouncedAiSuggestionsIn) debouncedAiSuggestionsOut {
if len(input) != 1 {
return debouncedAiSuggestionsOut{nil, fmt.Errorf("unexpected input length: %d", len(input))}
}
suggestions, err := GetAiSuggestions(input[0].ctx, input[0].query, input[0].numberCompletions)
return debouncedAiSuggestionsOut{suggestions, err}
},
openAiDebouncePeriod,
debounce.WithLeading(false),
debounce.WithTrailing(true),
debounce.WithMaxWait(openAiDebouncePeriod),
)
go func() {
for {
debouncedControl.Flush()
time.Sleep(openAiDebouncePeriod)
}
}()
}
func DebouncedGetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) {
resp := debouncedGetAiSuggestionsInner(debouncedAiSuggestionsIn{ctx, query, numberCompletions})
return resp.suggestions, resp.err
}
func GetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) {
if os.Getenv("OPENAI_API_KEY") == "" {
panic("TODO: Implement integration with the hishtory API")
} else {
return GetAiSuggestionsViaOpenAiApi(query, numberCompletions)
func GetAiSuggestionsViaOpenAiApi(query string, numberCompletions int) ([]string, OpenAiUsage, error) {
if results := testOnlyOverrideAiSuggestions[query]; len(results) > 0 {
return results, OpenAiUsage{}, nil
}
}
func GetAiSuggestionsViaOpenAiApi(query string, numberCompletions int) ([]string, error) {
hctx.GetLogger().Infof("Running OpenAI query for %#v", query)
apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey == "" {
return nil, fmt.Errorf("OPENAI_API_KEY environment variable is not set")
return nil, OpenAiUsage{}, fmt.Errorf("OPENAI_API_KEY environment variable is not set")
}
client := &http.Client{}
apiReq := openAiRequest{
@ -114,30 +66,30 @@ func GetAiSuggestionsViaOpenAiApi(query string, numberCompletions int) ([]string
}
apiReqStr, err := json.Marshal(apiReq)
if err != nil {
return nil, fmt.Errorf("failed to serialize JSON for OpenAI API: %w", err)
return nil, OpenAiUsage{}, fmt.Errorf("failed to serialize JSON for OpenAI API: %w", err)
}
req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", bytes.NewBuffer(apiReqStr))
if err != nil {
return nil, fmt.Errorf("failed to create OpenAI API request: %w", err)
return nil, OpenAiUsage{}, fmt.Errorf("failed to create OpenAI API request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to query OpenAI API: %w", err)
return nil, OpenAiUsage{}, fmt.Errorf("failed to query OpenAI API: %w", err)
}
defer resp.Body.Close()
bodyText, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read OpenAI API response: %w", err)
return nil, OpenAiUsage{}, fmt.Errorf("failed to read OpenAI API response: %w", err)
}
var apiResp openAiResponse
err = json.Unmarshal(bodyText, &apiResp)
if err != nil {
return nil, fmt.Errorf("failed to parse OpenAI API response: %w", err)
return nil, OpenAiUsage{}, fmt.Errorf("failed to parse OpenAI API response: %w", err)
}
if len(apiResp.Choices) == 0 {
return nil, fmt.Errorf("OpenAI API returned zero choicesm, resp=%#v", apiResp)
return nil, OpenAiUsage{}, fmt.Errorf("OpenAI API returned zero choicesm, resp=%#v", apiResp)
}
ret := make([]string, 0)
for _, item := range apiResp.Choices {
@ -146,7 +98,7 @@ func GetAiSuggestionsViaOpenAiApi(query string, numberCompletions int) ([]string
}
}
hctx.GetLogger().Infof("For OpenAI query=%#v ==> %#v", query, ret)
return ret, nil
return ret, apiResp.Usage, nil
}
type AiSuggestionRequest struct {

View File

@ -1,19 +0,0 @@
package ai
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestGetAiSuggestion(t *testing.T) {
suggestions, err := ai.GetAiSuggestions("list files by size")
require.NoError(t, err)
for _, suggestion := range suggestions {
if strings.Contains(suggestion, "ls") {
return
}
}
t.Fatalf("none of the AI suggestions %#v contain 'ls' which is suspicious", suggestions)
}