Get RemoteAddr from X-Real-Ip header

This commit is contained in:
David Dworken 2022-09-29 23:51:45 -07:00
parent f0a3caed1c
commit 98a4f002fa

View File

@ -38,6 +38,7 @@ type UsageData struct {
UserId string `json:"user_id" gorm:"not null; uniqueIndex:usageDataUniqueIndex"` UserId string `json:"user_id" gorm:"not null; uniqueIndex:usageDataUniqueIndex"`
DeviceId string `json:"device_id" gorm:"not null; uniqueIndex:usageDataUniqueIndex"` DeviceId string `json:"device_id" gorm:"not null; uniqueIndex:usageDataUniqueIndex"`
LastUsed time.Time `json:"last_used"` LastUsed time.Time `json:"last_used"`
LastIp string `json:"last_ip"`
NumEntriesHandled int `json:"num_entries_handled"` NumEntriesHandled int `json:"num_entries_handled"`
} }
@ -49,13 +50,13 @@ func getRequiredQueryParam(r *http.Request, queryParam string) string {
return val return val
} }
func updateUsageData(userId, deviceId string, numEntries int) { func updateUsageData(r *http.Request, userId, deviceId string, numEntries int) {
var usageData []UsageData var usageData []UsageData
GLOBAL_DB.Where("user_id = ? AND device_id = ?", userId, deviceId).Find(&usageData) GLOBAL_DB.Where("user_id = ? AND device_id = ?", userId, deviceId).Find(&usageData)
if len(usageData) == 0 { if len(usageData) == 0 {
GLOBAL_DB.Create(&UsageData{UserId: userId, DeviceId: deviceId, LastUsed: time.Now(), NumEntriesHandled: numEntries}) GLOBAL_DB.Create(&UsageData{UserId: userId, DeviceId: deviceId, LastUsed: time.Now(), NumEntriesHandled: numEntries})
} else { } else {
GLOBAL_DB.Model(&UsageData{}).Where("user_id = ? AND device_id = ?", userId, deviceId).Update("last_used", time.Now()) GLOBAL_DB.Model(&UsageData{}).Where("user_id = ? AND device_id = ?", userId, deviceId).Update("last_used", time.Now()).Update("last_ip", getRemoteAddr(r))
if numEntries > 0 { if numEntries > 0 {
GLOBAL_DB.Exec("UPDATE usage_data SET num_entries_handled = COALESCE(num_entries_handled, 0) + ? WHERE user_id = ? AND device_id = ?", numEntries, userId, deviceId) GLOBAL_DB.Exec("UPDATE usage_data SET num_entries_handled = COALESCE(num_entries_handled, 0) + ? WHERE user_id = ? AND device_id = ?", numEntries, userId, deviceId)
} }
@ -69,7 +70,8 @@ func usageStatsHandler(w http.ResponseWriter, r *http.Request) {
MIN(devices.registration_date) as registration_date, MIN(devices.registration_date) as registration_date,
COUNT(DISTINCT devices.device_id) as num_devices, COUNT(DISTINCT devices.device_id) as num_devices,
SUM(usage_data.num_entries_handled) as num_history_entries, SUM(usage_data.num_entries_handled) as num_history_entries,
MAX(usage_data.last_used) as last_active MAX(usage_data.last_used) as last_active,
COALESCE(STRING_AGG(DISTINCT usage_data.last_ip, ' ') FILTER (WHERE usage_data.last_ip != 'Unknown'), 'Unknown') as ip_addresses
FROM devices FROM devices
INNER JOIN usage_data ON devices.device_id = usage_data.device_id INNER JOIN usage_data ON devices.device_id = usage_data.device_id
GROUP BY devices.user_id GROUP BY devices.user_id
@ -84,11 +86,12 @@ func usageStatsHandler(w http.ResponseWriter, r *http.Request) {
var numDevices int var numDevices int
var numEntries int var numEntries int
var lastUsedDate time.Time var lastUsedDate time.Time
err = rows.Scan(&registrationDate, &numDevices, &numEntries, &lastUsedDate) var ipAddresses string
err = rows.Scan(&registrationDate, &numDevices, &numEntries, &lastUsedDate, &ipAddresses)
if err != nil { if err != nil {
panic(err) panic(err)
} }
w.Write([]byte(fmt.Sprintf("Registered: %s\tNumDevices: %d\tNumEntries: %d\tLastUsed: %s\n", registrationDate.Format("2006-01-02"), numDevices, numEntries, lastUsedDate.Format("2006-01-02")))) w.Write([]byte(fmt.Sprintf("Registered: %s\tNumDevices: %d\tNumEntries: %d\tLastUsed: %s\tIP: %s\n", registrationDate.Format("2006-01-02"), numDevices, numEntries, lastUsedDate.Format("2006-01-02"), ipAddresses)))
} }
} }
@ -104,7 +107,7 @@ func apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
} }
fmt.Printf("apiSubmitHandler: received request containg %d EncHistoryEntry\n", len(entries)) fmt.Printf("apiSubmitHandler: received request containg %d EncHistoryEntry\n", len(entries))
for _, entry := range entries { for _, entry := range entries {
updateUsageData(entry.UserId, entry.DeviceId, 1) updateUsageData(r, entry.UserId, entry.DeviceId, 1)
tx := GLOBAL_DB.Where("user_id = ?", entry.UserId) tx := GLOBAL_DB.Where("user_id = ?", entry.UserId)
var devices []*shared.Device var devices []*shared.Device
result := tx.Find(&devices) result := tx.Find(&devices)
@ -128,7 +131,7 @@ func apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
func apiBootstrapHandler(w http.ResponseWriter, r *http.Request) { func apiBootstrapHandler(w http.ResponseWriter, r *http.Request) {
userId := getRequiredQueryParam(r, "user_id") userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id") deviceId := getRequiredQueryParam(r, "device_id")
updateUsageData(userId, deviceId, 0) updateUsageData(r, userId, deviceId, 0)
tx := GLOBAL_DB.Where("user_id = ?", userId) tx := GLOBAL_DB.Where("user_id = ?", userId)
var historyEntries []*shared.EncHistoryEntry var historyEntries []*shared.EncHistoryEntry
result := tx.Find(&historyEntries) result := tx.Find(&historyEntries)
@ -145,7 +148,7 @@ func apiBootstrapHandler(w http.ResponseWriter, r *http.Request) {
func apiQueryHandler(w http.ResponseWriter, r *http.Request) { func apiQueryHandler(w http.ResponseWriter, r *http.Request) {
userId := getRequiredQueryParam(r, "user_id") userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id") deviceId := getRequiredQueryParam(r, "device_id")
updateUsageData(userId, deviceId, 0) updateUsageData(r, userId, deviceId, 0)
// Increment the count // Increment the count
result := GLOBAL_DB.Exec("UPDATE enc_history_entries SET read_count = read_count + 1 WHERE device_id = ?", deviceId) result := GLOBAL_DB.Exec("UPDATE enc_history_entries SET read_count = read_count + 1 WHERE device_id = ?", deviceId)
if result.Error != nil { if result.Error != nil {
@ -177,6 +180,14 @@ func apiQueryHandler(w http.ResponseWriter, r *http.Request) {
w.Write(resp) w.Write(resp)
} }
func getRemoteAddr(r *http.Request) string {
addr, ok := r.Header["X-Real-Ip"]
if !ok || len(addr) == 0 {
return "Unknown"
}
return addr[0]
}
func apiRegisterHandler(w http.ResponseWriter, r *http.Request) { func apiRegisterHandler(w http.ResponseWriter, r *http.Request) {
userId := getRequiredQueryParam(r, "user_id") userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id") deviceId := getRequiredQueryParam(r, "device_id")
@ -186,12 +197,11 @@ func apiRegisterHandler(w http.ResponseWriter, r *http.Request) {
if result.Error != nil { if result.Error != nil {
panic(result.Error) panic(result.Error)
} }
// TODO: r.RemoteAddr isn't using the proxy protocol and isn't the actual device IP GLOBAL_DB.Create(&shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()})
GLOBAL_DB.Create(&shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: r.RemoteAddr, RegistrationDate: time.Now()})
if existingDevicesCount > 0 { if existingDevicesCount > 0 {
GLOBAL_DB.Create(&shared.DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()}) GLOBAL_DB.Create(&shared.DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()})
} }
updateUsageData(userId, deviceId, 0) updateUsageData(r, userId, deviceId, 0)
} }
func apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) { func apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) {
@ -244,7 +254,7 @@ func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) {
if result.Error != nil { if result.Error != nil {
panic(fmt.Errorf("failed to clear the dump request: %v", err)) panic(fmt.Errorf("failed to clear the dump request: %v", err))
} }
updateUsageData(userId, srcDeviceId, len(entries)) updateUsageData(r, userId, srcDeviceId, len(entries))
} }
func apiBannerHandler(w http.ResponseWriter, r *http.Request) { func apiBannerHandler(w http.ResponseWriter, r *http.Request) {