mirror of
https://github.com/ddworken/hishtory.git
synced 2024-12-23 23:39:02 +01:00
Add initial version of AI searching, but with a broken implementation of debouncing
This commit is contained in:
parent
af079cd4c9
commit
eb835fe52c
@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ddworken/hishtory/shared"
|
||||
"github.com/ddworken/hishtory/shared/ai"
|
||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
||||
)
|
||||
|
||||
@ -272,6 +273,7 @@ func (s *Server) addDeletionRequestHandler(w http.ResponseWriter, r *http.Reques
|
||||
func (s *Server) slsaStatusHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// returns "OK" unless there is a current SLSA bug
|
||||
v := getHishtoryVersion(r)
|
||||
// TODO: Migrate this to a version parsing library
|
||||
if !strings.Contains(v, "v0.") {
|
||||
w.Write([]byte("OK"))
|
||||
return
|
||||
@ -306,6 +308,35 @@ func (s *Server) feedbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (s *Server) aiSuggestionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
var req ai.AiSuggestionRequest
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to decode: %w", err))
|
||||
}
|
||||
if req.NumberCompletions > 10 {
|
||||
panic(fmt.Errorf("request for %d completions is greater than max allowed", req.NumberCompletions))
|
||||
}
|
||||
numDevices, err := s.db.CountDevicesForUser(ctx, req.UserId)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to count devices for user: %w", err))
|
||||
}
|
||||
if numDevices == 0 {
|
||||
panic(fmt.Errorf("rejecting OpenAI request for user_id=%#v since it does not exist", req.UserId))
|
||||
}
|
||||
suggestions, err := ai.GetAiSuggestionsViaOpenAiApi(req.Query, req.NumberCompletions)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to query OpenAI API: %w", err))
|
||||
}
|
||||
s.statsd.Incr("hishtory.openai.query", []string{}, float64(req.NumberCompletions))
|
||||
var resp ai.AiSuggestionResponse
|
||||
resp.Suggestions = suggestions
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
panic(fmt.Errorf("failed to JSON marshall the API response: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) pingHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("OK"))
|
||||
}
|
||||
|
@ -112,6 +112,7 @@ func (s *Server) Run(ctx context.Context, addr string) error {
|
||||
mux.Handle("/api/v1/add-deletion-request", middlewares(http.HandlerFunc(s.addDeletionRequestHandler)))
|
||||
mux.Handle("/api/v1/slsa-status", middlewares(http.HandlerFunc(s.slsaStatusHandler)))
|
||||
mux.Handle("/api/v1/feedback", middlewares(http.HandlerFunc(s.feedbackHandler)))
|
||||
mux.Handle("/api/v1/ai-suggest", middlewares(http.HandlerFunc(s.aiSuggestionHandler)))
|
||||
mux.Handle("/api/v1/ping", middlewares(http.HandlerFunc(s.pingHandler)))
|
||||
mux.Handle("/healthcheck", middlewares(http.HandlerFunc(s.healthCheckHandler)))
|
||||
mux.Handle("/internal/api/v1/usage-stats", middlewares(http.HandlerFunc(s.usageStatsHandler)))
|
||||
|
@ -94,7 +94,11 @@ func BuildTableRow(ctx context.Context, columnNames []string, entry data.History
|
||||
case "CWD", "cwd":
|
||||
row = append(row, entry.CurrentWorkingDirectory)
|
||||
case "Timestamp", "timestamp":
|
||||
row = append(row, entry.StartTime.Local().Format(hctx.GetConf(ctx).TimestampFormat))
|
||||
if entry.StartTime.UnixMilli() == 0 {
|
||||
row = append(row, "N/A")
|
||||
} else {
|
||||
row = append(row, entry.StartTime.Local().Format(hctx.GetConf(ctx).TimestampFormat))
|
||||
}
|
||||
case "Runtime", "runtime":
|
||||
if entry.EndTime.UnixMilli() == 0 {
|
||||
// An EndTime of zero means this is a pre-saved entry that never finished
|
||||
|
@ -22,6 +22,7 @@ import (
|
||||
"github.com/ddworken/hishtory/client/lib"
|
||||
"github.com/ddworken/hishtory/client/table"
|
||||
"github.com/ddworken/hishtory/shared"
|
||||
"github.com/ddworken/hishtory/shared/ai"
|
||||
"github.com/muesli/termenv"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
@ -436,9 +437,44 @@ func renderNullableTable(m model) string {
|
||||
return baseStyle.Render(m.table.View())
|
||||
}
|
||||
|
||||
func getRowsFromAiSuggestions(ctx context.Context, columnNames []string, query string) ([]table.Row, []*data.HistoryEntry, error) {
|
||||
// TODO: Add debouncing here so we don't waste API queries on half-typed search queries
|
||||
suggestions, err := ai.DebouncedGetAiSuggestions(ctx, strings.TrimPrefix(query, "?"), 5)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get AI query suggestions: %w", err)
|
||||
}
|
||||
var rows []table.Row
|
||||
var entries []*data.HistoryEntry
|
||||
for _, suggestion := range suggestions {
|
||||
entry := data.HistoryEntry{
|
||||
LocalUsername: "OpenAI",
|
||||
Hostname: "OpenAI",
|
||||
Command: suggestion,
|
||||
CurrentWorkingDirectory: "N/A",
|
||||
HomeDirectory: "N/A",
|
||||
ExitCode: 0,
|
||||
StartTime: time.Unix(0, 0).UTC(),
|
||||
EndTime: time.Unix(0, 0).UTC(),
|
||||
DeviceId: "OpenAI",
|
||||
EntryId: "OpenAI",
|
||||
}
|
||||
entries = append(entries, &entry)
|
||||
row, err := lib.BuildTableRow(ctx, columnNames, entry)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to build row for entry=%#v: %w", entry, err)
|
||||
}
|
||||
rows = append(rows, row)
|
||||
}
|
||||
hctx.GetLogger().Infof("getRowsFromAiSuggestions(%#v) ==> %#v", query, suggestions)
|
||||
return rows, entries, nil
|
||||
}
|
||||
|
||||
func getRows(ctx context.Context, columnNames []string, query string, numEntries int) ([]table.Row, []*data.HistoryEntry, error) {
|
||||
db := hctx.GetDb(ctx)
|
||||
config := hctx.GetConf(ctx)
|
||||
if config.BetaMode && strings.HasPrefix(query, "?") && len(query) > 1 {
|
||||
return getRowsFromAiSuggestions(ctx, columnNames, query)
|
||||
}
|
||||
searchResults, err := lib.Search(ctx, db, query, numEntries)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
1
go.mod
1
go.mod
@ -244,6 +244,7 @@ require (
|
||||
github.com/vbatts/tar-split v0.11.2 // indirect
|
||||
github.com/xanzy/go-gitlab v0.73.1 // indirect
|
||||
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 // indirect
|
||||
github.com/zmwangx/debounce v1.0.0 // indirect
|
||||
go.etcd.io/bbolt v1.3.6 // indirect
|
||||
go.etcd.io/etcd/api/v3 v3.6.0-alpha.0 // indirect
|
||||
go.etcd.io/etcd/client/pkg/v3 v3.6.0-alpha.0 // indirect
|
||||
|
2
go.sum
2
go.sum
@ -1538,6 +1538,8 @@ github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
github.com/zalando/go-keyring v0.1.0/go.mod h1:RaxNwUITJaHVdQ0VC7pELPZ3tOWn13nr0gZMZEhpVU0=
|
||||
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
|
||||
github.com/zmwangx/debounce v1.0.0 h1:Dyf+WfLESjc2bqFKHgI1dZTW9oh6CJm8SBDkhXrwLB4=
|
||||
github.com/zmwangx/debounce v1.0.0/go.mod h1:U+/QHt+bSMdUh8XKOb6U+MQV5Ew4eS8M3ua5WJ7Ns6I=
|
||||
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
|
||||
go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
|
||||
go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
|
||||
|
17
scripts/aimain.go
Normal file
17
scripts/aimain.go
Normal file
@ -0,0 +1,17 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/ddworken/hishtory/shared/ai"
|
||||
)
|
||||
|
||||
func main() {
|
||||
resp, err := ai.GetAiSuggestions("Find all CSV files in the current directory or subdirectories and select the first column, then prepend `foo` to each line", 3)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println(strings.Join(resp, "\n"))
|
||||
}
|
161
shared/ai/ai.go
Normal file
161
shared/ai/ai.go
Normal file
@ -0,0 +1,161 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/ddworken/hishtory/client/hctx"
|
||||
"github.com/zmwangx/debounce"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
type openAiRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []openAiMessage `json:"messages"`
|
||||
NumberCompletions int `json:"n"`
|
||||
}
|
||||
|
||||
type openAiMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type openAiResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Usage openAiUsage `json:"usage"`
|
||||
Choices []openAiChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type openAiChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message openAiMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type openAiUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type debouncedAiSuggestionsIn struct {
|
||||
ctx context.Context
|
||||
query string
|
||||
numberCompletions int
|
||||
}
|
||||
|
||||
type debouncedAiSuggestionsOut struct {
|
||||
suggestions []string
|
||||
err error
|
||||
}
|
||||
|
||||
var debouncedGetAiSuggestionsInner func(...debouncedAiSuggestionsIn) debouncedAiSuggestionsOut
|
||||
var debouncedControl debounce.ControlWithReturnValue[debouncedAiSuggestionsOut]
|
||||
|
||||
func init() {
|
||||
openAiDebouncePeriod := time.Millisecond * 1000
|
||||
debouncedGetAiSuggestionsInner, debouncedControl = debounce.DebounceWithCustomSignature(
|
||||
func(input ...debouncedAiSuggestionsIn) debouncedAiSuggestionsOut {
|
||||
if len(input) != 1 {
|
||||
return debouncedAiSuggestionsOut{nil, fmt.Errorf("unexpected input length: %d", len(input))}
|
||||
}
|
||||
suggestions, err := GetAiSuggestions(input[0].ctx, input[0].query, input[0].numberCompletions)
|
||||
return debouncedAiSuggestionsOut{suggestions, err}
|
||||
},
|
||||
openAiDebouncePeriod,
|
||||
debounce.WithLeading(false),
|
||||
debounce.WithTrailing(true),
|
||||
debounce.WithMaxWait(openAiDebouncePeriod),
|
||||
)
|
||||
go func() {
|
||||
for {
|
||||
debouncedControl.Flush()
|
||||
time.Sleep(openAiDebouncePeriod)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func DebouncedGetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) {
|
||||
resp := debouncedGetAiSuggestionsInner(debouncedAiSuggestionsIn{ctx, query, numberCompletions})
|
||||
return resp.suggestions, resp.err
|
||||
}
|
||||
|
||||
func GetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) {
|
||||
if os.Getenv("OPENAI_API_KEY") == "" {
|
||||
panic("TODO: Implement integration with the hishtory API")
|
||||
} else {
|
||||
return GetAiSuggestionsViaOpenAiApi(query, numberCompletions)
|
||||
}
|
||||
}
|
||||
|
||||
func GetAiSuggestionsViaOpenAiApi(query string, numberCompletions int) ([]string, error) {
|
||||
hctx.GetLogger().Infof("Running OpenAI query for %#v", query)
|
||||
apiKey := os.Getenv("OPENAI_API_KEY")
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("OPENAI_API_KEY environment variable is not set")
|
||||
}
|
||||
client := &http.Client{}
|
||||
apiReq := openAiRequest{
|
||||
Model: "gpt-3.5-turbo",
|
||||
NumberCompletions: numberCompletions,
|
||||
Messages: []openAiMessage{
|
||||
{Role: "system", Content: "You are an expert programmer that loves to help people with writing shell commands. You always reply with just a shell command and no additional context or information. Your replies will be directly executed in bash, so ensure that they are correct and do not contain anything other than a bash command."},
|
||||
{Role: "user", Content: query},
|
||||
},
|
||||
}
|
||||
apiReqStr, err := json.Marshal(apiReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to serialize JSON for OpenAI API: %w", err)
|
||||
}
|
||||
req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", bytes.NewBuffer(apiReqStr))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OpenAI API request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query OpenAI API: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
bodyText, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read OpenAI API response: %w", err)
|
||||
}
|
||||
var apiResp openAiResponse
|
||||
err = json.Unmarshal(bodyText, &apiResp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse OpenAI API response: %w", err)
|
||||
}
|
||||
if len(apiResp.Choices) == 0 {
|
||||
return nil, fmt.Errorf("OpenAI API returned zero choicesm, resp=%#v", apiResp)
|
||||
}
|
||||
ret := make([]string, 0)
|
||||
for _, item := range apiResp.Choices {
|
||||
if !slices.Contains(ret, item.Message.Content) {
|
||||
ret = append(ret, item.Message.Content)
|
||||
}
|
||||
}
|
||||
hctx.GetLogger().Infof("For OpenAI query=%#v ==> %#v", query, ret)
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
type AiSuggestionRequest struct {
|
||||
DeviceId string `json:"device_id"`
|
||||
UserId string `json:"user_id"`
|
||||
Query string `json:"query"`
|
||||
NumberCompletions int `json:"number_completions"`
|
||||
}
|
||||
|
||||
type AiSuggestionResponse struct {
|
||||
Suggestions []string `json:"suggestions"`
|
||||
}
|
19
shared/ai/ai_test.go
Normal file
19
shared/ai/ai_test.go
Normal file
@ -0,0 +1,19 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetAiSuggestion(t *testing.T) {
|
||||
suggestions, err := ai.GetAiSuggestions("list files by size")
|
||||
require.NoError(t, err)
|
||||
for _, suggestion := range suggestions {
|
||||
if strings.Contains(suggestion, "ls") {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatalf("none of the AI suggestions %#v contain 'ls' which is suspicious", suggestions)
|
||||
}
|
Loading…
Reference in New Issue
Block a user