mirror of
https://github.com/ddworken/hishtory.git
synced 2025-06-25 14:32:14 +02:00
Add basic debouncing for AI integration + implement AI suggestions via hishtory API endpoint
This commit is contained in:
parent
a735ceca85
commit
bf5fee194f
@ -325,11 +325,12 @@ func (s *Server) aiSuggestionHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
if numDevices == 0 {
|
if numDevices == 0 {
|
||||||
panic(fmt.Errorf("rejecting OpenAI request for user_id=%#v since it does not exist", req.UserId))
|
panic(fmt.Errorf("rejecting OpenAI request for user_id=%#v since it does not exist", req.UserId))
|
||||||
}
|
}
|
||||||
suggestions, err := ai.GetAiSuggestionsViaOpenAiApi(req.Query, req.NumberCompletions)
|
suggestions, usage, err := ai.GetAiSuggestionsViaOpenAiApi(req.Query, req.NumberCompletions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Errorf("failed to query OpenAI API: %w", err))
|
panic(fmt.Errorf("failed to query OpenAI API: %w", err))
|
||||||
}
|
}
|
||||||
s.statsd.Incr("hishtory.openai.query", []string{}, float64(req.NumberCompletions))
|
s.statsd.Incr("hishtory.openai.query", []string{}, float64(req.NumberCompletions))
|
||||||
|
s.statsd.Incr("hishtory.openai.tokens", []string{}, float64(usage.TotalTokens))
|
||||||
var resp ai.AiSuggestionResponse
|
var resp ai.AiSuggestionResponse
|
||||||
resp.Suggestions = suggestions
|
resp.Suggestions = suggestions
|
||||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||||
|
57
client/ai/ai.go
Normal file
57
client/ai/ai.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ddworken/hishtory/client/data"
|
||||||
|
"github.com/ddworken/hishtory/client/hctx"
|
||||||
|
"github.com/ddworken/hishtory/client/lib"
|
||||||
|
"github.com/ddworken/hishtory/shared/ai"
|
||||||
|
)
|
||||||
|
|
||||||
|
var mostRecentQuery string
|
||||||
|
|
||||||
|
func DebouncedGetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) {
|
||||||
|
mostRecentQuery = query
|
||||||
|
time.Sleep(time.Millisecond * 300)
|
||||||
|
if mostRecentQuery == query {
|
||||||
|
return GetAiSuggestions(ctx, query, numberCompletions)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) {
|
||||||
|
if os.Getenv("OPENAI_API_KEY") == "" {
|
||||||
|
return GetAiSuggestionsViaHishtoryApi(ctx, query, numberCompletions)
|
||||||
|
} else {
|
||||||
|
suggestions, _, err := ai.GetAiSuggestionsViaOpenAiApi(query, numberCompletions)
|
||||||
|
return suggestions, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetAiSuggestionsViaHishtoryApi(ctx context.Context, query string, numberCompletions int) ([]string, error) {
|
||||||
|
req := ai.AiSuggestionRequest{
|
||||||
|
DeviceId: hctx.GetConf(ctx).DeviceId,
|
||||||
|
UserId: data.UserId(hctx.GetConf(ctx).UserSecret),
|
||||||
|
Query: query,
|
||||||
|
NumberCompletions: numberCompletions,
|
||||||
|
}
|
||||||
|
reqData, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal AiSuggestionRequest: %w", err)
|
||||||
|
}
|
||||||
|
respData, err := lib.ApiPost(ctx, "/api/v1/ai-suggest", "application/json", reqData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query /api/v1/ai-suggest: %w", err)
|
||||||
|
}
|
||||||
|
var resp ai.AiSuggestionResponse
|
||||||
|
err = json.Unmarshal(respData, &resp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse /api/v1/ai-suggest response: %w", err)
|
||||||
|
}
|
||||||
|
return resp.Suggestions, nil
|
||||||
|
}
|
@ -17,12 +17,12 @@ import (
|
|||||||
"github.com/charmbracelet/bubbles/textinput"
|
"github.com/charmbracelet/bubbles/textinput"
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
"github.com/ddworken/hishtory/client/ai"
|
||||||
"github.com/ddworken/hishtory/client/data"
|
"github.com/ddworken/hishtory/client/data"
|
||||||
"github.com/ddworken/hishtory/client/hctx"
|
"github.com/ddworken/hishtory/client/hctx"
|
||||||
"github.com/ddworken/hishtory/client/lib"
|
"github.com/ddworken/hishtory/client/lib"
|
||||||
"github.com/ddworken/hishtory/client/table"
|
"github.com/ddworken/hishtory/client/table"
|
||||||
"github.com/ddworken/hishtory/shared"
|
"github.com/ddworken/hishtory/shared"
|
||||||
"github.com/ddworken/hishtory/shared/ai"
|
|
||||||
"github.com/muesli/termenv"
|
"github.com/muesli/termenv"
|
||||||
"golang.org/x/term"
|
"golang.org/x/term"
|
||||||
)
|
)
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
resp, err := ai.GetAiSuggestions("Find all CSV files in the current directory or subdirectories and select the first column, then prepend `foo` to each line", 3)
|
resp, _, err := ai.GetAiSuggestionsViaOpenAiApi("Find all CSV files in the current directory or subdirectories and select the first column, then prepend `foo` to each line", 3)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -2,16 +2,13 @@ package ai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ddworken/hishtory/client/hctx"
|
"github.com/ddworken/hishtory/client/hctx"
|
||||||
"github.com/zmwangx/debounce"
|
|
||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -31,7 +28,7 @@ type openAiResponse struct {
|
|||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
Created int `json:"created"`
|
Created int `json:"created"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Usage openAiUsage `json:"usage"`
|
Usage OpenAiUsage `json:"usage"`
|
||||||
Choices []openAiChoice `json:"choices"`
|
Choices []openAiChoice `json:"choices"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,67 +38,22 @@ type openAiChoice struct {
|
|||||||
FinishReason string `json:"finish_reason"`
|
FinishReason string `json:"finish_reason"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type openAiUsage struct {
|
type OpenAiUsage struct {
|
||||||
PromptTokens int `json:"prompt_tokens"`
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
CompletionTokens int `json:"completion_tokens"`
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
TotalTokens int `json:"total_tokens"`
|
TotalTokens int `json:"total_tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type debouncedAiSuggestionsIn struct {
|
var testOnlyOverrideAiSuggestions map[string][]string
|
||||||
ctx context.Context
|
|
||||||
query string
|
|
||||||
numberCompletions int
|
|
||||||
}
|
|
||||||
|
|
||||||
type debouncedAiSuggestionsOut struct {
|
func GetAiSuggestionsViaOpenAiApi(query string, numberCompletions int) ([]string, OpenAiUsage, error) {
|
||||||
suggestions []string
|
if results := testOnlyOverrideAiSuggestions[query]; len(results) > 0 {
|
||||||
err error
|
return results, OpenAiUsage{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var debouncedGetAiSuggestionsInner func(...debouncedAiSuggestionsIn) debouncedAiSuggestionsOut
|
|
||||||
var debouncedControl debounce.ControlWithReturnValue[debouncedAiSuggestionsOut]
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
openAiDebouncePeriod := time.Millisecond * 1000
|
|
||||||
debouncedGetAiSuggestionsInner, debouncedControl = debounce.DebounceWithCustomSignature(
|
|
||||||
func(input ...debouncedAiSuggestionsIn) debouncedAiSuggestionsOut {
|
|
||||||
if len(input) != 1 {
|
|
||||||
return debouncedAiSuggestionsOut{nil, fmt.Errorf("unexpected input length: %d", len(input))}
|
|
||||||
}
|
|
||||||
suggestions, err := GetAiSuggestions(input[0].ctx, input[0].query, input[0].numberCompletions)
|
|
||||||
return debouncedAiSuggestionsOut{suggestions, err}
|
|
||||||
},
|
|
||||||
openAiDebouncePeriod,
|
|
||||||
debounce.WithLeading(false),
|
|
||||||
debounce.WithTrailing(true),
|
|
||||||
debounce.WithMaxWait(openAiDebouncePeriod),
|
|
||||||
)
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
debouncedControl.Flush()
|
|
||||||
time.Sleep(openAiDebouncePeriod)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func DebouncedGetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) {
|
|
||||||
resp := debouncedGetAiSuggestionsInner(debouncedAiSuggestionsIn{ctx, query, numberCompletions})
|
|
||||||
return resp.suggestions, resp.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetAiSuggestions(ctx context.Context, query string, numberCompletions int) ([]string, error) {
|
|
||||||
if os.Getenv("OPENAI_API_KEY") == "" {
|
|
||||||
panic("TODO: Implement integration with the hishtory API")
|
|
||||||
} else {
|
|
||||||
return GetAiSuggestionsViaOpenAiApi(query, numberCompletions)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetAiSuggestionsViaOpenAiApi(query string, numberCompletions int) ([]string, error) {
|
|
||||||
hctx.GetLogger().Infof("Running OpenAI query for %#v", query)
|
hctx.GetLogger().Infof("Running OpenAI query for %#v", query)
|
||||||
apiKey := os.Getenv("OPENAI_API_KEY")
|
apiKey := os.Getenv("OPENAI_API_KEY")
|
||||||
if apiKey == "" {
|
if apiKey == "" {
|
||||||
return nil, fmt.Errorf("OPENAI_API_KEY environment variable is not set")
|
return nil, OpenAiUsage{}, fmt.Errorf("OPENAI_API_KEY environment variable is not set")
|
||||||
}
|
}
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
apiReq := openAiRequest{
|
apiReq := openAiRequest{
|
||||||
@ -114,30 +66,30 @@ func GetAiSuggestionsViaOpenAiApi(query string, numberCompletions int) ([]string
|
|||||||
}
|
}
|
||||||
apiReqStr, err := json.Marshal(apiReq)
|
apiReqStr, err := json.Marshal(apiReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to serialize JSON for OpenAI API: %w", err)
|
return nil, OpenAiUsage{}, fmt.Errorf("failed to serialize JSON for OpenAI API: %w", err)
|
||||||
}
|
}
|
||||||
req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", bytes.NewBuffer(apiReqStr))
|
req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", bytes.NewBuffer(apiReqStr))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create OpenAI API request: %w", err)
|
return nil, OpenAiUsage{}, fmt.Errorf("failed to create OpenAI API request: %w", err)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to query OpenAI API: %w", err)
|
return nil, OpenAiUsage{}, fmt.Errorf("failed to query OpenAI API: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
bodyText, err := io.ReadAll(resp.Body)
|
bodyText, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to read OpenAI API response: %w", err)
|
return nil, OpenAiUsage{}, fmt.Errorf("failed to read OpenAI API response: %w", err)
|
||||||
}
|
}
|
||||||
var apiResp openAiResponse
|
var apiResp openAiResponse
|
||||||
err = json.Unmarshal(bodyText, &apiResp)
|
err = json.Unmarshal(bodyText, &apiResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse OpenAI API response: %w", err)
|
return nil, OpenAiUsage{}, fmt.Errorf("failed to parse OpenAI API response: %w", err)
|
||||||
}
|
}
|
||||||
if len(apiResp.Choices) == 0 {
|
if len(apiResp.Choices) == 0 {
|
||||||
return nil, fmt.Errorf("OpenAI API returned zero choicesm, resp=%#v", apiResp)
|
return nil, OpenAiUsage{}, fmt.Errorf("OpenAI API returned zero choicesm, resp=%#v", apiResp)
|
||||||
}
|
}
|
||||||
ret := make([]string, 0)
|
ret := make([]string, 0)
|
||||||
for _, item := range apiResp.Choices {
|
for _, item := range apiResp.Choices {
|
||||||
@ -146,7 +98,7 @@ func GetAiSuggestionsViaOpenAiApi(query string, numberCompletions int) ([]string
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
hctx.GetLogger().Infof("For OpenAI query=%#v ==> %#v", query, ret)
|
hctx.GetLogger().Infof("For OpenAI query=%#v ==> %#v", query, ret)
|
||||||
return ret, nil
|
return ret, apiResp.Usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type AiSuggestionRequest struct {
|
type AiSuggestionRequest struct {
|
||||||
|
@ -1,19 +0,0 @@
|
|||||||
package ai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGetAiSuggestion(t *testing.T) {
|
|
||||||
suggestions, err := ai.GetAiSuggestions("list files by size")
|
|
||||||
require.NoError(t, err)
|
|
||||||
for _, suggestion := range suggestions {
|
|
||||||
if strings.Contains(suggestion, "ls") {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
t.Fatalf("none of the AI suggestions %#v contain 'ls' which is suspicious", suggestions)
|
|
||||||
}
|
|
Loading…
x
Reference in New Issue
Block a user