mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-15 19:31:06 +01:00
765aba2c1c
propagate context from all the API calls and log request ID, account ID and peer ID --------- Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
243 lines
5.6 KiB
Go
243 lines
5.6 KiB
Go
package geolocation
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"path"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/oschwald/maxminddb-golang"
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const MMDBFileName = "GeoLite2-City.mmdb"
|
|
|
|
type Geolocation struct {
|
|
mmdbPath string
|
|
mux sync.RWMutex
|
|
sha256sum []byte
|
|
db *maxminddb.Reader
|
|
locationDB *SqliteStore
|
|
stopCh chan struct{}
|
|
reloadCheckInterval time.Duration
|
|
}
|
|
|
|
type Record struct {
|
|
City struct {
|
|
GeonameID uint `maxminddb:"geoname_id"`
|
|
Names struct {
|
|
En string `maxminddb:"en"`
|
|
} `maxminddb:"names"`
|
|
} `maxminddb:"city"`
|
|
Continent struct {
|
|
GeonameID uint `maxminddb:"geoname_id"`
|
|
Code string `maxminddb:"code"`
|
|
} `maxminddb:"continent"`
|
|
Country struct {
|
|
GeonameID uint `maxminddb:"geoname_id"`
|
|
ISOCode string `maxminddb:"iso_code"`
|
|
} `maxminddb:"country"`
|
|
}
|
|
|
|
type City struct {
|
|
GeoNameID int `gorm:"column:geoname_id"`
|
|
CityName string
|
|
}
|
|
|
|
type Country struct {
|
|
CountryISOCode string `gorm:"column:country_iso_code"`
|
|
CountryName string
|
|
}
|
|
|
|
func NewGeolocation(ctx context.Context, dataDir string) (*Geolocation, error) {
|
|
if err := loadGeolocationDatabases(dataDir); err != nil {
|
|
return nil, fmt.Errorf("failed to load MaxMind databases: %v", err)
|
|
}
|
|
|
|
mmdbPath := path.Join(dataDir, MMDBFileName)
|
|
db, err := openDB(mmdbPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sha256sum, err := calculateFileSHA256(mmdbPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
locationDB, err := NewSqliteStore(ctx, dataDir)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
geo := &Geolocation{
|
|
mmdbPath: mmdbPath,
|
|
mux: sync.RWMutex{},
|
|
sha256sum: sha256sum,
|
|
db: db,
|
|
locationDB: locationDB,
|
|
reloadCheckInterval: 300 * time.Second, // TODO: make configurable
|
|
stopCh: make(chan struct{}),
|
|
}
|
|
|
|
go geo.reloader(ctx)
|
|
|
|
return geo, nil
|
|
}
|
|
|
|
func openDB(mmdbPath string) (*maxminddb.Reader, error) {
|
|
_, err := os.Stat(mmdbPath)
|
|
|
|
if os.IsNotExist(err) {
|
|
return nil, fmt.Errorf("%v does not exist", mmdbPath)
|
|
} else if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
db, err := maxminddb.Open(mmdbPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%v could not be opened: %w", mmdbPath, err)
|
|
}
|
|
|
|
return db, nil
|
|
}
|
|
|
|
func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) {
|
|
gl.mux.RLock()
|
|
defer gl.mux.RUnlock()
|
|
|
|
var record Record
|
|
err := gl.db.Lookup(ip, &record)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &record, nil
|
|
}
|
|
|
|
// GetAllCountries retrieves a list of all countries.
|
|
func (gl *Geolocation) GetAllCountries() ([]Country, error) {
|
|
allCountries, err := gl.locationDB.GetAllCountries()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
countries := make([]Country, 0)
|
|
for _, country := range allCountries {
|
|
if country.CountryName != "" {
|
|
countries = append(countries, country)
|
|
}
|
|
}
|
|
return countries, nil
|
|
}
|
|
|
|
// GetCitiesByCountry retrieves a list of cities in a specific country based on the country's ISO code.
|
|
func (gl *Geolocation) GetCitiesByCountry(countryISOCode string) ([]City, error) {
|
|
allCities, err := gl.locationDB.GetCitiesByCountry(countryISOCode)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cities := make([]City, 0)
|
|
for _, city := range allCities {
|
|
if city.CityName != "" {
|
|
cities = append(cities, city)
|
|
}
|
|
}
|
|
return cities, nil
|
|
}
|
|
|
|
func (gl *Geolocation) Stop() error {
|
|
close(gl.stopCh)
|
|
if gl.db != nil {
|
|
if err := gl.db.Close(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if gl.locationDB != nil {
|
|
if err := gl.locationDB.close(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (gl *Geolocation) reloader(ctx context.Context) {
|
|
for {
|
|
select {
|
|
case <-gl.stopCh:
|
|
return
|
|
case <-time.After(gl.reloadCheckInterval):
|
|
if err := gl.locationDB.reload(ctx); err != nil {
|
|
log.WithContext(ctx).Errorf("geonames db reload failed: %s", err)
|
|
}
|
|
|
|
newSha256sum1, err := calculateFileSHA256(gl.mmdbPath)
|
|
if err != nil {
|
|
log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
|
|
continue
|
|
}
|
|
if !bytes.Equal(gl.sha256sum, newSha256sum1) {
|
|
// we check sum twice just to avoid possible case when we reload during update of the file
|
|
// considering the frequency of file update (few times a week) checking sum twice should be enough
|
|
time.Sleep(50 * time.Millisecond)
|
|
newSha256sum2, err := calculateFileSHA256(gl.mmdbPath)
|
|
if err != nil {
|
|
log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
|
|
continue
|
|
}
|
|
if !bytes.Equal(newSha256sum1, newSha256sum2) {
|
|
log.WithContext(ctx).Errorf("sha256 sum changed during reloading of '%s'", gl.mmdbPath)
|
|
continue
|
|
}
|
|
err = gl.reload(ctx, newSha256sum2)
|
|
if err != nil {
|
|
log.WithContext(ctx).Errorf("mmdb reload failed: %s", err)
|
|
}
|
|
} else {
|
|
log.WithContext(ctx).Tracef("No changes in '%s', no need to reload. Next check is in %.0f seconds.",
|
|
gl.mmdbPath, gl.reloadCheckInterval.Seconds())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (gl *Geolocation) reload(ctx context.Context, newSha256sum []byte) error {
|
|
gl.mux.Lock()
|
|
defer gl.mux.Unlock()
|
|
|
|
log.WithContext(ctx).Infof("Reloading '%s'", gl.mmdbPath)
|
|
|
|
err := gl.db.Close()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
db, err := openDB(gl.mmdbPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
gl.db = db
|
|
gl.sha256sum = newSha256sum
|
|
|
|
log.WithContext(ctx).Infof("Successfully reloaded '%s'", gl.mmdbPath)
|
|
|
|
return nil
|
|
}
|
|
|
|
func fileExists(filePath string) (bool, error) {
|
|
_, err := os.Stat(filePath)
|
|
if err == nil {
|
|
return true, nil
|
|
}
|
|
if os.IsNotExist(err) {
|
|
return false, fmt.Errorf("%v does not exist", filePath)
|
|
}
|
|
return false, err
|
|
}
|