mirror of
https://github.com/ddworken/hishtory.git
synced 2025-02-16 18:41:03 +01:00
Add basic debouncing for AI integration + implement AI suggestions via hishtory API endpoint
This commit is contained in:
parent
eb835fe52c
commit
0ea3ce2399
@ -325,11 +325,12 @@ func (s *Server) aiSuggestionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if numDevices == 0 {
|
||||
panic(fmt.Errorf("rejecting OpenAI request for user_id=%#v since it does not exist", req.UserId))
|
||||
}
|
||||
suggestions, err := ai.GetAiSuggestionsViaOpenAiApi(req.Query, req.NumberCompletions)
|
||||
suggestions, usage, err := ai.GetAiSuggestionsViaOpenAiApi(req.Query, req.NumberCompletions)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to query OpenAI API: %w", err))
|
||||
}
|
||||
s.statsd.Incr("hishtory.openai.query", []string{}, float64(req.NumberCompletions))
|
||||
s.statsd.Incr("hishtory.openai.tokens", []string{}, float64(usage.TotalTokens))
|
||||
var resp ai.AiSuggestionResponse
|
||||
resp.Suggestions = suggestions
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
|
57
client/ai/ai.go
Normal file
57
client/ai/ai.go
Normal file
@ -0,0 +1,57 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/ddworken/hishtory/client/data"
|
||||
"github.com/ddworken/hishtory/client/hctx"
|
||||
"github.com/ddworken/hishtory/client/lib"
|
||||
"github.com/ddworken/hishtory/shared/ai"
|
||||
)
|
||||
|
||||
var mostRecentQuery string
|
||||
|
||||
func DebouncedGetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) {
|
||||
mostRecentQuery = query
|
||||
time.Sleep(time.Millisecond * 300)
|
||||
if mostRecentQuery == query {
|
||||
return GetAiSuggestions(ctx, query, numberCompletions)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func GetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) {
|
||||
if os.Getenv("OPENAI_API_KEY") == "" {
|
||||
return GetAiSuggestionsViaHishtoryApi(ctx, query, numberCompletions)
|
||||
} else {
|
||||
suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(query, numberCompletions)
|
||||
return suggestions, err
|
||||
}
|
||||
}
|
||||
|
||||
func GetAiSuggestionsViaHishtoryApi(ctx context.Context, query string, numberCompletions int) ([]string, error) {
|
||||
req := ai.AiSuggestionRequest{
|
||||
DeviceId: hctx.GetConf(ctx).DeviceId,
|
||||
UserId: data.UserId(hctx.GetConf(ctx).UserSecret),
|
||||
Query: query,
|
||||
NumberCompletions: numberCompletions,
|
||||
}
|
||||
reqData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal AiSuggestionRequest: %w", err)
|
||||
}
|
||||
respData, err := lib.ApiPost(ctx, "/api/v1/ai-suggest", "application/json", reqData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query /api/v1/ai-suggest: %w", err)
|
||||
}
|
||||
var resp ai.AiSuggestionResponse
|
||||
err = json.Unmarshal(respData, &resp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse /api/v1/ai-suggest response: %w", err)
|
||||
}
|
||||
return resp.Suggestions, nil
|
||||
}
|
@ -17,12 +17,12 @@ import (
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/ddworken/hishtory/client/ai"
|
||||
"github.com/ddworken/hishtory/client/data"
|
||||
"github.com/ddworken/hishtory/client/hctx"
|
||||
"github.com/ddworken/hishtory/client/lib"
|
||||
"github.com/ddworken/hishtory/client/table"
|
||||
"github.com/ddworken/hishtory/shared"
|
||||
"github.com/ddworken/hishtory/shared/ai"
|
||||
"github.com/muesli/termenv"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
resp, err := ai.GetAiSuggestions("Find all CSV files in the current directory or subdirectories and select the first column, then prepend `foo` to each line", 3)
|
||||
resp, _, err := ai.GetAiSuggestionsViaOpenAiApi("Find all CSV files in the current directory or subdirectories and select the first column, then prepend `foo` to each line", 3)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -2,16 +2,13 @@ package ai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/ddworken/hishtory/client/hctx"
|
||||
"github.com/zmwangx/debounce"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
@ -31,7 +28,7 @@ type openAiResponse struct {
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Usage openAiUsage `json:"usage"`
|
||||
Usage OpenAiUsage `json:"usage"`
|
||||
Choices []openAiChoice `json:"choices"`
|
||||
}
|
||||
|
||||
@ -41,67 +38,22 @@ type openAiChoice struct {
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type openAiUsage struct {
|
||||
type OpenAiUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type debouncedAiSuggestionsIn struct {
|
||||
ctx context.Context
|
||||
query string
|
||||
numberCompletions int
|
||||
}
|
||||
var testOnlyOverrideAiSuggestions map[string][]string
|
||||
|
||||
type debouncedAiSuggestionsOut struct {
|
||||
suggestions []string
|
||||
err error
|
||||
}
|
||||
|
||||
var debouncedGetAiSuggestionsInner func(...debouncedAiSuggestionsIn) debouncedAiSuggestionsOut
|
||||
var debouncedControl debounce.ControlWithReturnValue[debouncedAiSuggestionsOut]
|
||||
|
||||
func init() {
|
||||
openAiDebouncePeriod := time.Millisecond * 1000
|
||||
debouncedGetAiSuggestionsInner, debouncedControl = debounce.DebounceWithCustomSignature(
|
||||
func(input ...debouncedAiSuggestionsIn) debouncedAiSuggestionsOut {
|
||||
if len(input) != 1 {
|
||||
return debouncedAiSuggestionsOut{nil, fmt.Errorf("unexpected input length: %d", len(input))}
|
||||
}
|
||||
suggestions, err := GetAiSuggestions(input[0].ctx, input[0].query, input[0].numberCompletions)
|
||||
return debouncedAiSuggestionsOut{suggestions, err}
|
||||
},
|
||||
openAiDebouncePeriod,
|
||||
debounce.WithLeading(false),
|
||||
debounce.WithTrailing(true),
|
||||
debounce.WithMaxWait(openAiDebouncePeriod),
|
||||
)
|
||||
go func() {
|
||||
for {
|
||||
debouncedControl.Flush()
|
||||
time.Sleep(openAiDebouncePeriod)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func DebouncedGetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) {
|
||||
resp := debouncedGetAiSuggestionsInner(debouncedAiSuggestionsIn{ctx, query, numberCompletions})
|
||||
return resp.suggestions, resp.err
|
||||
}
|
||||
|
||||
func GetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) {
|
||||
if os.Getenv("OPENAI_API_KEY") == "" {
|
||||
panic("TODO: Implement integration with the hishtory API")
|
||||
} else {
|
||||
return GetAiSuggestionsViaOpenAiApi(query, numberCompletions)
|
||||
func GetAiSuggestionsViaOpenAiApi(query string, numberCompletions int) ([]string, OpenAiUsage, error) {
|
||||
if results := testOnlyOverrideAiSuggestions[query]; len(results) > 0 {
|
||||
return results, OpenAiUsage{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func GetAiSuggestionsViaOpenAiApi(query string, numberCompletions int) ([]string, error) {
|
||||
hctx.GetLogger().Infof("Running OpenAI query for %#v", query)
|
||||
apiKey := os.Getenv("OPENAI_API_KEY")
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("OPENAI_API_KEY environment variable is not set")
|
||||
return nil, OpenAiUsage{}, fmt.Errorf("OPENAI_API_KEY environment variable is not set")
|
||||
}
|
||||
client := &http.Client{}
|
||||
apiReq := openAiRequest{
|
||||
@ -114,30 +66,30 @@ func GetAiSuggestionsViaOpenAiApi(query string, numberCompletions int) ([]string
|
||||
}
|
||||
apiReqStr, err := json.Marshal(apiReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to serialize JSON for OpenAI API: %w", err)
|
||||
return nil, OpenAiUsage{}, fmt.Errorf("failed to serialize JSON for OpenAI API: %w", err)
|
||||
}
|
||||
req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", bytes.NewBuffer(apiReqStr))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OpenAI API request: %w", err)
|
||||
return nil, OpenAiUsage{}, fmt.Errorf("failed to create OpenAI API request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query OpenAI API: %w", err)
|
||||
return nil, OpenAiUsage{}, fmt.Errorf("failed to query OpenAI API: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
bodyText, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read OpenAI API response: %w", err)
|
||||
return nil, OpenAiUsage{}, fmt.Errorf("failed to read OpenAI API response: %w", err)
|
||||
}
|
||||
var apiResp openAiResponse
|
||||
err = json.Unmarshal(bodyText, &apiResp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse OpenAI API response: %w", err)
|
||||
return nil, OpenAiUsage{}, fmt.Errorf("failed to parse OpenAI API response: %w", err)
|
||||
}
|
||||
if len(apiResp.Choices) == 0 {
|
||||
return nil, fmt.Errorf("OpenAI API returned zero choicesm, resp=%#v", apiResp)
|
||||
return nil, OpenAiUsage{}, fmt.Errorf("OpenAI API returned zero choicesm, resp=%#v", apiResp)
|
||||
}
|
||||
ret := make([]string, 0)
|
||||
for _, item := range apiResp.Choices {
|
||||
@ -146,7 +98,7 @@ func GetAiSuggestionsViaOpenAiApi(query string, numberCompletions int) ([]string
|
||||
}
|
||||
}
|
||||
hctx.GetLogger().Infof("For OpenAI query=%#v ==> %#v", query, ret)
|
||||
return ret, nil
|
||||
return ret, apiResp.Usage, nil
|
||||
}
|
||||
|
||||
type AiSuggestionRequest struct {
|
||||
|
@ -1,19 +0,0 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetAiSuggestion(t *testing.T) {
|
||||
suggestions, err := ai.GetAiSuggestions("list files by size")
|
||||
require.NoError(t, err)
|
||||
for _, suggestion := range suggestions {
|
||||
if strings.Contains(suggestion, "ls") {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatalf("none of the AI suggestions %#v contain 'ls' which is suspicious", suggestions)
|
||||
}
|
Loading…
Reference in New Issue
Block a user