building, before doing the refactor to make device ID just another random ID

This commit is contained in:
David Dworken 2022-04-03 20:55:37 -07:00
parent 2a3887b9ed
commit 32e74eb3a1
4 changed files with 25 additions and 26 deletions

View File

@ -1,12 +1,12 @@
package main package main
import ( import (
"fmt"
"os"
"strings"
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"os"
"strings"
"github.com/ddworken/hishtory/shared" "github.com/ddworken/hishtory/shared"
) )
@ -25,10 +25,10 @@ func main() {
export() export()
case "init": case "init":
shared.CheckFatalError(shared.Setup(0, os.Args)) shared.CheckFatalError(shared.Setup(0, os.Args))
// TODO: Call ebootstrap here // TODO: Call ebootstrap here
case "install": case "install":
shared.CheckFatalError(shared.Install()) shared.CheckFatalError(shared.Install())
// TODO: Call ebootstrap here // TODO: Call ebootstrap here
case "enable": case "enable":
shared.CheckFatalError(shared.Enable()) shared.CheckFatalError(shared.Enable())
case "disable": case "disable":
@ -72,15 +72,14 @@ func saveHistoryEntry() {
entry, err := shared.BuildHistoryEntry(os.Args) entry, err := shared.BuildHistoryEntry(os.Args)
shared.CheckFatalError(err) shared.CheckFatalError(err)
// Persist it locally // Persist it locally
db, err := shared.OpenLocalSqliteDb() db, err := shared.OpenLocalSqliteDb()
shared.CheckFatalError(err) shared.CheckFatalError(err)
err = db.Create(entry) result := db.Create(entry)
shared.CheckFatalError(err) shared.CheckFatalError(result.Error)
// Persist it remotely // Persist it remotely
// TODO: This is encrypting one to this device, this is wrong. We want to encrypt it to every device except this one. encEntry, err := shared.EncryptHistoryEntry(config.UserSecret, *entry)
encEntry, err := shared.EncryptHistoryEntry(config.UserSecret ,config.DeviceId, *entry)
shared.CheckFatalError(err) shared.CheckFatalError(err)
jsonValue, err := json.Marshal(encEntry) jsonValue, err := json.Marshal(encEntry)
shared.CheckFatalError(err) shared.CheckFatalError(err)
@ -98,4 +97,4 @@ func export() {
for _, entry := range data { for _, entry := range data {
fmt.Println(entry) fmt.Println(entry)
} }
} }

View File

@ -36,19 +36,19 @@ func apiESubmitHandler(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
panic(err) panic(err)
} }
GLOBAL_DB.Where("user_id = ?", ) GLOBAL_DB.Where("user_id = ?")
for _, entry := range entries { for _, entry := range entries {
tx := GLOBAL_DB.Where("user_id = ?", entry.UserId) tx := GLOBAL_DB.Where("user_id = ?", entry.UserId)
var devices []*shared.Device; var devices []*shared.Device
result := tx.Find(&devices) result := tx.Find(&devices)
if result.Error != nil { if result.Error != nil {
panic(fmt.Errorf("DB query error: %v", result.Error)) panic(fmt.Errorf("DB query error: %v", result.Error))
} }
if len(devices) == 0{ if len(devices) == 0 {
panic(fmt.Errorf("Found no devices associated with user_id=%s, can't save history entry!", entry.UserId)) panic(fmt.Errorf("Found no devices associated with user_id=%s, can't save history entry!", entry.UserId))
} }
for _, device := range devices { for _, device := range devices {
entry.DeviceId = device.DeviceId; entry.DeviceId = device.DeviceId
GLOBAL_DB.Create(&entry) GLOBAL_DB.Create(&entry)
} }
} }
@ -56,7 +56,7 @@ func apiESubmitHandler(w http.ResponseWriter, r *http.Request) {
func apiEQueryHandler(w http.ResponseWriter, r *http.Request) { func apiEQueryHandler(w http.ResponseWriter, r *http.Request) {
deviceId := r.URL.Query().Get("device_id") deviceId := r.URL.Query().Get("device_id")
// Increment the count // Increment the count
GLOBAL_DB.Exec("UPDATE enc_history_entries SET read_count = read_count + 1 WHERE device_id = ?", deviceId) GLOBAL_DB.Exec("UPDATE enc_history_entries SET read_count = read_count + 1 WHERE device_id = ?", deviceId)
// Then retrieve, to avoid a race condition // Then retrieve, to avoid a race condition

View File

@ -158,7 +158,7 @@ func DisplayResults(results []*HistoryEntry, displayHostname bool) {
type ClientConfig struct { type ClientConfig struct {
UserSecret string `json:"user_secret"` UserSecret string `json:"user_secret"`
IsEnabled bool `json:"is_enabled"` IsEnabled bool `json:"is_enabled"`
DeviceId int `json:"device_id"` DeviceId int `json:"device_id"`
} }
func GetConfig() (ClientConfig, error) { func GetConfig() (ClientConfig, error) {

View File

@ -39,12 +39,12 @@ type EncHistoryEntry struct {
DeviceId string `json:"device_id"` DeviceId string `json:"device_id"`
UserId string `json:"user_id"` UserId string `json:"user_id"`
Date time.Time `json:"time"` Date time.Time `json:"time"`
EncryptedId string `json:"id"` EncryptedId string `json:"id"`
ReadCount int `json:"read_count"` ReadCount int `json:"read_count"`
} }
type Device struct { type Device struct {
UserId string `json:"user_id"` UserId string `json:"user_id"`
DeviceId string `json:"device_id"` DeviceId string `json:"device_id"`
} }
@ -57,7 +57,7 @@ type Device struct {
// } // }
const ( const (
HISHTORY_PATH = ".hishtory" HISHTORY_PATH = ".hishtory"
DB_PATH = ".hishtory.db" DB_PATH = ".hishtory.db"
KDF_USER_ID = "user_id" KDF_USER_ID = "user_id"
KDF_DEVICE_ID = "device_id" KDF_DEVICE_ID = "device_id"
@ -145,8 +145,8 @@ func EncryptHistoryEntry(userSecret string, entry HistoryEntry) (EncHistoryEntry
Nonce: nonce, Nonce: nonce,
UserId: UserId(userSecret), UserId: UserId(userSecret),
Date: time.Now(), Date: time.Now(),
EncryptedId: uuid.Must(uuid.NewRandom()).String(), EncryptedId: uuid.Must(uuid.NewRandom()).String(),
ReadCount: 0, ReadCount: 0,
}, nil }, nil
} }
@ -182,15 +182,15 @@ func OpenLocalSqliteDb() (*gorm.DB, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create ~/.hishtory dir: %v", err) return nil, fmt.Errorf("failed to create ~/.hishtory dir: %v", err)
} }
db, err := gorm.Open(sqlite.Open(path.Join(homedir, HISHTORY_PATH, DB_PATH)), &gorm.Config{SkipDefaultTransaction: true,}) db, err := gorm.Open(sqlite.Open(path.Join(homedir, HISHTORY_PATH, DB_PATH)), &gorm.Config{SkipDefaultTransaction: true})
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to connect to the DB: %v", err) return nil, fmt.Errorf("failed to connect to the DB: %v", err)
} }
tx, err := db.DB() tx, err := db.DB()
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = tx.Ping() err = tx.Ping()
if err != nil { if err != nil {
return nil, err return nil, err
} }