mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-17 10:31:45 +02:00
184 lines
4.9 KiB
Go
184 lines
4.9 KiB
Go
package updatemanager
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
v "github.com/hashicorp/go-version"
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
"github.com/netbirdio/netbird/client/internal/peer"
|
|
cProto "github.com/netbirdio/netbird/client/proto"
|
|
"github.com/netbirdio/netbird/version"
|
|
)
|
|
|
|
const (
|
|
latestVersion = "latest"
|
|
disableAutoUpdate = "disabled"
|
|
unknownVersion = "Unknown"
|
|
)
|
|
|
|
type UpdateManager struct {
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
version string
|
|
latestVersion string
|
|
update *version.Update
|
|
lastTrigger time.Time
|
|
statusRecorder *peer.Status
|
|
mutex sync.Mutex
|
|
waitGroup sync.WaitGroup
|
|
}
|
|
|
|
func NewUpdateManager(ctx context.Context, statusRecorder *peer.Status) *UpdateManager {
|
|
update := version.NewUpdate("nb/client")
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
manager := &UpdateManager{
|
|
update: update,
|
|
lastTrigger: time.Now().Add(-10 * time.Minute),
|
|
statusRecorder: statusRecorder,
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
version: disableAutoUpdate,
|
|
latestVersion: unknownVersion,
|
|
}
|
|
update.SetDaemonVersion(version.NetbirdVersion())
|
|
update.SetOnUpdateListener(manager.Updated)
|
|
return manager
|
|
}
|
|
|
|
func (u *UpdateManager) SetVersion(v string) {
|
|
u.mutex.Lock()
|
|
if u.version != v {
|
|
u.mutex.Unlock()
|
|
log.Tracef("Auto-update version set to %s", v)
|
|
u.version = v
|
|
go u.Updated("N/A")
|
|
} else {
|
|
u.mutex.Unlock()
|
|
}
|
|
}
|
|
|
|
func (u *UpdateManager) Stop() {
|
|
u.update.StopWatch()
|
|
u.cancel()
|
|
u.waitGroup.Wait()
|
|
}
|
|
|
|
func (u *UpdateManager) Updated(latestVersion string) {
|
|
u.waitGroup.Add(1)
|
|
defer u.waitGroup.Done()
|
|
u.mutex.Lock()
|
|
defer u.mutex.Unlock()
|
|
select {
|
|
case <-u.ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
if latestVersion != "N/A" {
|
|
u.latestVersion = latestVersion
|
|
}
|
|
ctx, cancel := context.WithDeadline(u.ctx, time.Now().Add(time.Minute))
|
|
defer cancel()
|
|
u.CheckForUpdates(ctx)
|
|
}
|
|
|
|
func (u *UpdateManager) CheckForUpdates(ctx context.Context) {
|
|
if u.version == disableAutoUpdate {
|
|
log.Trace("Skipped checking for updates, auto-update is disabled")
|
|
return
|
|
}
|
|
currentVersionString := version.NetbirdVersion()
|
|
updateVersionString := u.version
|
|
if updateVersionString == latestVersion || updateVersionString == "" {
|
|
if u.latestVersion == unknownVersion {
|
|
log.Tracef("Latest version not fetched yet")
|
|
return
|
|
}
|
|
updateVersionString = u.latestVersion
|
|
}
|
|
currentVersion, err := v.NewVersion(currentVersionString)
|
|
if err != nil {
|
|
log.Errorf("Error checking for update, error parsing version `%s`: %v", currentVersionString, err)
|
|
return
|
|
}
|
|
updateVersion, err := v.NewVersion(updateVersionString)
|
|
if err != nil {
|
|
log.Errorf("Error checking for update, error parsing version `%s`: %v", updateVersionString, err)
|
|
return
|
|
}
|
|
if currentVersion.LessThan(updateVersion) {
|
|
if u.lastTrigger.Add(5 * time.Minute).Before(time.Now()) {
|
|
u.lastTrigger = time.Now()
|
|
log.Debugf("Auto-update triggered, current version: %s, target version: %s", currentVersionString, updateVersionString)
|
|
u.statusRecorder.PublishEvent(
|
|
cProto.SystemEvent_INFO,
|
|
cProto.SystemEvent_SYSTEM,
|
|
"Automatically updating client",
|
|
"Your client version is older than auto-update version set in Management, updating client now.",
|
|
nil,
|
|
)
|
|
err = u.triggerUpdate(ctx, updateVersionString)
|
|
if err != nil {
|
|
log.Errorf("Error triggering auto-update: %v", err)
|
|
}
|
|
}
|
|
} else {
|
|
log.Debugf("Current version (%s) is equal to or higher than auto-update version (%s)", currentVersionString, updateVersionString)
|
|
}
|
|
}
|
|
|
|
func downloadFileToTemporaryDir(ctx context.Context, fileURL string) (string, error) { //nolint:unused
|
|
tempDir, err := os.MkdirTemp("", "netbird-installer-*")
|
|
if err != nil {
|
|
return "", fmt.Errorf("error creating temporary directory: %w", err)
|
|
}
|
|
fileNameParts := strings.Split(fileURL, "/")
|
|
out, err := os.Create(filepath.Join(tempDir, fileNameParts[len(fileNameParts)-1]))
|
|
if err != nil {
|
|
return "", fmt.Errorf("error creating temporary file: %w", err)
|
|
}
|
|
defer func() {
|
|
if err := out.Close(); err != nil {
|
|
log.Errorf("Error closing temporary file: %v", err)
|
|
}
|
|
}()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", fileURL, nil)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error creating file download request: %w", err)
|
|
}
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error downloading file: %w", err)
|
|
}
|
|
defer func() {
|
|
if err := resp.Body.Close(); err != nil {
|
|
log.Errorf("Error closing response body: %v", err)
|
|
}
|
|
}()
|
|
|
|
_, err = io.Copy(out, resp.Body)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error downloading file: %w", err)
|
|
}
|
|
|
|
log.Tracef("Downloaded update file to %s", out.Name())
|
|
|
|
return out.Name(), nil
|
|
}
|
|
|
|
func urlWithVersionArch(url, version string) string { //nolint:unused
|
|
url = strings.ReplaceAll(url, "%version", version)
|
|
url = strings.ReplaceAll(url, "%arch", runtime.GOARCH)
|
|
return url
|
|
}
|