mirror of
https://github.com/netbirdio/netbird.git
synced 2025-02-23 13:41:19 +01:00
Add support for downloading Geo databases to the management service (#1626)
Adds support for downloading Geo databases to the management service. If the Geo databases are not found, the service will automatically attempt to download them during startup.
This commit is contained in:
parent
d8ce08d898
commit
b65c2f69b0
@ -199,6 +199,6 @@ jobs:
|
||||
- name: test script
|
||||
run: bash -x infrastructure_files/download-geolite2.sh
|
||||
- name: test mmdb file exists
|
||||
run: ls -l GeoLite2-City_*/GeoLite2-City.mmdb
|
||||
run: test -f GeoLite2-City.mmdb
|
||||
- name: test geonames file exists
|
||||
run: test -f geonames.db
|
||||
|
@ -43,21 +43,18 @@ download_geolite_mmdb() {
|
||||
mkdir -p "$EXTRACTION_DIR"
|
||||
tar -xzvf "$DATABASE_FILE" > /dev/null 2>&1
|
||||
|
||||
# Create a SHA256 signature file
|
||||
MMDB_FILE="GeoLite2-City.mmdb"
|
||||
cd "$EXTRACTION_DIR"
|
||||
sha256sum "$MMDB_FILE" > "$MMDB_FILE.sha256"
|
||||
echo "SHA256 signature created for $MMDB_FILE."
|
||||
cd - > /dev/null 2>&1
|
||||
cp "$EXTRACTION_DIR"/"$MMDB_FILE" $MMDB_FILE
|
||||
|
||||
# Remove downloaded files
|
||||
rm -r "$EXTRACTION_DIR"
|
||||
rm "$DATABASE_FILE" "$SIGNATURE_FILE"
|
||||
|
||||
# Done. Print next steps
|
||||
echo ""
|
||||
echo "Process completed successfully."
|
||||
echo "Now you can place $EXTRACTION_DIR/$MMDB_FILE to 'datadir' of management service."
|
||||
echo -e "Example:\n\tdocker compose cp $EXTRACTION_DIR/$MMDB_FILE management:/var/lib/netbird/"
|
||||
echo "Now you can place $MMDB_FILE to 'datadir' of management service."
|
||||
echo -e "Example:\n\tdocker compose cp $MMDB_FILE management:/var/lib/netbird/"
|
||||
}
|
||||
|
||||
|
||||
|
@ -166,7 +166,7 @@ var (
|
||||
|
||||
geo, err := geolocation.NewGeolocation(config.Datadir)
|
||||
if err != nil {
|
||||
log.Warnf("could not initialize geo location service, we proceed without geo support")
|
||||
log.Warnf("could not initialize geo location service: %v, we proceed without geo support", err)
|
||||
} else {
|
||||
log.Infof("geo location service has been initialized from %s", config.Datadir)
|
||||
}
|
||||
|
187
management/server/geolocation/database.go
Normal file
187
management/server/geolocation/database.go
Normal file
@ -0,0 +1,187 @@
|
||||
package geolocation
|
||||
|
||||
import (
|
||||
"encoding/csv"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
geoLiteCityTarGZURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz"
|
||||
geoLiteCityZipURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City-CSV/download?suffix=zip"
|
||||
geoLiteCitySha256TarURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz.sha256"
|
||||
geoLiteCitySha256ZipURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City-CSV/download?suffix=zip.sha256"
|
||||
)
|
||||
|
||||
// loadGeolocationDatabases loads the MaxMind databases.
|
||||
func loadGeolocationDatabases(dataDir string) error {
|
||||
files := []string{MMDBFileName, GeoSqliteDBFile}
|
||||
for _, file := range files {
|
||||
exists, _ := fileExists(path.Join(dataDir, file))
|
||||
if exists {
|
||||
continue
|
||||
}
|
||||
|
||||
switch file {
|
||||
case MMDBFileName:
|
||||
extractFunc := func(src string, dst string) error {
|
||||
if err := decompressTarGzFile(src, dst); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.Rename(path.Join(dst, MMDBFileName), path.Join(dataDir, MMDBFileName))
|
||||
}
|
||||
if err := loadDatabase(
|
||||
geoLiteCitySha256TarURL,
|
||||
geoLiteCityTarGZURL,
|
||||
extractFunc,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case GeoSqliteDBFile:
|
||||
extractFunc := func(src string, dst string) error {
|
||||
if err := decompressZipFile(src, dst); err != nil {
|
||||
return err
|
||||
}
|
||||
extractedCsvFile := path.Join(dst, "GeoLite2-City-Locations-en.csv")
|
||||
return importCsvToSqlite(dataDir, extractedCsvFile)
|
||||
}
|
||||
|
||||
if err := loadDatabase(
|
||||
geoLiteCitySha256ZipURL,
|
||||
geoLiteCityZipURL,
|
||||
extractFunc,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadDatabase downloads a file from the specified URL and verifies its checksum.
|
||||
// It then calls the extract function to perform additional processing on the extracted files.
|
||||
func loadDatabase(checksumURL string, fileURL string, extractFunc func(src string, dst string) error) error {
|
||||
temp, err := os.MkdirTemp(os.TempDir(), "geolite")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer os.RemoveAll(temp)
|
||||
|
||||
checksumFile := path.Join(temp, getDatabaseFileName(checksumURL))
|
||||
err = downloadFile(checksumURL, checksumFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sha256sum, err := loadChecksumFromFile(checksumFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbFile := path.Join(temp, getDatabaseFileName(fileURL))
|
||||
err = downloadFile(fileURL, dbFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := verifyChecksum(dbFile, sha256sum); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return extractFunc(dbFile, temp)
|
||||
}
|
||||
|
||||
// importCsvToSqlite imports a CSV file into a SQLite database.
|
||||
func importCsvToSqlite(dataDir string, csvFile string) error {
|
||||
geonames, err := loadGeonamesCsv(csvFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(path.Join(dataDir, GeoSqliteDBFile)), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
CreateBatchSize: 1000,
|
||||
PrepareStmt: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
sql, err := db.DB()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
sql.Close()
|
||||
}()
|
||||
|
||||
if err := db.AutoMigrate(&GeoNames{}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return db.Create(geonames).Error
|
||||
}
|
||||
|
||||
func loadGeonamesCsv(filepath string) ([]GeoNames, error) {
|
||||
f, err := os.Open(filepath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
reader := csv.NewReader(f)
|
||||
records, err := reader.ReadAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var geoNames []GeoNames
|
||||
for index, record := range records {
|
||||
if index == 0 {
|
||||
continue
|
||||
}
|
||||
geoNameID, err := strconv.Atoi(record[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
geoName := GeoNames{
|
||||
GeoNameID: geoNameID,
|
||||
LocaleCode: record[1],
|
||||
ContinentCode: record[2],
|
||||
ContinentName: record[3],
|
||||
CountryIsoCode: record[4],
|
||||
CountryName: record[5],
|
||||
Subdivision1IsoCode: record[6],
|
||||
Subdivision1Name: record[7],
|
||||
Subdivision2IsoCode: record[8],
|
||||
Subdivision2Name: record[9],
|
||||
CityName: record[10],
|
||||
MetroCode: record[11],
|
||||
TimeZone: record[12],
|
||||
IsInEuropeanUnion: record[13],
|
||||
}
|
||||
geoNames = append(geoNames, geoName)
|
||||
}
|
||||
|
||||
return geoNames, nil
|
||||
}
|
||||
|
||||
// getDatabaseFileName extracts the file name from a given URL string.
|
||||
func getDatabaseFileName(urlStr string) string {
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ext := u.Query().Get("suffix")
|
||||
fileName := fmt.Sprintf("%s.%s", path.Base(u.Path), ext)
|
||||
return fileName
|
||||
}
|
@ -2,9 +2,7 @@ package geolocation
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path"
|
||||
@ -54,20 +52,23 @@ type Country struct {
|
||||
CountryName string
|
||||
}
|
||||
|
||||
func NewGeolocation(datadir string) (*Geolocation, error) {
|
||||
mmdbPath := path.Join(datadir, MMDBFileName)
|
||||
func NewGeolocation(dataDir string) (*Geolocation, error) {
|
||||
if err := loadGeolocationDatabases(dataDir); err != nil {
|
||||
return nil, fmt.Errorf("failed to load MaxMind databases: %v", err)
|
||||
}
|
||||
|
||||
mmdbPath := path.Join(dataDir, MMDBFileName)
|
||||
db, err := openDB(mmdbPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sha256sum, err := getSha256sum(mmdbPath)
|
||||
sha256sum, err := calculateFileSHA256(mmdbPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
locationDB, err := NewSqliteStore(datadir)
|
||||
locationDB, err := NewSqliteStore(dataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -104,21 +105,6 @@ func openDB(mmdbPath string) (*maxminddb.Reader, error) {
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func getSha256sum(mmdbPath string) ([]byte, error) {
|
||||
f, err := os.Open(mmdbPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, f); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return h.Sum(nil), nil
|
||||
}
|
||||
|
||||
func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) {
|
||||
gl.mux.RLock()
|
||||
defer gl.mux.RUnlock()
|
||||
@ -189,7 +175,7 @@ func (gl *Geolocation) reloader() {
|
||||
log.Errorf("geonames db reload failed: %s", err)
|
||||
}
|
||||
|
||||
newSha256sum1, err := getSha256sum(gl.mmdbPath)
|
||||
newSha256sum1, err := calculateFileSHA256(gl.mmdbPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
|
||||
continue
|
||||
@ -198,7 +184,7 @@ func (gl *Geolocation) reloader() {
|
||||
// we check sum twice just to avoid possible case when we reload during update of the file
|
||||
// considering the frequency of file update (few times a week) checking sum twice should be enough
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
newSha256sum2, err := getSha256sum(gl.mmdbPath)
|
||||
newSha256sum2, err := calculateFileSHA256(gl.mmdbPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
|
||||
continue
|
||||
|
@ -20,6 +20,27 @@ const (
|
||||
GeoSqliteDBFile = "geonames.db"
|
||||
)
|
||||
|
||||
type GeoNames struct {
|
||||
GeoNameID int `gorm:"column:geoname_id"`
|
||||
LocaleCode string `gorm:"column:locale_code"`
|
||||
ContinentCode string `gorm:"column:continent_code"`
|
||||
ContinentName string `gorm:"column:continent_name"`
|
||||
CountryIsoCode string `gorm:"column:country_iso_code"`
|
||||
CountryName string `gorm:"column:country_name"`
|
||||
Subdivision1IsoCode string `gorm:"column:subdivision_1_iso_code"`
|
||||
Subdivision1Name string `gorm:"column:subdivision_1_name"`
|
||||
Subdivision2IsoCode string `gorm:"column:subdivision_2_iso_code"`
|
||||
Subdivision2Name string `gorm:"column:subdivision_2_name"`
|
||||
CityName string `gorm:"column:city_name"`
|
||||
MetroCode string `gorm:"column:metro_code"`
|
||||
TimeZone string `gorm:"column:time_zone"`
|
||||
IsInEuropeanUnion string `gorm:"column:is_in_european_union"`
|
||||
}
|
||||
|
||||
func (*GeoNames) TableName() string {
|
||||
return "geonames"
|
||||
}
|
||||
|
||||
// SqliteStore represents a location storage backed by a Sqlite DB.
|
||||
type SqliteStore struct {
|
||||
db *gorm.DB
|
||||
@ -37,7 +58,7 @@ func NewSqliteStore(dataDir string) (*SqliteStore, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sha256sum, err := getSha256sum(file)
|
||||
sha256sum, err := calculateFileSHA256(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -60,7 +81,7 @@ func (s *SqliteStore) GetAllCountries() ([]Country, error) {
|
||||
}
|
||||
|
||||
var countries []Country
|
||||
result := s.db.Table("geonames").
|
||||
result := s.db.Model(&GeoNames{}).
|
||||
Select("country_iso_code", "country_name").
|
||||
Group("country_name").
|
||||
Scan(&countries)
|
||||
@ -81,7 +102,7 @@ func (s *SqliteStore) GetCitiesByCountry(countryISOCode string) ([]City, error)
|
||||
}
|
||||
|
||||
var cities []City
|
||||
result := s.db.Table("geonames").
|
||||
result := s.db.Model(&GeoNames{}).
|
||||
Select("geoname_id", "city_name").
|
||||
Where("country_iso_code = ?", countryISOCode).
|
||||
Group("city_name").
|
||||
@ -98,7 +119,7 @@ func (s *SqliteStore) reload() error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
newSha256sum1, err := getSha256sum(s.filePath)
|
||||
newSha256sum1, err := calculateFileSHA256(s.filePath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err)
|
||||
}
|
||||
@ -107,7 +128,7 @@ func (s *SqliteStore) reload() error {
|
||||
// we check sum twice just to avoid possible case when we reload during update of the file
|
||||
// considering the frequency of file update (few times a week) checking sum twice should be enough
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
newSha256sum2, err := getSha256sum(s.filePath)
|
||||
newSha256sum2, err := calculateFileSHA256(s.filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err)
|
||||
}
|
||||
|
176
management/server/geolocation/utils.go
Normal file
176
management/server/geolocation/utils.go
Normal file
@ -0,0 +1,176 @@
|
||||
package geolocation
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"archive/zip"
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// decompressTarGzFile decompresses a .tar.gz file.
|
||||
func decompressTarGzFile(filepath, destDir string) error {
|
||||
file, err := os.Open(filepath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
gzipReader, err := gzip.NewReader(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer gzipReader.Close()
|
||||
|
||||
tarReader := tar.NewReader(gzipReader)
|
||||
|
||||
for {
|
||||
header, err := tarReader.Next()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if header.Typeflag == tar.TypeReg {
|
||||
outFile, err := os.Create(path.Join(destDir, path.Base(header.Name)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.Copy(outFile, tarReader) // #nosec G110
|
||||
outFile.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// decompressZipFile decompresses a .zip file.
|
||||
func decompressZipFile(filepath, destDir string) error {
|
||||
r, err := zip.OpenReader(filepath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
for _, f := range r.File {
|
||||
if f.FileInfo().IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
outFile, err := os.Create(path.Join(destDir, path.Base(f.Name)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rc, err := f.Open()
|
||||
if err != nil {
|
||||
outFile.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.Copy(outFile, rc) // #nosec G110
|
||||
outFile.Close()
|
||||
rc.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateFileSHA256 calculates the SHA256 checksum of a file.
|
||||
func calculateFileSHA256(filepath string) ([]byte, error) {
|
||||
file, err := os.Open(filepath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, file); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return h.Sum(nil), nil
|
||||
}
|
||||
|
||||
// loadChecksumFromFile loads the first checksum from a file.
|
||||
func loadChecksumFromFile(filepath string) (string, error) {
|
||||
file, err := os.Open(filepath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
if scanner.Scan() {
|
||||
parts := strings.Fields(scanner.Text())
|
||||
if len(parts) > 0 {
|
||||
return parts[0], nil
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// verifyChecksum compares the calculated SHA256 checksum of a file against the expected checksum.
|
||||
func verifyChecksum(filepath, expectedChecksum string) error {
|
||||
calculatedChecksum, err := calculateFileSHA256(filepath)
|
||||
|
||||
fileCheckSum := fmt.Sprintf("%x", calculatedChecksum)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if fileCheckSum != expectedChecksum {
|
||||
return fmt.Errorf("checksum mismatch: expected %s, got %s", expectedChecksum, fileCheckSum)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// downloadFile downloads a file from a URL and saves it to a local file path.
|
||||
func downloadFile(url, filepath string) error {
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("unexpected error occurred while downloading the file: %s", string(bodyBytes))
|
||||
}
|
||||
|
||||
out, err := os.Create(filepath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
_, err = io.Copy(out, bytes.NewBuffer(bodyBytes))
|
||||
return err
|
||||
}
|
Loading…
Reference in New Issue
Block a user