hishtory/shared/ai/ai.go

165 lines
5.3 KiB
Go

package ai
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strconv"
"github.com/ddworken/hishtory/client/hctx"
"golang.org/x/exp/slices"
)
const DefaultOpenAiEndpoint = "https://api.openai.com/v1/chat/completions"
type openAiRequest struct {
Model string `json:"model"`
Messages []openAiMessage `json:"messages"`
NumberCompletions int `json:"n"`
}
type openAiMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type openAiResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
Usage OpenAiUsage `json:"usage"`
Choices []openAiChoice `json:"choices"`
}
type openAiChoice struct {
Index int `json:"index"`
Message openAiMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type OpenAiUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type TestOnlyOverrideAiSuggestionRequest struct {
Query string `json:"query"`
Suggestions []string `json:"suggestions"`
}
var TestOnlyOverrideAiSuggestions map[string][]string = make(map[string][]string)
func GetAiSuggestionsViaOpenAiApi(apiEndpoint, query, shellName, osName, overriddenOpenAiModel string, numberCompletions int) ([]string, OpenAiUsage, error) {
if results := TestOnlyOverrideAiSuggestions[query]; len(results) > 0 {
return results, OpenAiUsage{}, nil
}
hctx.GetLogger().Infof("Running OpenAI query for %#v", query)
apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey == "" && apiEndpoint == DefaultOpenAiEndpoint {
return nil, OpenAiUsage{}, fmt.Errorf("OPENAI_API_KEY environment variable is not set")
}
apiReqStr, err := json.Marshal(createOpenAiRequest(query, shellName, osName, overriddenOpenAiModel, numberCompletions))
if err != nil {
return nil, OpenAiUsage{}, fmt.Errorf("failed to serialize JSON for OpenAI API: %w", err)
}
req, err := http.NewRequest(http.MethodPost, apiEndpoint, bytes.NewBuffer(apiReqStr))
if err != nil {
return nil, OpenAiUsage{}, fmt.Errorf("failed to create OpenAI API request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, OpenAiUsage{}, fmt.Errorf("failed to query OpenAI API: %w", err)
}
defer resp.Body.Close()
bodyText, err := io.ReadAll(resp.Body)
if err != nil {
return nil, OpenAiUsage{}, fmt.Errorf("failed to read OpenAI API response: %w", err)
}
if resp.StatusCode == 429 {
return nil, OpenAiUsage{}, fmt.Errorf("received 429 error code from OpenAI (is your API key valid?)")
}
var apiResp openAiResponse
err = json.Unmarshal(bodyText, &apiResp)
if err != nil {
return nil, OpenAiUsage{}, fmt.Errorf("failed to parse OpenAI API response=%#v: %w", string(bodyText), err)
}
if len(apiResp.Choices) == 0 {
return nil, OpenAiUsage{}, fmt.Errorf("OpenAI API returned zero choices, parsed resp=%#v, resp body=%#v, resp.StatusCode=%d", apiResp, bodyText, resp.StatusCode)
}
ret := make([]string, 0)
for _, item := range apiResp.Choices {
if !slices.Contains(ret, item.Message.Content) {
ret = append(ret, item.Message.Content)
}
}
hctx.GetLogger().Infof("For OpenAI query=%#v ==> %#v", query, ret)
return ret, apiResp.Usage, nil
}
type AiSuggestionRequest struct {
DeviceId string `json:"device_id"`
UserId string `json:"user_id"`
Query string `json:"query"`
NumberCompletions int `json:"number_completions"`
ShellName string `json:"shell_name"`
OsName string `json:"os_name"`
Model string `json:"model"`
}
type AiSuggestionResponse struct {
Suggestions []string `json:"suggestions"`
}
func createOpenAiRequest(query, shellName, osName, overriddenOpenAiModel string, numberCompletions int) openAiRequest {
if osName == "" {
osName = "Linux"
}
if shellName == "" {
shellName = "bash"
}
// According to https://platform.openai.com/docs/models gpt-4o-mini is the best model
// by performance/price ratio.
model := "gpt-4o-mini"
if envModel := os.Getenv("OPENAI_API_MODEL"); envModel != "" {
model = envModel
}
if overriddenOpenAiModel != "" {
model = overriddenOpenAiModel
}
if envNumberCompletions := os.Getenv("OPENAI_API_NUMBER_COMPLETIONS"); envNumberCompletions != "" {
n, err := strconv.Atoi(envNumberCompletions)
if err == nil {
numberCompletions = n
}
}
defaultSystemPrompt := "You are an expert programmer that loves to help people with writing shell commands. " +
"You always reply with just a shell command and no additional context, information, or formatting. " +
"Your replies will be directly executed in " + shellName + " on " + osName +
", so ensure that they are correct and do not contain anything other than a shell command."
if systemPrompt := os.Getenv("OPENAI_API_SYSTEM_PROMPT"); systemPrompt != "" {
defaultSystemPrompt = systemPrompt
}
return openAiRequest{
Model: model,
NumberCompletions: numberCompletions,
Messages: []openAiMessage{
{Role: "system", Content: defaultSystemPrompt},
{Role: "user", Content: query},
},
}
}