From eb835fe52c6b322d71f4676da16e288a5804cf35 Mon Sep 17 00:00:00 2001 From: David Dworken Date: Sat, 11 Nov 2023 17:41:24 -0800 Subject: [PATCH] Add initial version of AI searching, but with a broken implementation of debouncing --- .../server/internal/server/api_handlers.go | 31 ++++ backend/server/internal/server/srv.go | 1 + client/lib/lib.go | 6 +- client/tui/tui.go | 36 ++++ go.mod | 1 + go.sum | 2 + scripts/aimain.go | 17 ++ shared/ai/ai.go | 161 ++++++++++++++++++ shared/ai/ai_test.go | 19 +++ 9 files changed, 273 insertions(+), 1 deletion(-) create mode 100644 scripts/aimain.go create mode 100644 shared/ai/ai.go create mode 100644 shared/ai/ai_test.go diff --git a/backend/server/internal/server/api_handlers.go b/backend/server/internal/server/api_handlers.go index aa71aa2..9eea56c 100644 --- a/backend/server/internal/server/api_handlers.go +++ b/backend/server/internal/server/api_handlers.go @@ -12,6 +12,7 @@ import ( "time" "github.com/ddworken/hishtory/shared" + "github.com/ddworken/hishtory/shared/ai" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" ) @@ -272,6 +273,7 @@ func (s *Server) addDeletionRequestHandler(w http.ResponseWriter, r *http.Reques func (s *Server) slsaStatusHandler(w http.ResponseWriter, r *http.Request) { // returns "OK" unless there is a current SLSA bug v := getHishtoryVersion(r) + // TODO: Migrate this to a version parsing library if !strings.Contains(v, "v0.") { w.Write([]byte("OK")) return @@ -306,6 +308,35 @@ func (s *Server) feedbackHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } +func (s *Server) aiSuggestionHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + var req ai.AiSuggestionRequest + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + panic(fmt.Errorf("failed to decode: %w", err)) + } + if req.NumberCompletions > 10 { + panic(fmt.Errorf("request for %d completions is greater than max allowed", req.NumberCompletions)) + } + numDevices, err := s.db.CountDevicesForUser(ctx, req.UserId) + if err != nil { + panic(fmt.Errorf("failed to count devices for user: %w", err)) + } + if numDevices == 0 { + panic(fmt.Errorf("rejecting OpenAI request for user_id=%#v since it does not exist", req.UserId)) + } + suggestions, err := ai.GetAiSuggestionsViaOpenAiApi(req.Query, req.NumberCompletions) + if err != nil { + panic(fmt.Errorf("failed to query OpenAI API: %w", err)) + } + s.statsd.Incr("hishtory.openai.query", []string{}, float64(req.NumberCompletions)) + var resp ai.AiSuggestionResponse + resp.Suggestions = suggestions + if err := json.NewEncoder(w).Encode(resp); err != nil { + panic(fmt.Errorf("failed to JSON marshall the API response: %w", err)) + } +} + func (s *Server) pingHandler(w http.ResponseWriter, r *http.Request) { w.Write([]byte("OK")) } diff --git a/backend/server/internal/server/srv.go b/backend/server/internal/server/srv.go index 723c2d7..5ac1c2b 100644 --- a/backend/server/internal/server/srv.go +++ b/backend/server/internal/server/srv.go @@ -112,6 +112,7 @@ func (s *Server) Run(ctx context.Context, addr string) error { mux.Handle("/api/v1/add-deletion-request", middlewares(http.HandlerFunc(s.addDeletionRequestHandler))) mux.Handle("/api/v1/slsa-status", middlewares(http.HandlerFunc(s.slsaStatusHandler))) mux.Handle("/api/v1/feedback", middlewares(http.HandlerFunc(s.feedbackHandler))) + mux.Handle("/api/v1/ai-suggest", middlewares(http.HandlerFunc(s.aiSuggestionHandler))) mux.Handle("/api/v1/ping", middlewares(http.HandlerFunc(s.pingHandler))) mux.Handle("/healthcheck", middlewares(http.HandlerFunc(s.healthCheckHandler))) mux.Handle("/internal/api/v1/usage-stats", middlewares(http.HandlerFunc(s.usageStatsHandler))) diff --git a/client/lib/lib.go b/client/lib/lib.go index b139c30..46fc978 100644 --- a/client/lib/lib.go +++ b/client/lib/lib.go @@ -94,7 +94,11 @@ func BuildTableRow(ctx context.Context, columnNames []string, entry data.History case "CWD", "cwd": row = append(row, entry.CurrentWorkingDirectory) case "Timestamp", "timestamp": - row = append(row, entry.StartTime.Local().Format(hctx.GetConf(ctx).TimestampFormat)) + if entry.StartTime.UnixMilli() == 0 { + row = append(row, "N/A") + } else { + row = append(row, entry.StartTime.Local().Format(hctx.GetConf(ctx).TimestampFormat)) + } case "Runtime", "runtime": if entry.EndTime.UnixMilli() == 0 { // An EndTime of zero means this is a pre-saved entry that never finished diff --git a/client/tui/tui.go b/client/tui/tui.go index 54ac53b..7911e87 100644 --- a/client/tui/tui.go +++ b/client/tui/tui.go @@ -22,6 +22,7 @@ import ( "github.com/ddworken/hishtory/client/lib" "github.com/ddworken/hishtory/client/table" "github.com/ddworken/hishtory/shared" + "github.com/ddworken/hishtory/shared/ai" "github.com/muesli/termenv" "golang.org/x/term" ) @@ -436,9 +437,44 @@ func renderNullableTable(m model) string { return baseStyle.Render(m.table.View()) } +func getRowsFromAiSuggestions(ctx context.Context, columnNames []string, query string) ([]table.Row, []*data.HistoryEntry, error) { + // TODO: Add debouncing here so we don't waste API queries on half-typed search queries + suggestions, err := ai.DebouncedGetAiSuggestions(ctx, strings.TrimPrefix(query, "?"), 5) + if err != nil { + return nil, nil, fmt.Errorf("failed to get AI query suggestions: %w", err) + } + var rows []table.Row + var entries []*data.HistoryEntry + for _, suggestion := range suggestions { + entry := data.HistoryEntry{ + LocalUsername: "OpenAI", + Hostname: "OpenAI", + Command: suggestion, + CurrentWorkingDirectory: "N/A", + HomeDirectory: "N/A", + ExitCode: 0, + StartTime: time.Unix(0, 0).UTC(), + EndTime: time.Unix(0, 0).UTC(), + DeviceId: "OpenAI", + EntryId: "OpenAI", + } + entries = append(entries, &entry) + row, err := lib.BuildTableRow(ctx, columnNames, entry) + if err != nil { + return nil, nil, fmt.Errorf("failed to build row for entry=%#v: %w", entry, err) + } + rows = append(rows, row) + } + hctx.GetLogger().Infof("getRowsFromAiSuggestions(%#v) ==> %#v", query, suggestions) + return rows, entries, nil +} + func getRows(ctx context.Context, columnNames []string, query string, numEntries int) ([]table.Row, []*data.HistoryEntry, error) { db := hctx.GetDb(ctx) config := hctx.GetConf(ctx) + if config.BetaMode && strings.HasPrefix(query, "?") && len(query) > 1 { + return getRowsFromAiSuggestions(ctx, columnNames, query) + } searchResults, err := lib.Search(ctx, db, query, numEntries) if err != nil { return nil, nil, err diff --git a/go.mod b/go.mod index a0111e2..5414551 100644 --- a/go.mod +++ b/go.mod @@ -244,6 +244,7 @@ require ( github.com/vbatts/tar-split v0.11.2 // indirect github.com/xanzy/go-gitlab v0.73.1 // indirect github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 // indirect + github.com/zmwangx/debounce v1.0.0 // indirect go.etcd.io/bbolt v1.3.6 // indirect go.etcd.io/etcd/api/v3 v3.6.0-alpha.0 // indirect go.etcd.io/etcd/client/pkg/v3 v3.6.0-alpha.0 // indirect diff --git a/go.sum b/go.sum index 088bd08..9153907 100644 --- a/go.sum +++ b/go.sum @@ -1538,6 +1538,8 @@ github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1 github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/zalando/go-keyring v0.1.0/go.mod h1:RaxNwUITJaHVdQ0VC7pELPZ3tOWn13nr0gZMZEhpVU0= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +github.com/zmwangx/debounce v1.0.0 h1:Dyf+WfLESjc2bqFKHgI1dZTW9oh6CJm8SBDkhXrwLB4= +github.com/zmwangx/debounce v1.0.0/go.mod h1:U+/QHt+bSMdUh8XKOb6U+MQV5Ew4eS8M3ua5WJ7Ns6I= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= diff --git a/scripts/aimain.go b/scripts/aimain.go new file mode 100644 index 0000000..2736c9a --- /dev/null +++ b/scripts/aimain.go @@ -0,0 +1,17 @@ +package main + +import ( + "fmt" + "log" + "strings" + + "github.com/ddworken/hishtory/shared/ai" +) + +func main() { + resp, err := ai.GetAiSuggestions("Find all CSV files in the current directory or subdirectories and select the first column, then prepend `foo` to each line", 3) + if err != nil { + log.Fatal(err) + } + fmt.Println(strings.Join(resp, "\n")) +} diff --git a/shared/ai/ai.go b/shared/ai/ai.go new file mode 100644 index 0000000..525099a --- /dev/null +++ b/shared/ai/ai.go @@ -0,0 +1,161 @@ +package ai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "time" + + "github.com/ddworken/hishtory/client/hctx" + "github.com/zmwangx/debounce" + "golang.org/x/exp/slices" +) + +type openAiRequest struct { + Model string `json:"model"` + Messages []openAiMessage `json:"messages"` + NumberCompletions int `json:"n"` +} + +type openAiMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type openAiResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + Usage openAiUsage `json:"usage"` + Choices []openAiChoice `json:"choices"` +} + +type openAiChoice struct { + Index int `json:"index"` + Message openAiMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type openAiUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type debouncedAiSuggestionsIn struct { + ctx context.Context + query string + numberCompletions int +} + +type debouncedAiSuggestionsOut struct { + suggestions []string + err error +} + +var debouncedGetAiSuggestionsInner func(...debouncedAiSuggestionsIn) debouncedAiSuggestionsOut +var debouncedControl debounce.ControlWithReturnValue[debouncedAiSuggestionsOut] + +func init() { + openAiDebouncePeriod := time.Millisecond * 1000 + debouncedGetAiSuggestionsInner, debouncedControl = debounce.DebounceWithCustomSignature( + func(input ...debouncedAiSuggestionsIn) debouncedAiSuggestionsOut { + if len(input) != 1 { + return debouncedAiSuggestionsOut{nil, fmt.Errorf("unexpected input length: %d", len(input))} + } + suggestions, err := GetAiSuggestions(input[0].ctx, input[0].query, input[0].numberCompletions) + return debouncedAiSuggestionsOut{suggestions, err} + }, + openAiDebouncePeriod, + debounce.WithLeading(false), + debounce.WithTrailing(true), + debounce.WithMaxWait(openAiDebouncePeriod), + ) + go func() { + for { + debouncedControl.Flush() + time.Sleep(openAiDebouncePeriod) + } + }() +} + +func DebouncedGetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) { + resp := debouncedGetAiSuggestionsInner(debouncedAiSuggestionsIn{ctx, query, numberCompletions}) + return resp.suggestions, resp.err +} + +func GetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) { + if os.Getenv("OPENAI_API_KEY") == "" { + panic("TODO: Implement integration with the hishtory API") + } else { + return GetAiSuggestionsViaOpenAiApi(query, numberCompletions) + } +} + +func GetAiSuggestionsViaOpenAiApi(query string, numberCompletions int) ([]string, error) { + hctx.GetLogger().Infof("Running OpenAI query for %#v", query) + apiKey := os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + return nil, fmt.Errorf("OPENAI_API_KEY environment variable is not set") + } + client := &http.Client{} + apiReq := openAiRequest{ + Model: "gpt-3.5-turbo", + NumberCompletions: numberCompletions, + Messages: []openAiMessage{ + {Role: "system", Content: "You are an expert programmer that loves to help people with writing shell commands. You always reply with just a shell command and no additional context or information. Your replies will be directly executed in bash, so ensure that they are correct and do not contain anything other than a bash command."}, + {Role: "user", Content: query}, + }, + } + apiReqStr, err := json.Marshal(apiReq) + if err != nil { + return nil, fmt.Errorf("failed to serialize JSON for OpenAI API: %w", err) + } + req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", bytes.NewBuffer(apiReqStr)) + if err != nil { + return nil, fmt.Errorf("failed to create OpenAI API request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to query OpenAI API: %w", err) + } + defer resp.Body.Close() + bodyText, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read OpenAI API response: %w", err) + } + var apiResp openAiResponse + err = json.Unmarshal(bodyText, &apiResp) + if err != nil { + return nil, fmt.Errorf("failed to parse OpenAI API response: %w", err) + } + if len(apiResp.Choices) == 0 { + return nil, fmt.Errorf("OpenAI API returned zero choicesm, resp=%#v", apiResp) + } + ret := make([]string, 0) + for _, item := range apiResp.Choices { + if !slices.Contains(ret, item.Message.Content) { + ret = append(ret, item.Message.Content) + } + } + hctx.GetLogger().Infof("For OpenAI query=%#v ==> %#v", query, ret) + return ret, nil +} + +type AiSuggestionRequest struct { + DeviceId string `json:"device_id"` + UserId string `json:"user_id"` + Query string `json:"query"` + NumberCompletions int `json:"number_completions"` +} + +type AiSuggestionResponse struct { + Suggestions []string `json:"suggestions"` +} diff --git a/shared/ai/ai_test.go b/shared/ai/ai_test.go new file mode 100644 index 0000000..ae204ec --- /dev/null +++ b/shared/ai/ai_test.go @@ -0,0 +1,19 @@ +package ai + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetAiSuggestion(t *testing.T) { + suggestions, err := ai.GetAiSuggestions("list files by size") + require.NoError(t, err) + for _, suggestion := range suggestions { + if strings.Contains(suggestion, "ls") { + return + } + } + t.Fatalf("none of the AI suggestions %#v contain 'ls' which is suspicious", suggestions) +}