2023-11-12 02:41:24 +01:00
package ai
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"github.com/ddworken/hishtory/client/hctx"
"golang.org/x/exp/slices"
)
type openAiRequest struct {
Model string ` json:"model" `
Messages [ ] openAiMessage ` json:"messages" `
NumberCompletions int ` json:"n" `
}
type openAiMessage struct {
Role string ` json:"role" `
Content string ` json:"content" `
}
type openAiResponse struct {
Id string ` json:"id" `
Object string ` json:"object" `
Created int ` json:"created" `
Model string ` json:"model" `
2023-11-12 05:41:59 +01:00
Usage OpenAiUsage ` json:"usage" `
2023-11-12 02:41:24 +01:00
Choices [ ] openAiChoice ` json:"choices" `
}
type openAiChoice struct {
Index int ` json:"index" `
Message openAiMessage ` json:"message" `
FinishReason string ` json:"finish_reason" `
}
2023-11-12 05:41:59 +01:00
type OpenAiUsage struct {
2023-11-12 02:41:24 +01:00
PromptTokens int ` json:"prompt_tokens" `
CompletionTokens int ` json:"completion_tokens" `
TotalTokens int ` json:"total_tokens" `
}
2023-11-12 06:13:03 +01:00
type TestOnlyOverrideAiSuggestionRequest struct {
Query string ` json:"query" `
Suggestions [ ] string ` json:"suggestions" `
}
2023-11-12 05:59:45 +01:00
var TestOnlyOverrideAiSuggestions map [ string ] [ ] string = make ( map [ string ] [ ] string )
2023-11-12 02:41:24 +01:00
2023-11-12 05:41:59 +01:00
func GetAiSuggestionsViaOpenAiApi ( query string , numberCompletions int ) ( [ ] string , OpenAiUsage , error ) {
2023-11-12 05:59:45 +01:00
if results := TestOnlyOverrideAiSuggestions [ query ] ; len ( results ) > 0 {
2023-11-12 05:41:59 +01:00
return results , OpenAiUsage { } , nil
2023-11-12 02:41:24 +01:00
}
hctx . GetLogger ( ) . Infof ( "Running OpenAI query for %#v" , query )
apiKey := os . Getenv ( "OPENAI_API_KEY" )
if apiKey == "" {
2023-11-12 05:41:59 +01:00
return nil , OpenAiUsage { } , fmt . Errorf ( "OPENAI_API_KEY environment variable is not set" )
2023-11-12 02:41:24 +01:00
}
client := & http . Client { }
apiReq := openAiRequest {
Model : "gpt-3.5-turbo" ,
NumberCompletions : numberCompletions ,
Messages : [ ] openAiMessage {
{ Role : "system" , Content : "You are an expert programmer that loves to help people with writing shell commands. You always reply with just a shell command and no additional context or information. Your replies will be directly executed in bash, so ensure that they are correct and do not contain anything other than a bash command." } ,
{ Role : "user" , Content : query } ,
} ,
}
apiReqStr , err := json . Marshal ( apiReq )
if err != nil {
2023-11-12 05:41:59 +01:00
return nil , OpenAiUsage { } , fmt . Errorf ( "failed to serialize JSON for OpenAI API: %w" , err )
2023-11-12 02:41:24 +01:00
}
req , err := http . NewRequest ( "POST" , "https://api.openai.com/v1/chat/completions" , bytes . NewBuffer ( apiReqStr ) )
if err != nil {
2023-11-12 05:41:59 +01:00
return nil , OpenAiUsage { } , fmt . Errorf ( "failed to create OpenAI API request: %w" , err )
2023-11-12 02:41:24 +01:00
}
req . Header . Set ( "Content-Type" , "application/json" )
req . Header . Set ( "Authorization" , "Bearer " + apiKey )
resp , err := client . Do ( req )
if err != nil {
2023-11-12 05:41:59 +01:00
return nil , OpenAiUsage { } , fmt . Errorf ( "failed to query OpenAI API: %w" , err )
2023-11-12 02:41:24 +01:00
}
defer resp . Body . Close ( )
bodyText , err := io . ReadAll ( resp . Body )
if err != nil {
2023-11-12 05:41:59 +01:00
return nil , OpenAiUsage { } , fmt . Errorf ( "failed to read OpenAI API response: %w" , err )
2023-11-12 02:41:24 +01:00
}
var apiResp openAiResponse
err = json . Unmarshal ( bodyText , & apiResp )
if err != nil {
2023-11-12 06:44:21 +01:00
return nil , OpenAiUsage { } , fmt . Errorf ( "failed to parse OpenAI API response=%#v: %w" , bodyText , err )
2023-11-12 02:41:24 +01:00
}
if len ( apiResp . Choices ) == 0 {
2023-11-12 06:44:21 +01:00
return nil , OpenAiUsage { } , fmt . Errorf ( "OpenAI API returned zero choices, resp=%#v" , apiResp )
2023-11-12 02:41:24 +01:00
}
ret := make ( [ ] string , 0 )
for _ , item := range apiResp . Choices {
if ! slices . Contains ( ret , item . Message . Content ) {
ret = append ( ret , item . Message . Content )
}
}
hctx . GetLogger ( ) . Infof ( "For OpenAI query=%#v ==> %#v" , query , ret )
2023-11-12 05:41:59 +01:00
return ret , apiResp . Usage , nil
2023-11-12 02:41:24 +01:00
}
type AiSuggestionRequest struct {
DeviceId string ` json:"device_id" `
UserId string ` json:"user_id" `
Query string ` json:"query" `
NumberCompletions int ` json:"number_completions" `
}
type AiSuggestionResponse struct {
Suggestions [ ] string ` json:"suggestions" `
}