Add initial version of AI searching, but with a broken implementation of debouncing

This commit is contained in:
David Dworken 2023-11-11 17:41:24 -08:00
parent af079cd4c9
commit eb835fe52c
9 changed files with 273 additions and 1 deletions

View File

@ -12,6 +12,7 @@ import (
"time"
"github.com/ddworken/hishtory/shared"
"github.com/ddworken/hishtory/shared/ai"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
)
@ -272,6 +273,7 @@ func (s *Server) addDeletionRequestHandler(w http.ResponseWriter, r *http.Reques
func (s *Server) slsaStatusHandler(w http.ResponseWriter, r *http.Request) {
// returns "OK" unless there is a current SLSA bug
v := getHishtoryVersion(r)
// TODO: Migrate this to a version parsing library
if !strings.Contains(v, "v0.") {
w.Write([]byte("OK"))
return
@ -306,6 +308,35 @@ func (s *Server) feedbackHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
func (s *Server) aiSuggestionHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var req ai.AiSuggestionRequest
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
panic(fmt.Errorf("failed to decode: %w", err))
}
if req.NumberCompletions > 10 {
panic(fmt.Errorf("request for %d completions is greater than max allowed", req.NumberCompletions))
}
numDevices, err := s.db.CountDevicesForUser(ctx, req.UserId)
if err != nil {
panic(fmt.Errorf("failed to count devices for user: %w", err))
}
if numDevices == 0 {
panic(fmt.Errorf("rejecting OpenAI request for user_id=%#v since it does not exist", req.UserId))
}
suggestions, err := ai.GetAiSuggestionsViaOpenAiApi(req.Query, req.NumberCompletions)
if err != nil {
panic(fmt.Errorf("failed to query OpenAI API: %w", err))
}
s.statsd.Incr("hishtory.openai.query", []string{}, float64(req.NumberCompletions))
var resp ai.AiSuggestionResponse
resp.Suggestions = suggestions
if err := json.NewEncoder(w).Encode(resp); err != nil {
panic(fmt.Errorf("failed to JSON marshall the API response: %w", err))
}
}
func (s *Server) pingHandler(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}

View File

@ -112,6 +112,7 @@ func (s *Server) Run(ctx context.Context, addr string) error {
mux.Handle("/api/v1/add-deletion-request", middlewares(http.HandlerFunc(s.addDeletionRequestHandler)))
mux.Handle("/api/v1/slsa-status", middlewares(http.HandlerFunc(s.slsaStatusHandler)))
mux.Handle("/api/v1/feedback", middlewares(http.HandlerFunc(s.feedbackHandler)))
mux.Handle("/api/v1/ai-suggest", middlewares(http.HandlerFunc(s.aiSuggestionHandler)))
mux.Handle("/api/v1/ping", middlewares(http.HandlerFunc(s.pingHandler)))
mux.Handle("/healthcheck", middlewares(http.HandlerFunc(s.healthCheckHandler)))
mux.Handle("/internal/api/v1/usage-stats", middlewares(http.HandlerFunc(s.usageStatsHandler)))

View File

@ -94,7 +94,11 @@ func BuildTableRow(ctx context.Context, columnNames []string, entry data.History
case "CWD", "cwd":
row = append(row, entry.CurrentWorkingDirectory)
case "Timestamp", "timestamp":
row = append(row, entry.StartTime.Local().Format(hctx.GetConf(ctx).TimestampFormat))
if entry.StartTime.UnixMilli() == 0 {
row = append(row, "N/A")
} else {
row = append(row, entry.StartTime.Local().Format(hctx.GetConf(ctx).TimestampFormat))
}
case "Runtime", "runtime":
if entry.EndTime.UnixMilli() == 0 {
// An EndTime of zero means this is a pre-saved entry that never finished

View File

@ -22,6 +22,7 @@ import (
"github.com/ddworken/hishtory/client/lib"
"github.com/ddworken/hishtory/client/table"
"github.com/ddworken/hishtory/shared"
"github.com/ddworken/hishtory/shared/ai"
"github.com/muesli/termenv"
"golang.org/x/term"
)
@ -436,9 +437,44 @@ func renderNullableTable(m model) string {
return baseStyle.Render(m.table.View())
}
func getRowsFromAiSuggestions(ctx context.Context, columnNames []string, query string) ([]table.Row, []*data.HistoryEntry, error) {
// TODO: Add debouncing here so we don't waste API queries on half-typed search queries
suggestions, err := ai.DebouncedGetAiSuggestions(ctx, strings.TrimPrefix(query, "?"), 5)
if err != nil {
return nil, nil, fmt.Errorf("failed to get AI query suggestions: %w", err)
}
var rows []table.Row
var entries []*data.HistoryEntry
for _, suggestion := range suggestions {
entry := data.HistoryEntry{
LocalUsername: "OpenAI",
Hostname: "OpenAI",
Command: suggestion,
CurrentWorkingDirectory: "N/A",
HomeDirectory: "N/A",
ExitCode: 0,
StartTime: time.Unix(0, 0).UTC(),
EndTime: time.Unix(0, 0).UTC(),
DeviceId: "OpenAI",
EntryId: "OpenAI",
}
entries = append(entries, &entry)
row, err := lib.BuildTableRow(ctx, columnNames, entry)
if err != nil {
return nil, nil, fmt.Errorf("failed to build row for entry=%#v: %w", entry, err)
}
rows = append(rows, row)
}
hctx.GetLogger().Infof("getRowsFromAiSuggestions(%#v) ==> %#v", query, suggestions)
return rows, entries, nil
}
func getRows(ctx context.Context, columnNames []string, query string, numEntries int) ([]table.Row, []*data.HistoryEntry, error) {
db := hctx.GetDb(ctx)
config := hctx.GetConf(ctx)
if config.BetaMode && strings.HasPrefix(query, "?") && len(query) > 1 {
return getRowsFromAiSuggestions(ctx, columnNames, query)
}
searchResults, err := lib.Search(ctx, db, query, numEntries)
if err != nil {
return nil, nil, err

1
go.mod
View File

@ -244,6 +244,7 @@ require (
github.com/vbatts/tar-split v0.11.2 // indirect
github.com/xanzy/go-gitlab v0.73.1 // indirect
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 // indirect
github.com/zmwangx/debounce v1.0.0 // indirect
go.etcd.io/bbolt v1.3.6 // indirect
go.etcd.io/etcd/api/v3 v3.6.0-alpha.0 // indirect
go.etcd.io/etcd/client/pkg/v3 v3.6.0-alpha.0 // indirect

2
go.sum
View File

@ -1538,6 +1538,8 @@ github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/zalando/go-keyring v0.1.0/go.mod h1:RaxNwUITJaHVdQ0VC7pELPZ3tOWn13nr0gZMZEhpVU0=
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
github.com/zmwangx/debounce v1.0.0 h1:Dyf+WfLESjc2bqFKHgI1dZTW9oh6CJm8SBDkhXrwLB4=
github.com/zmwangx/debounce v1.0.0/go.mod h1:U+/QHt+bSMdUh8XKOb6U+MQV5Ew4eS8M3ua5WJ7Ns6I=
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=

17
scripts/aimain.go Normal file
View File

@ -0,0 +1,17 @@
package main
import (
"fmt"
"log"
"strings"
"github.com/ddworken/hishtory/shared/ai"
)
func main() {
resp, err := ai.GetAiSuggestions("Find all CSV files in the current directory or subdirectories and select the first column, then prepend `foo` to each line", 3)
if err != nil {
log.Fatal(err)
}
fmt.Println(strings.Join(resp, "\n"))
}

161
shared/ai/ai.go Normal file
View File

@ -0,0 +1,161 @@
package ai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"time"
"github.com/ddworken/hishtory/client/hctx"
"github.com/zmwangx/debounce"
"golang.org/x/exp/slices"
)
type openAiRequest struct {
Model string `json:"model"`
Messages []openAiMessage `json:"messages"`
NumberCompletions int `json:"n"`
}
type openAiMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type openAiResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
Usage openAiUsage `json:"usage"`
Choices []openAiChoice `json:"choices"`
}
type openAiChoice struct {
Index int `json:"index"`
Message openAiMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type openAiUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type debouncedAiSuggestionsIn struct {
ctx context.Context
query string
numberCompletions int
}
type debouncedAiSuggestionsOut struct {
suggestions []string
err error
}
var debouncedGetAiSuggestionsInner func(...debouncedAiSuggestionsIn) debouncedAiSuggestionsOut
var debouncedControl debounce.ControlWithReturnValue[debouncedAiSuggestionsOut]
func init() {
openAiDebouncePeriod := time.Millisecond * 1000
debouncedGetAiSuggestionsInner, debouncedControl = debounce.DebounceWithCustomSignature(
func(input ...debouncedAiSuggestionsIn) debouncedAiSuggestionsOut {
if len(input) != 1 {
return debouncedAiSuggestionsOut{nil, fmt.Errorf("unexpected input length: %d", len(input))}
}
suggestions, err := GetAiSuggestions(input[0].ctx, input[0].query, input[0].numberCompletions)
return debouncedAiSuggestionsOut{suggestions, err}
},
openAiDebouncePeriod,
debounce.WithLeading(false),
debounce.WithTrailing(true),
debounce.WithMaxWait(openAiDebouncePeriod),
)
go func() {
for {
debouncedControl.Flush()
time.Sleep(openAiDebouncePeriod)
}
}()
}
func DebouncedGetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) {
resp := debouncedGetAiSuggestionsInner(debouncedAiSuggestionsIn{ctx, query, numberCompletions})
return resp.suggestions, resp.err
}
func GetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) {
if os.Getenv("OPENAI_API_KEY") == "" {
panic("TODO: Implement integration with the hishtory API")
} else {
return GetAiSuggestionsViaOpenAiApi(query, numberCompletions)
}
}
func GetAiSuggestionsViaOpenAiApi(query string, numberCompletions int) ([]string, error) {
hctx.GetLogger().Infof("Running OpenAI query for %#v", query)
apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey == "" {
return nil, fmt.Errorf("OPENAI_API_KEY environment variable is not set")
}
client := &http.Client{}
apiReq := openAiRequest{
Model: "gpt-3.5-turbo",
NumberCompletions: numberCompletions,
Messages: []openAiMessage{
{Role: "system", Content: "You are an expert programmer that loves to help people with writing shell commands. You always reply with just a shell command and no additional context or information. Your replies will be directly executed in bash, so ensure that they are correct and do not contain anything other than a bash command."},
{Role: "user", Content: query},
},
}
apiReqStr, err := json.Marshal(apiReq)
if err != nil {
return nil, fmt.Errorf("failed to serialize JSON for OpenAI API: %w", err)
}
req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", bytes.NewBuffer(apiReqStr))
if err != nil {
return nil, fmt.Errorf("failed to create OpenAI API request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to query OpenAI API: %w", err)
}
defer resp.Body.Close()
bodyText, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read OpenAI API response: %w", err)
}
var apiResp openAiResponse
err = json.Unmarshal(bodyText, &apiResp)
if err != nil {
return nil, fmt.Errorf("failed to parse OpenAI API response: %w", err)
}
if len(apiResp.Choices) == 0 {
return nil, fmt.Errorf("OpenAI API returned zero choicesm, resp=%#v", apiResp)
}
ret := make([]string, 0)
for _, item := range apiResp.Choices {
if !slices.Contains(ret, item.Message.Content) {
ret = append(ret, item.Message.Content)
}
}
hctx.GetLogger().Infof("For OpenAI query=%#v ==> %#v", query, ret)
return ret, nil
}
type AiSuggestionRequest struct {
DeviceId string `json:"device_id"`
UserId string `json:"user_id"`
Query string `json:"query"`
NumberCompletions int `json:"number_completions"`
}
type AiSuggestionResponse struct {
Suggestions []string `json:"suggestions"`
}

19
shared/ai/ai_test.go Normal file
View File

@ -0,0 +1,19 @@
package ai
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestGetAiSuggestion(t *testing.T) {
suggestions, err := ai.GetAiSuggestions("list files by size")
require.NoError(t, err)
for _, suggestion := range suggestions {
if strings.Contains(suggestion, "ls") {
return
}
}
t.Fatalf("none of the AI suggestions %#v contain 'ls' which is suspicious", suggestions)
}