mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-25 09:33:24 +01:00
Use the conn state of peer on proper way (#717)
The ConnStatus is a custom type based on iota like an enum. The problem was nowhere used to the benefits of this implementation. All ConnStatus instances has been compared with strings. I suppose the reason to do it to avoid a circle dependency. In this commit the separated status package has been moved to peer package. Remove unused, exported functions from engine
This commit is contained in:
parent
e914adb5cd
commit
337d3edcc4
@ -13,8 +13,8 @@ import (
|
|||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
nbStatus "github.com/netbirdio/netbird/client/status"
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -94,7 +94,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
var cancel context.CancelFunc
|
var cancel context.CancelFunc
|
||||||
ctx, cancel = context.WithCancel(ctx)
|
ctx, cancel = context.WithCancel(ctx)
|
||||||
SetupCloseHandler(ctx, cancel)
|
SetupCloseHandler(ctx, cancel)
|
||||||
return internal.RunClient(ctx, config, nbStatus.NewRecorder())
|
return internal.RunClient(ctx, config, peer.NewRecorder())
|
||||||
}
|
}
|
||||||
|
|
||||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||||
|
@ -12,8 +12,8 @@ import (
|
|||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
nbStatus "github.com/netbirdio/netbird/client/status"
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
mgm "github.com/netbirdio/netbird/management/client"
|
mgm "github.com/netbirdio/netbird/management/client"
|
||||||
@ -22,7 +22,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// RunClient with main logic.
|
// RunClient with main logic.
|
||||||
func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Status) error {
|
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
|
||||||
backOff := &backoff.ExponentialBackOff{
|
backOff := &backoff.ExponentialBackOff{
|
||||||
InitialInterval: time.Second,
|
InitialInterval: time.Second,
|
||||||
RandomizationFactor: 1,
|
RandomizationFactor: 1,
|
||||||
@ -103,7 +103,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
|
|||||||
}
|
}
|
||||||
statusRecorder.MarkManagementConnected(managementURL)
|
statusRecorder.MarkManagementConnected(managementURL)
|
||||||
|
|
||||||
localPeerState := nbStatus.LocalPeerState{
|
localPeerState := peer.LocalPeerState{
|
||||||
IP: loginResp.GetPeerConfig().GetAddress(),
|
IP: loginResp.GetPeerConfig().GetAddress(),
|
||||||
PubKey: myPrivateKey.PublicKey().String(),
|
PubKey: myPrivateKey.PublicKey().String(),
|
||||||
KernelInterface: iface.WireguardModuleIsLoaded(),
|
KernelInterface: iface.WireguardModuleIsLoaded(),
|
||||||
|
@ -12,24 +12,23 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
mgm "github.com/netbirdio/netbird/management/client"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
|
||||||
signal "github.com/netbirdio/netbird/signal/client"
|
|
||||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
|
||||||
"github.com/netbirdio/netbird/util"
|
|
||||||
"github.com/pion/ice/v2"
|
"github.com/pion/ice/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
mgm "github.com/netbirdio/netbird/management/client"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
|
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
|
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
|
||||||
@ -109,7 +108,7 @@ type Engine struct {
|
|||||||
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
|
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
|
||||||
sshServer nbssh.Server
|
sshServer nbssh.Server
|
||||||
|
|
||||||
statusRecorder *nbstatus.Status
|
statusRecorder *peer.Status
|
||||||
|
|
||||||
routeManager routemanager.Manager
|
routeManager routemanager.Manager
|
||||||
|
|
||||||
@ -126,7 +125,7 @@ type Peer struct {
|
|||||||
func NewEngine(
|
func NewEngine(
|
||||||
ctx context.Context, cancel context.CancelFunc,
|
ctx context.Context, cancel context.CancelFunc,
|
||||||
signalClient signal.Client, mgmClient mgm.Client,
|
signalClient signal.Client, mgmClient mgm.Client,
|
||||||
config *EngineConfig, statusRecorder *nbstatus.Status,
|
config *EngineConfig, statusRecorder *peer.Status,
|
||||||
) *Engine {
|
) *Engine {
|
||||||
return &Engine{
|
return &Engine{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
@ -338,42 +337,6 @@ func (e *Engine) removePeer(peerKey string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPeerConnectionStatus returns a connection Status or nil if peer connection wasn't found
|
|
||||||
func (e *Engine) GetPeerConnectionStatus(peerKey string) peer.ConnStatus {
|
|
||||||
conn, exists := e.peerConns[peerKey]
|
|
||||||
if exists && conn != nil {
|
|
||||||
return conn.Status()
|
|
||||||
}
|
|
||||||
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *Engine) GetPeers() []string {
|
|
||||||
e.syncMsgMux.Lock()
|
|
||||||
defer e.syncMsgMux.Unlock()
|
|
||||||
|
|
||||||
peers := []string{}
|
|
||||||
for s := range e.peerConns {
|
|
||||||
peers = append(peers, s)
|
|
||||||
}
|
|
||||||
return peers
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetConnectedPeers returns a connection Status or nil if peer connection wasn't found
|
|
||||||
func (e *Engine) GetConnectedPeers() []string {
|
|
||||||
e.syncMsgMux.Lock()
|
|
||||||
defer e.syncMsgMux.Unlock()
|
|
||||||
|
|
||||||
peers := []string{}
|
|
||||||
for s, conn := range e.peerConns {
|
|
||||||
if conn.Status() == peer.StatusConnected {
|
|
||||||
peers = append(peers, s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return peers
|
|
||||||
}
|
|
||||||
|
|
||||||
func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client) error {
|
func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client) error {
|
||||||
err := s.Send(&sProto.Message{
|
err := s.Send(&sProto.Message{
|
||||||
Key: myKey.PublicKey().String(),
|
Key: myKey.PublicKey().String(),
|
||||||
@ -509,7 +472,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
e.statusRecorder.UpdateLocalPeerState(nbstatus.LocalPeerState{
|
e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{
|
||||||
IP: e.config.WgAddr,
|
IP: e.config.WgAddr,
|
||||||
PubKey: e.config.WgPrivateKey.PublicKey().String(),
|
PubKey: e.config.WgPrivateKey.PublicKey().String(),
|
||||||
KernelInterface: iface.WireguardModuleIsLoaded(),
|
KernelInterface: iface.WireguardModuleIsLoaded(),
|
||||||
|
@ -3,16 +3,6 @@ package internal
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
@ -23,18 +13,29 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/keepalive"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
mgmt "github.com/netbirdio/netbird/management/client"
|
mgmt "github.com/netbirdio/netbird/management/client"
|
||||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
signal "github.com/netbirdio/netbird/signal/client"
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
"github.com/netbirdio/netbird/signal/proto"
|
"github.com/netbirdio/netbird/signal/proto"
|
||||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
"google.golang.org/grpc/keepalive"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -71,7 +72,7 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
WgAddr: "100.64.0.1/24",
|
WgAddr: "100.64.0.1/24",
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
}, nbstatus.NewRecorder())
|
}, peer.NewRecorder())
|
||||||
|
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||||
@ -205,7 +206,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
WgAddr: "100.64.0.1/24",
|
WgAddr: "100.64.0.1/24",
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
}, nbstatus.NewRecorder())
|
}, peer.NewRecorder())
|
||||||
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU)
|
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU)
|
||||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder)
|
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder)
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
@ -389,7 +390,7 @@ func TestEngine_Sync(t *testing.T) {
|
|||||||
WgAddr: "100.64.0.1/24",
|
WgAddr: "100.64.0.1/24",
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
}, nbstatus.NewRecorder())
|
}, peer.NewRecorder())
|
||||||
|
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||||
@ -439,7 +440,7 @@ func TestEngine_Sync(t *testing.T) {
|
|||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(engine.GetPeers()) == 3 && engine.networkSerial == 10 {
|
if getPeers(engine) == 3 && engine.networkSerial == 10 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -547,7 +548,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
WgAddr: wgAddr,
|
WgAddr: wgAddr,
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
}, nbstatus.NewRecorder())
|
}, peer.NewRecorder())
|
||||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
|
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
input := struct {
|
input := struct {
|
||||||
@ -712,7 +713,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
WgAddr: wgAddr,
|
WgAddr: wgAddr,
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
}, nbstatus.NewRecorder())
|
}, peer.NewRecorder())
|
||||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
|
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
@ -846,7 +847,7 @@ loop:
|
|||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
totalConnected := 0
|
totalConnected := 0
|
||||||
for _, engine := range engines {
|
for _, engine := range engines {
|
||||||
totalConnected = totalConnected + len(engine.GetConnectedPeers())
|
totalConnected = totalConnected + getConnectedPeers(engine)
|
||||||
}
|
}
|
||||||
if totalConnected == expectedConnected {
|
if totalConnected == expectedConnected {
|
||||||
log.Infof("total connected=%d", totalConnected)
|
log.Infof("total connected=%d", totalConnected)
|
||||||
@ -977,7 +978,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
|||||||
WgPort: wgPort,
|
WgPort: wgPort,
|
||||||
}
|
}
|
||||||
|
|
||||||
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, nbstatus.NewRecorder()), nil
|
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, peer.NewRecorder()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func startSignal() (*grpc.Server, string, error) {
|
func startSignal() (*grpc.Server, string, error) {
|
||||||
@ -1044,3 +1045,23 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
|
|||||||
|
|
||||||
return s, lis.Addr().String(), nil
|
return s, lis.Addr().String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
|
||||||
|
func getConnectedPeers(e *Engine) int {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
defer e.syncMsgMux.Unlock()
|
||||||
|
i := 0
|
||||||
|
for _, conn := range e.peerConns {
|
||||||
|
if conn.Status() == peer.StatusConnected {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPeers(e *Engine) int {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
return len(e.peerConns)
|
||||||
|
}
|
||||||
|
@ -7,13 +7,13 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
|
||||||
nbStatus "github.com/netbirdio/netbird/client/status"
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
"github.com/pion/ice/v2"
|
"github.com/pion/ice/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||||
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnConfig is a peer Connection configuration
|
// ConnConfig is a peer Connection configuration
|
||||||
@ -83,7 +83,7 @@ type Conn struct {
|
|||||||
agent *ice.Agent
|
agent *ice.Agent
|
||||||
status ConnStatus
|
status ConnStatus
|
||||||
|
|
||||||
statusRecorder *nbStatus.Status
|
statusRecorder *Status
|
||||||
|
|
||||||
proxy proxy.Proxy
|
proxy proxy.Proxy
|
||||||
}
|
}
|
||||||
@ -100,7 +100,7 @@ func (conn *Conn) UpdateConf(conf ConnConfig) {
|
|||||||
|
|
||||||
// NewConn creates a new not opened Conn to the remote peer.
|
// NewConn creates a new not opened Conn to the remote peer.
|
||||||
// To establish a connection run Conn.Open
|
// To establish a connection run Conn.Open
|
||||||
func NewConn(config ConnConfig, statusRecorder *nbStatus.Status) (*Conn, error) {
|
func NewConn(config ConnConfig, statusRecorder *Status) (*Conn, error) {
|
||||||
return &Conn{
|
return &Conn{
|
||||||
config: config,
|
config: config,
|
||||||
mu: sync.Mutex{},
|
mu: sync.Mutex{},
|
||||||
@ -190,11 +190,11 @@ func (conn *Conn) reCreateAgent() error {
|
|||||||
func (conn *Conn) Open() error {
|
func (conn *Conn) Open() error {
|
||||||
log.Debugf("trying to connect to peer %s", conn.config.Key)
|
log.Debugf("trying to connect to peer %s", conn.config.Key)
|
||||||
|
|
||||||
peerState := nbStatus.PeerState{PubKey: conn.config.Key}
|
peerState := State{PubKey: conn.config.Key}
|
||||||
|
|
||||||
peerState.IP = strings.Split(conn.config.ProxyConfig.AllowedIps, "/")[0]
|
peerState.IP = strings.Split(conn.config.ProxyConfig.AllowedIps, "/")[0]
|
||||||
peerState.ConnStatusUpdate = time.Now()
|
peerState.ConnStatusUpdate = time.Now()
|
||||||
peerState.ConnStatus = conn.status.String()
|
peerState.ConnStatus = conn.status
|
||||||
|
|
||||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -250,9 +250,9 @@ func (conn *Conn) Open() error {
|
|||||||
defer conn.notifyDisconnected()
|
defer conn.notifyDisconnected()
|
||||||
conn.mu.Unlock()
|
conn.mu.Unlock()
|
||||||
|
|
||||||
peerState = nbStatus.PeerState{PubKey: conn.config.Key}
|
peerState = State{PubKey: conn.config.Key}
|
||||||
|
|
||||||
peerState.ConnStatus = conn.status.String()
|
peerState.ConnStatus = conn.status
|
||||||
peerState.ConnStatusUpdate = time.Now()
|
peerState.ConnStatusUpdate = time.Now()
|
||||||
err = conn.statusRecorder.UpdatePeerState(peerState)
|
err = conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -359,7 +359,7 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
peerState := nbStatus.PeerState{PubKey: conn.config.Key}
|
peerState := State{PubKey: conn.config.Key}
|
||||||
useProxy := shouldUseProxy(pair)
|
useProxy := shouldUseProxy(pair)
|
||||||
var p proxy.Proxy
|
var p proxy.Proxy
|
||||||
if useProxy {
|
if useProxy {
|
||||||
@ -377,7 +377,7 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
|
|||||||
|
|
||||||
conn.status = StatusConnected
|
conn.status = StatusConnected
|
||||||
|
|
||||||
peerState.ConnStatus = conn.status.String()
|
peerState.ConnStatus = conn.status
|
||||||
peerState.ConnStatusUpdate = time.Now()
|
peerState.ConnStatusUpdate = time.Now()
|
||||||
peerState.LocalIceCandidateType = pair.Local.Type().String()
|
peerState.LocalIceCandidateType = pair.Local.Type().String()
|
||||||
peerState.RemoteIceCandidateType = pair.Remote.Type().String()
|
peerState.RemoteIceCandidateType = pair.Remote.Type().String()
|
||||||
@ -422,8 +422,8 @@ func (conn *Conn) cleanup() error {
|
|||||||
|
|
||||||
conn.status = StatusDisconnected
|
conn.status = StatusDisconnected
|
||||||
|
|
||||||
peerState := nbStatus.PeerState{PubKey: conn.config.Key}
|
peerState := State{PubKey: conn.config.Key}
|
||||||
peerState.ConnStatus = conn.status.String()
|
peerState.ConnStatus = conn.status
|
||||||
peerState.ConnStatusUpdate = time.Now()
|
peerState.ConnStatusUpdate = time.Now()
|
||||||
|
|
||||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||||
|
29
client/internal/peer/conn_status.go
Normal file
29
client/internal/peer/conn_status.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
const (
|
||||||
|
// StatusConnected indicate the peer is in connected state
|
||||||
|
StatusConnected ConnStatus = iota
|
||||||
|
// StatusConnecting indicate the peer is in connecting state
|
||||||
|
StatusConnecting
|
||||||
|
// StatusDisconnected indicate the peer is in disconnected state
|
||||||
|
StatusDisconnected
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConnStatus describe the status of a peer's connection
|
||||||
|
type ConnStatus int
|
||||||
|
|
||||||
|
func (s ConnStatus) String() string {
|
||||||
|
switch s {
|
||||||
|
case StatusConnecting:
|
||||||
|
return "Connecting"
|
||||||
|
case StatusConnected:
|
||||||
|
return "Connected"
|
||||||
|
case StatusDisconnected:
|
||||||
|
return "Disconnected"
|
||||||
|
default:
|
||||||
|
log.Errorf("unknown status: %d", s)
|
||||||
|
return "INVALID_PEER_CONNECTION_STATUS"
|
||||||
|
}
|
||||||
|
}
|
27
client/internal/peer/conn_status_test.go
Normal file
27
client/internal/peer/conn_status_test.go
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/magiconair/properties/assert"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConnStatus_String(t *testing.T) {
|
||||||
|
|
||||||
|
tables := []struct {
|
||||||
|
name string
|
||||||
|
status ConnStatus
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"StatusConnected", StatusConnected, "Connected"},
|
||||||
|
{"StatusDisconnected", StatusDisconnected, "Disconnected"},
|
||||||
|
{"StatusConnecting", StatusConnecting, "Connecting"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, table := range tables {
|
||||||
|
t.Run(table.name, func(t *testing.T) {
|
||||||
|
got := table.status.String()
|
||||||
|
assert.Equal(t, got, table.want, "they should be equal")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -1,14 +1,15 @@
|
|||||||
package peer
|
package peer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/magiconair/properties/assert"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
"github.com/pion/ice/v2"
|
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/magiconair/properties/assert"
|
||||||
|
"github.com/pion/ice/v2"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
var connConf = ConnConfig{
|
var connConf = ConnConfig{
|
||||||
@ -46,7 +47,7 @@ func TestConn_GetKey(t *testing.T) {
|
|||||||
|
|
||||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||||
|
|
||||||
conn, err := NewConn(connConf, nbstatus.NewRecorder())
|
conn, err := NewConn(connConf, NewRecorder())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -80,7 +81,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
|||||||
|
|
||||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||||
|
|
||||||
conn, err := NewConn(connConf, nbstatus.NewRecorder())
|
conn, err := NewConn(connConf, NewRecorder())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -113,7 +114,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
func TestConn_Status(t *testing.T) {
|
func TestConn_Status(t *testing.T) {
|
||||||
|
|
||||||
conn, err := NewConn(connConf, nbstatus.NewRecorder())
|
conn, err := NewConn(connConf, NewRecorder())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -140,7 +141,7 @@ func TestConn_Status(t *testing.T) {
|
|||||||
|
|
||||||
func TestConn_Close(t *testing.T) {
|
func TestConn_Close(t *testing.T) {
|
||||||
|
|
||||||
conn, err := NewConn(connConf, nbstatus.NewRecorder())
|
conn, err := NewConn(connConf, NewRecorder())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -1,25 +1,241 @@
|
|||||||
package peer
|
package peer
|
||||||
|
|
||||||
import log "github.com/sirupsen/logrus"
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
type ConnStatus int
|
// State contains the latest state of a peer
|
||||||
|
type State struct {
|
||||||
|
IP string
|
||||||
|
PubKey string
|
||||||
|
FQDN string
|
||||||
|
ConnStatus ConnStatus
|
||||||
|
ConnStatusUpdate time.Time
|
||||||
|
Relayed bool
|
||||||
|
Direct bool
|
||||||
|
LocalIceCandidateType string
|
||||||
|
RemoteIceCandidateType string
|
||||||
|
}
|
||||||
|
|
||||||
func (s ConnStatus) String() string {
|
// LocalPeerState contains the latest state of the local peer
|
||||||
switch s {
|
type LocalPeerState struct {
|
||||||
case StatusConnecting:
|
IP string
|
||||||
return "Connecting"
|
PubKey string
|
||||||
case StatusConnected:
|
KernelInterface bool
|
||||||
return "Connected"
|
FQDN string
|
||||||
case StatusDisconnected:
|
}
|
||||||
return "Disconnected"
|
|
||||||
default:
|
// SignalState contains the latest state of a signal connection
|
||||||
log.Errorf("unknown status: %d", s)
|
type SignalState struct {
|
||||||
return "INVALID_PEER_CONNECTION_STATUS"
|
URL string
|
||||||
|
Connected bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ManagementState contains the latest state of a management connection
|
||||||
|
type ManagementState struct {
|
||||||
|
URL string
|
||||||
|
Connected bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// FullStatus contains the full state held by the Status instance
|
||||||
|
type FullStatus struct {
|
||||||
|
Peers []State
|
||||||
|
ManagementState ManagementState
|
||||||
|
SignalState SignalState
|
||||||
|
LocalPeerState LocalPeerState
|
||||||
|
}
|
||||||
|
|
||||||
|
// Status holds a state of peers, signal and management connections
|
||||||
|
type Status struct {
|
||||||
|
mux sync.Mutex
|
||||||
|
peers map[string]State
|
||||||
|
changeNotify map[string]chan struct{}
|
||||||
|
signal SignalState
|
||||||
|
management ManagementState
|
||||||
|
localPeer LocalPeerState
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRecorder returns a new Status instance
|
||||||
|
func NewRecorder() *Status {
|
||||||
|
return &Status{
|
||||||
|
peers: make(map[string]State),
|
||||||
|
changeNotify: make(map[string]chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
// AddPeer adds peer to Daemon status map
|
||||||
StatusConnected ConnStatus = iota
|
func (d *Status) AddPeer(peerPubKey string) error {
|
||||||
StatusConnecting
|
d.mux.Lock()
|
||||||
StatusDisconnected
|
defer d.mux.Unlock()
|
||||||
)
|
|
||||||
|
_, ok := d.peers[peerPubKey]
|
||||||
|
if ok {
|
||||||
|
return errors.New("peer already exist")
|
||||||
|
}
|
||||||
|
d.peers[peerPubKey] = State{PubKey: peerPubKey}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeer adds peer to Daemon status map
|
||||||
|
func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
state, ok := d.peers[peerPubKey]
|
||||||
|
if !ok {
|
||||||
|
return State{}, errors.New("peer not found")
|
||||||
|
}
|
||||||
|
return state, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemovePeer removes peer from Daemon status map
|
||||||
|
func (d *Status) RemovePeer(peerPubKey string) error {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
_, ok := d.peers[peerPubKey]
|
||||||
|
if ok {
|
||||||
|
delete(d.peers, peerPubKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors.New("no peer with to remove")
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePeerState updates peer status
|
||||||
|
func (d *Status) UpdatePeerState(receivedState State) error {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
peerState, ok := d.peers[receivedState.PubKey]
|
||||||
|
if !ok {
|
||||||
|
return errors.New("peer doesn't exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
if receivedState.IP != "" {
|
||||||
|
peerState.IP = receivedState.IP
|
||||||
|
}
|
||||||
|
|
||||||
|
if receivedState.ConnStatus != peerState.ConnStatus {
|
||||||
|
peerState.ConnStatus = receivedState.ConnStatus
|
||||||
|
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
|
||||||
|
peerState.Direct = receivedState.Direct
|
||||||
|
peerState.Relayed = receivedState.Relayed
|
||||||
|
peerState.LocalIceCandidateType = receivedState.LocalIceCandidateType
|
||||||
|
peerState.RemoteIceCandidateType = receivedState.RemoteIceCandidateType
|
||||||
|
}
|
||||||
|
|
||||||
|
d.peers[receivedState.PubKey] = peerState
|
||||||
|
|
||||||
|
ch, found := d.changeNotify[receivedState.PubKey]
|
||||||
|
if found && ch != nil {
|
||||||
|
close(ch)
|
||||||
|
d.changeNotify[receivedState.PubKey] = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePeerFQDN update peer's state fqdn only
|
||||||
|
func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
peerState, ok := d.peers[peerPubKey]
|
||||||
|
if !ok {
|
||||||
|
return errors.New("peer doesn't exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
peerState.FQDN = fqdn
|
||||||
|
d.peers[peerPubKey] = peerState
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeerStateChangeNotifier returns a change notifier channel for a peer
|
||||||
|
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
ch, found := d.changeNotify[peer]
|
||||||
|
if !found || ch == nil {
|
||||||
|
ch = make(chan struct{})
|
||||||
|
d.changeNotify[peer] = ch
|
||||||
|
}
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateLocalPeerState updates local peer status
|
||||||
|
func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
d.localPeer = localPeerState
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanLocalPeerState cleans local peer status
|
||||||
|
func (d *Status) CleanLocalPeerState() {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
d.localPeer = LocalPeerState{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkManagementDisconnected sets ManagementState to disconnected
|
||||||
|
func (d *Status) MarkManagementDisconnected(managementURL string) {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
d.management = ManagementState{
|
||||||
|
URL: managementURL,
|
||||||
|
Connected: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkManagementConnected sets ManagementState to connected
|
||||||
|
func (d *Status) MarkManagementConnected(managementURL string) {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
d.management = ManagementState{
|
||||||
|
URL: managementURL,
|
||||||
|
Connected: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkSignalDisconnected sets SignalState to disconnected
|
||||||
|
func (d *Status) MarkSignalDisconnected(signalURL string) {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
d.signal = SignalState{
|
||||||
|
signalURL,
|
||||||
|
false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkSignalConnected sets SignalState to connected
|
||||||
|
func (d *Status) MarkSignalConnected(signalURL string) {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
d.signal = SignalState{
|
||||||
|
signalURL,
|
||||||
|
true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFullStatus gets full status
|
||||||
|
func (d *Status) GetFullStatus() FullStatus {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
fullStatus := FullStatus{
|
||||||
|
ManagementState: d.management,
|
||||||
|
SignalState: d.signal,
|
||||||
|
LocalPeerState: d.localPeer,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, status := range d.peers {
|
||||||
|
fullStatus.Peers = append(fullStatus.Peers, status)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fullStatus
|
||||||
|
}
|
||||||
|
@ -1,27 +1,244 @@
|
|||||||
package peer
|
package peer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/magiconair/properties/assert"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConnStatus_String(t *testing.T) {
|
func TestAddPeer(t *testing.T) {
|
||||||
|
key := "abc"
|
||||||
|
status := NewRecorder()
|
||||||
|
err := status.AddPeer(key)
|
||||||
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
tables := []struct {
|
_, exists := status.peers[key]
|
||||||
name string
|
assert.True(t, exists, "value was found")
|
||||||
status ConnStatus
|
|
||||||
want string
|
err = status.AddPeer(key)
|
||||||
}{
|
|
||||||
{"StatusConnected", StatusConnected, "Connected"},
|
assert.Error(t, err, "should return error on duplicate")
|
||||||
{"StatusDisconnected", StatusDisconnected, "Disconnected"},
|
}
|
||||||
{"StatusConnecting", StatusConnecting, "Connecting"},
|
|
||||||
|
func TestGetPeer(t *testing.T) {
|
||||||
|
key := "abc"
|
||||||
|
status := NewRecorder()
|
||||||
|
err := status.AddPeer(key)
|
||||||
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
|
peerStatus, err := status.GetPeer(key)
|
||||||
|
assert.NoError(t, err, "shouldn't return error on getting peer")
|
||||||
|
|
||||||
|
assert.Equal(t, key, peerStatus.PubKey, "retrieved public key should match")
|
||||||
|
|
||||||
|
_, err = status.GetPeer("non_existing_key")
|
||||||
|
assert.Error(t, err, "should return error when peer doesn't exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdatePeerState(t *testing.T) {
|
||||||
|
key := "abc"
|
||||||
|
ip := "10.10.10.10"
|
||||||
|
status := NewRecorder()
|
||||||
|
peerState := State{
|
||||||
|
PubKey: key,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, table := range tables {
|
status.peers[key] = peerState
|
||||||
t.Run(table.name, func(t *testing.T) {
|
|
||||||
got := table.status.String()
|
peerState.IP = ip
|
||||||
assert.Equal(t, got, table.want, "they should be equal")
|
|
||||||
|
err := status.UpdatePeerState(peerState)
|
||||||
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
|
state, exists := status.peers[key]
|
||||||
|
assert.True(t, exists, "state should be found")
|
||||||
|
assert.Equal(t, ip, state.IP, "ip should be equal")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
||||||
|
key := "abc"
|
||||||
|
fqdn := "peer-a.netbird.local"
|
||||||
|
status := NewRecorder()
|
||||||
|
peerState := State{
|
||||||
|
PubKey: key,
|
||||||
|
}
|
||||||
|
|
||||||
|
status.peers[key] = peerState
|
||||||
|
|
||||||
|
err := status.UpdatePeerFQDN(key, fqdn)
|
||||||
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
|
state, exists := status.peers[key]
|
||||||
|
assert.True(t, exists, "state should be found")
|
||||||
|
assert.Equal(t, fqdn, state.FQDN, "fqdn should be equal")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
|
||||||
|
key := "abc"
|
||||||
|
ip := "10.10.10.10"
|
||||||
|
status := NewRecorder()
|
||||||
|
peerState := State{
|
||||||
|
PubKey: key,
|
||||||
|
}
|
||||||
|
|
||||||
|
status.peers[key] = peerState
|
||||||
|
|
||||||
|
ch := status.GetPeerStateChangeNotifier(key)
|
||||||
|
assert.NotNil(t, ch, "channel shouldn't be nil")
|
||||||
|
|
||||||
|
peerState.IP = ip
|
||||||
|
|
||||||
|
err := status.UpdatePeerState(peerState)
|
||||||
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ch:
|
||||||
|
default:
|
||||||
|
t.Errorf("channel wasn't closed after update")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemovePeer(t *testing.T) {
|
||||||
|
key := "abc"
|
||||||
|
status := NewRecorder()
|
||||||
|
peerState := State{
|
||||||
|
PubKey: key,
|
||||||
|
}
|
||||||
|
|
||||||
|
status.peers[key] = peerState
|
||||||
|
|
||||||
|
err := status.RemovePeer(key)
|
||||||
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
|
_, exists := status.peers[key]
|
||||||
|
assert.False(t, exists, "state value shouldn't be found")
|
||||||
|
|
||||||
|
err = status.RemovePeer("not existing")
|
||||||
|
assert.Error(t, err, "should return error when peer doesn't exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateLocalPeerState(t *testing.T) {
|
||||||
|
localPeerState := LocalPeerState{
|
||||||
|
IP: "10.10.10.10",
|
||||||
|
PubKey: "abc",
|
||||||
|
KernelInterface: false,
|
||||||
|
}
|
||||||
|
status := NewRecorder()
|
||||||
|
|
||||||
|
status.UpdateLocalPeerState(localPeerState)
|
||||||
|
|
||||||
|
assert.Equal(t, localPeerState, status.localPeer, "local peer status should be equal")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanLocalPeerState(t *testing.T) {
|
||||||
|
emptyLocalPeerState := LocalPeerState{}
|
||||||
|
localPeerState := LocalPeerState{
|
||||||
|
IP: "10.10.10.10",
|
||||||
|
PubKey: "abc",
|
||||||
|
KernelInterface: false,
|
||||||
|
}
|
||||||
|
status := NewRecorder()
|
||||||
|
|
||||||
|
status.localPeer = localPeerState
|
||||||
|
|
||||||
|
status.CleanLocalPeerState()
|
||||||
|
|
||||||
|
assert.Equal(t, emptyLocalPeerState, status.localPeer, "local peer status should be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateSignalState(t *testing.T) {
|
||||||
|
url := "https://signal"
|
||||||
|
var tests = []struct {
|
||||||
|
name string
|
||||||
|
connected bool
|
||||||
|
want SignalState
|
||||||
|
}{
|
||||||
|
{"should mark as connected", true, SignalState{
|
||||||
|
|
||||||
|
URL: url,
|
||||||
|
Connected: true,
|
||||||
|
}},
|
||||||
|
{"should mark as disconnected", false, SignalState{
|
||||||
|
URL: url,
|
||||||
|
Connected: false,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
status := NewRecorder()
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
if test.connected {
|
||||||
|
status.MarkSignalConnected(url)
|
||||||
|
} else {
|
||||||
|
status.MarkSignalDisconnected(url)
|
||||||
|
}
|
||||||
|
assert.Equal(t, test.want, status.signal, "signal status should be equal")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateManagementState(t *testing.T) {
|
||||||
|
url := "https://management"
|
||||||
|
var tests = []struct {
|
||||||
|
name string
|
||||||
|
connected bool
|
||||||
|
want ManagementState
|
||||||
|
}{
|
||||||
|
{"should mark as connected", true, ManagementState{
|
||||||
|
|
||||||
|
URL: url,
|
||||||
|
Connected: true,
|
||||||
|
}},
|
||||||
|
{"should mark as disconnected", false, ManagementState{
|
||||||
|
URL: url,
|
||||||
|
Connected: false,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
status := NewRecorder()
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
if test.connected {
|
||||||
|
status.MarkManagementConnected(url)
|
||||||
|
} else {
|
||||||
|
status.MarkManagementDisconnected(url)
|
||||||
|
}
|
||||||
|
assert.Equal(t, test.want, status.management, "signal status should be equal")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFullStatus(t *testing.T) {
|
||||||
|
key1 := "abc"
|
||||||
|
key2 := "def"
|
||||||
|
managementState := ManagementState{
|
||||||
|
URL: "https://signal",
|
||||||
|
Connected: true,
|
||||||
|
}
|
||||||
|
signalState := SignalState{
|
||||||
|
URL: "https://signal",
|
||||||
|
Connected: true,
|
||||||
|
}
|
||||||
|
peerState1 := State{
|
||||||
|
PubKey: key1,
|
||||||
|
}
|
||||||
|
|
||||||
|
peerState2 := State{
|
||||||
|
PubKey: key2,
|
||||||
|
}
|
||||||
|
|
||||||
|
status := NewRecorder()
|
||||||
|
|
||||||
|
status.management = managementState
|
||||||
|
status.signal = signalState
|
||||||
|
status.peers[key1] = peerState1
|
||||||
|
status.peers[key2] = peerState2
|
||||||
|
|
||||||
|
fullStatus := status.GetFullStatus()
|
||||||
|
|
||||||
|
assert.Equal(t, managementState, fullStatus.ManagementState, "management status should be equal")
|
||||||
|
assert.Equal(t, signalState, fullStatus.SignalState, "signal status should be equal")
|
||||||
|
assert.ElementsMatch(t, []State{peerState1, peerState2}, fullStatus.Peers, "peers states should match")
|
||||||
}
|
}
|
||||||
|
@ -5,11 +5,11 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/status"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type routerPeerStatus struct {
|
type routerPeerStatus struct {
|
||||||
@ -26,7 +26,7 @@ type routesUpdate struct {
|
|||||||
type clientNetwork struct {
|
type clientNetwork struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
stop context.CancelFunc
|
stop context.CancelFunc
|
||||||
statusRecorder *status.Status
|
statusRecorder *peer.Status
|
||||||
wgInterface *iface.WGIface
|
wgInterface *iface.WGIface
|
||||||
routes map[string]*route.Route
|
routes map[string]*route.Route
|
||||||
routeUpdate chan routesUpdate
|
routeUpdate chan routesUpdate
|
||||||
@ -37,7 +37,7 @@ type clientNetwork struct {
|
|||||||
updateSerial uint64
|
updateSerial uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *status.Status, network netip.Prefix) *clientNetwork {
|
func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
client := &clientNetwork{
|
client := &clientNetwork{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
@ -62,7 +62,7 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
routePeerStatuses[r.ID] = routerPeerStatus{
|
routePeerStatuses[r.ID] = routerPeerStatus{
|
||||||
connected: peerStatus.ConnStatus == peer.StatusConnected.String(),
|
connected: peerStatus.ConnStatus == peer.StatusConnected,
|
||||||
relayed: peerStatus.Relayed,
|
relayed: peerStatus.Relayed,
|
||||||
direct: peerStatus.Direct,
|
direct: peerStatus.Direct,
|
||||||
}
|
}
|
||||||
@ -123,7 +123,7 @@ func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey stri
|
|||||||
return
|
return
|
||||||
case <-c.statusRecorder.GetPeerStateChangeNotifier(peerKey):
|
case <-c.statusRecorder.GetPeerStateChangeNotifier(peerKey):
|
||||||
state, err := c.statusRecorder.GetPeer(peerKey)
|
state, err := c.statusRecorder.GetPeer(peerKey)
|
||||||
if err != nil || state.ConnStatus == peer.StatusConnecting.String() {
|
if err != nil || state.ConnStatus == peer.StatusConnecting {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
peerStateUpdate <- struct{}{}
|
peerStateUpdate <- struct{}{}
|
||||||
@ -144,7 +144,7 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() {
|
|||||||
|
|
||||||
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
||||||
state, err := c.statusRecorder.GetPeer(peerKey)
|
state, err := c.statusRecorder.GetPeer(peerKey)
|
||||||
if err != nil || state.ConnStatus != peer.StatusConnected.String() {
|
if err != nil || state.ConnStatus != peer.StatusConnected {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,11 +6,12 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/status"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Manager is a route manager interface
|
// Manager is a route manager interface
|
||||||
@ -27,13 +28,13 @@ type DefaultManager struct {
|
|||||||
clientNetworks map[string]*clientNetwork
|
clientNetworks map[string]*clientNetwork
|
||||||
serverRoutes map[string]*route.Route
|
serverRoutes map[string]*route.Route
|
||||||
serverRouter *serverRouter
|
serverRouter *serverRouter
|
||||||
statusRecorder *status.Status
|
statusRecorder *peer.Status
|
||||||
wgInterface *iface.WGIface
|
wgInterface *iface.WGIface
|
||||||
pubKey string
|
pubKey string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewManager returns a new route manager
|
// NewManager returns a new route manager
|
||||||
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *status.Status) *DefaultManager {
|
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
|
||||||
mCTX, cancel := context.WithCancel(ctx)
|
mCTX, cancel := context.WithCancel(ctx)
|
||||||
return &DefaultManager{
|
return &DefaultManager{
|
||||||
ctx: mCTX,
|
ctx: mCTX,
|
||||||
|
@ -7,10 +7,11 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/status"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// send 5 routes, one for server and 4 for clients, one normal and 2 HA and one small
|
// send 5 routes, one for server and 4 for clients, one normal and 2 HA and one small
|
||||||
@ -397,7 +398,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
err = wgInterface.Create()
|
err = wgInterface.Create()
|
||||||
require.NoError(t, err, "should create testing wireguard interface")
|
require.NoError(t, err, "should create testing wireguard interface")
|
||||||
|
|
||||||
statusRecorder := status.NewRecorder()
|
statusRecorder := peer.NewRecorder()
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder)
|
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder)
|
||||||
defer routeManager.Stop()
|
defer routeManager.Stop()
|
||||||
|
@ -13,8 +13,8 @@ import (
|
|||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
nbStatus "github.com/netbirdio/netbird/client/status"
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -33,7 +33,7 @@ type Server struct {
|
|||||||
config *internal.Config
|
config *internal.Config
|
||||||
proto.UnimplementedDaemonServiceServer
|
proto.UnimplementedDaemonServiceServer
|
||||||
|
|
||||||
statusRecorder *nbStatus.Status
|
statusRecorder *peer.Status
|
||||||
}
|
}
|
||||||
|
|
||||||
type oauthAuthFlow struct {
|
type oauthAuthFlow struct {
|
||||||
@ -96,7 +96,7 @@ func (s *Server) Start() error {
|
|||||||
s.config = config
|
s.config = config
|
||||||
|
|
||||||
if s.statusRecorder == nil {
|
if s.statusRecorder == nil {
|
||||||
s.statusRecorder = nbStatus.NewRecorder()
|
s.statusRecorder = peer.NewRecorder()
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
@ -386,7 +386,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
|
|||||||
}
|
}
|
||||||
|
|
||||||
if s.statusRecorder == nil {
|
if s.statusRecorder == nil {
|
||||||
s.statusRecorder = nbStatus.NewRecorder()
|
s.statusRecorder = peer.NewRecorder()
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
@ -430,7 +430,7 @@ func (s *Server) Status(
|
|||||||
statusResponse := proto.StatusResponse{Status: string(status), DaemonVersion: system.NetbirdVersion()}
|
statusResponse := proto.StatusResponse{Status: string(status), DaemonVersion: system.NetbirdVersion()}
|
||||||
|
|
||||||
if s.statusRecorder == nil {
|
if s.statusRecorder == nil {
|
||||||
s.statusRecorder = nbStatus.NewRecorder()
|
s.statusRecorder = peer.NewRecorder()
|
||||||
}
|
}
|
||||||
|
|
||||||
if msg.GetFullPeerStatus {
|
if msg.GetFullPeerStatus {
|
||||||
@ -476,7 +476,7 @@ func (s *Server) GetConfig(_ context.Context, _ *proto.GetConfigRequest) (*proto
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func toProtoFullStatus(fullStatus nbStatus.FullStatus) *proto.FullStatus {
|
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
||||||
pbFullStatus := proto.FullStatus{
|
pbFullStatus := proto.FullStatus{
|
||||||
ManagementState: &proto.ManagementState{},
|
ManagementState: &proto.ManagementState{},
|
||||||
SignalState: &proto.SignalState{},
|
SignalState: &proto.SignalState{},
|
||||||
@ -499,7 +499,7 @@ func toProtoFullStatus(fullStatus nbStatus.FullStatus) *proto.FullStatus {
|
|||||||
pbPeerState := &proto.PeerState{
|
pbPeerState := &proto.PeerState{
|
||||||
IP: peerState.IP,
|
IP: peerState.IP,
|
||||||
PubKey: peerState.PubKey,
|
PubKey: peerState.PubKey,
|
||||||
ConnStatus: peerState.ConnStatus,
|
ConnStatus: peerState.ConnStatus.String(),
|
||||||
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
|
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
|
||||||
Relayed: peerState.Relayed,
|
Relayed: peerState.Relayed,
|
||||||
Direct: peerState.Direct,
|
Direct: peerState.Direct,
|
||||||
|
@ -1,241 +0,0 @@
|
|||||||
package status
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PeerState contains the latest state of a peer
|
|
||||||
type PeerState struct {
|
|
||||||
IP string
|
|
||||||
PubKey string
|
|
||||||
FQDN string
|
|
||||||
ConnStatus string
|
|
||||||
ConnStatusUpdate time.Time
|
|
||||||
Relayed bool
|
|
||||||
Direct bool
|
|
||||||
LocalIceCandidateType string
|
|
||||||
RemoteIceCandidateType string
|
|
||||||
}
|
|
||||||
|
|
||||||
// LocalPeerState contains the latest state of the local peer
|
|
||||||
type LocalPeerState struct {
|
|
||||||
IP string
|
|
||||||
PubKey string
|
|
||||||
KernelInterface bool
|
|
||||||
FQDN string
|
|
||||||
}
|
|
||||||
|
|
||||||
// SignalState contains the latest state of a signal connection
|
|
||||||
type SignalState struct {
|
|
||||||
URL string
|
|
||||||
Connected bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// ManagementState contains the latest state of a management connection
|
|
||||||
type ManagementState struct {
|
|
||||||
URL string
|
|
||||||
Connected bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// FullStatus contains the full state held by the Status instance
|
|
||||||
type FullStatus struct {
|
|
||||||
Peers []PeerState
|
|
||||||
ManagementState ManagementState
|
|
||||||
SignalState SignalState
|
|
||||||
LocalPeerState LocalPeerState
|
|
||||||
}
|
|
||||||
|
|
||||||
// Status holds a state of peers, signal and management connections
|
|
||||||
type Status struct {
|
|
||||||
mux sync.Mutex
|
|
||||||
peers map[string]PeerState
|
|
||||||
changeNotify map[string]chan struct{}
|
|
||||||
signal SignalState
|
|
||||||
management ManagementState
|
|
||||||
localPeer LocalPeerState
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewRecorder returns a new Status instance
|
|
||||||
func NewRecorder() *Status {
|
|
||||||
return &Status{
|
|
||||||
peers: make(map[string]PeerState),
|
|
||||||
changeNotify: make(map[string]chan struct{}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddPeer adds peer to Daemon status map
|
|
||||||
func (d *Status) AddPeer(peerPubKey string) error {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
|
|
||||||
_, ok := d.peers[peerPubKey]
|
|
||||||
if ok {
|
|
||||||
return errors.New("peer already exist")
|
|
||||||
}
|
|
||||||
d.peers[peerPubKey] = PeerState{PubKey: peerPubKey}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPeer adds peer to Daemon status map
|
|
||||||
func (d *Status) GetPeer(peerPubKey string) (PeerState, error) {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
|
|
||||||
state, ok := d.peers[peerPubKey]
|
|
||||||
if !ok {
|
|
||||||
return PeerState{}, errors.New("peer not found")
|
|
||||||
}
|
|
||||||
return state, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemovePeer removes peer from Daemon status map
|
|
||||||
func (d *Status) RemovePeer(peerPubKey string) error {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
|
|
||||||
_, ok := d.peers[peerPubKey]
|
|
||||||
if ok {
|
|
||||||
delete(d.peers, peerPubKey)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return errors.New("no peer with to remove")
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdatePeerState updates peer status
|
|
||||||
func (d *Status) UpdatePeerState(receivedState PeerState) error {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
|
|
||||||
peerState, ok := d.peers[receivedState.PubKey]
|
|
||||||
if !ok {
|
|
||||||
return errors.New("peer doesn't exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
if receivedState.IP != "" {
|
|
||||||
peerState.IP = receivedState.IP
|
|
||||||
}
|
|
||||||
|
|
||||||
if receivedState.ConnStatus != peerState.ConnStatus {
|
|
||||||
peerState.ConnStatus = receivedState.ConnStatus
|
|
||||||
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
|
|
||||||
peerState.Direct = receivedState.Direct
|
|
||||||
peerState.Relayed = receivedState.Relayed
|
|
||||||
peerState.LocalIceCandidateType = receivedState.LocalIceCandidateType
|
|
||||||
peerState.RemoteIceCandidateType = receivedState.RemoteIceCandidateType
|
|
||||||
}
|
|
||||||
|
|
||||||
d.peers[receivedState.PubKey] = peerState
|
|
||||||
|
|
||||||
ch, found := d.changeNotify[receivedState.PubKey]
|
|
||||||
if found && ch != nil {
|
|
||||||
close(ch)
|
|
||||||
d.changeNotify[receivedState.PubKey] = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdatePeerFQDN update peer's state fqdn only
|
|
||||||
func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
|
|
||||||
peerState, ok := d.peers[peerPubKey]
|
|
||||||
if !ok {
|
|
||||||
return errors.New("peer doesn't exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
peerState.FQDN = fqdn
|
|
||||||
d.peers[peerPubKey] = peerState
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPeerStateChangeNotifier returns a change notifier channel for a peer
|
|
||||||
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
ch, found := d.changeNotify[peer]
|
|
||||||
if !found || ch == nil {
|
|
||||||
ch = make(chan struct{})
|
|
||||||
d.changeNotify[peer] = ch
|
|
||||||
}
|
|
||||||
return ch
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateLocalPeerState updates local peer status
|
|
||||||
func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
|
|
||||||
d.localPeer = localPeerState
|
|
||||||
}
|
|
||||||
|
|
||||||
// CleanLocalPeerState cleans local peer status
|
|
||||||
func (d *Status) CleanLocalPeerState() {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
|
|
||||||
d.localPeer = LocalPeerState{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkManagementDisconnected sets ManagementState to disconnected
|
|
||||||
func (d *Status) MarkManagementDisconnected(managementURL string) {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
d.management = ManagementState{
|
|
||||||
URL: managementURL,
|
|
||||||
Connected: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkManagementConnected sets ManagementState to connected
|
|
||||||
func (d *Status) MarkManagementConnected(managementURL string) {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
d.management = ManagementState{
|
|
||||||
URL: managementURL,
|
|
||||||
Connected: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkSignalDisconnected sets SignalState to disconnected
|
|
||||||
func (d *Status) MarkSignalDisconnected(signalURL string) {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
d.signal = SignalState{
|
|
||||||
signalURL,
|
|
||||||
false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkSignalConnected sets SignalState to connected
|
|
||||||
func (d *Status) MarkSignalConnected(signalURL string) {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
d.signal = SignalState{
|
|
||||||
signalURL,
|
|
||||||
true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetFullStatus gets full status
|
|
||||||
func (d *Status) GetFullStatus() FullStatus {
|
|
||||||
d.mux.Lock()
|
|
||||||
defer d.mux.Unlock()
|
|
||||||
|
|
||||||
fullStatus := FullStatus{
|
|
||||||
ManagementState: d.management,
|
|
||||||
SignalState: d.signal,
|
|
||||||
LocalPeerState: d.localPeer,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, status := range d.peers {
|
|
||||||
fullStatus.Peers = append(fullStatus.Peers, status)
|
|
||||||
}
|
|
||||||
|
|
||||||
return fullStatus
|
|
||||||
}
|
|
@ -1,243 +0,0 @@
|
|||||||
package status
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestAddPeer(t *testing.T) {
|
|
||||||
key := "abc"
|
|
||||||
status := NewRecorder()
|
|
||||||
err := status.AddPeer(key)
|
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
_, exists := status.peers[key]
|
|
||||||
assert.True(t, exists, "value was found")
|
|
||||||
|
|
||||||
err = status.AddPeer(key)
|
|
||||||
|
|
||||||
assert.Error(t, err, "should return error on duplicate")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetPeer(t *testing.T) {
|
|
||||||
key := "abc"
|
|
||||||
status := NewRecorder()
|
|
||||||
err := status.AddPeer(key)
|
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
peerStatus, err := status.GetPeer(key)
|
|
||||||
assert.NoError(t, err, "shouldn't return error on getting peer")
|
|
||||||
|
|
||||||
assert.Equal(t, key, peerStatus.PubKey, "retrieved public key should match")
|
|
||||||
|
|
||||||
_, err = status.GetPeer("non_existing_key")
|
|
||||||
assert.Error(t, err, "should return error when peer doesn't exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdatePeerState(t *testing.T) {
|
|
||||||
key := "abc"
|
|
||||||
ip := "10.10.10.10"
|
|
||||||
status := NewRecorder()
|
|
||||||
peerState := PeerState{
|
|
||||||
PubKey: key,
|
|
||||||
}
|
|
||||||
|
|
||||||
status.peers[key] = peerState
|
|
||||||
|
|
||||||
peerState.IP = ip
|
|
||||||
|
|
||||||
err := status.UpdatePeerState(peerState)
|
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
state, exists := status.peers[key]
|
|
||||||
assert.True(t, exists, "state should be found")
|
|
||||||
assert.Equal(t, ip, state.IP, "ip should be equal")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
|
||||||
key := "abc"
|
|
||||||
fqdn := "peer-a.netbird.local"
|
|
||||||
status := NewRecorder()
|
|
||||||
peerState := PeerState{
|
|
||||||
PubKey: key,
|
|
||||||
}
|
|
||||||
|
|
||||||
status.peers[key] = peerState
|
|
||||||
|
|
||||||
err := status.UpdatePeerFQDN(key, fqdn)
|
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
state, exists := status.peers[key]
|
|
||||||
assert.True(t, exists, "state should be found")
|
|
||||||
assert.Equal(t, fqdn, state.FQDN, "fqdn should be equal")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
|
|
||||||
key := "abc"
|
|
||||||
ip := "10.10.10.10"
|
|
||||||
status := NewRecorder()
|
|
||||||
peerState := PeerState{
|
|
||||||
PubKey: key,
|
|
||||||
}
|
|
||||||
|
|
||||||
status.peers[key] = peerState
|
|
||||||
|
|
||||||
ch := status.GetPeerStateChangeNotifier(key)
|
|
||||||
assert.NotNil(t, ch, "channel shouldn't be nil")
|
|
||||||
|
|
||||||
peerState.IP = ip
|
|
||||||
|
|
||||||
err := status.UpdatePeerState(peerState)
|
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ch:
|
|
||||||
default:
|
|
||||||
t.Errorf("channel wasn't closed after update")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRemovePeer(t *testing.T) {
|
|
||||||
key := "abc"
|
|
||||||
status := NewRecorder()
|
|
||||||
peerState := PeerState{
|
|
||||||
PubKey: key,
|
|
||||||
}
|
|
||||||
|
|
||||||
status.peers[key] = peerState
|
|
||||||
|
|
||||||
err := status.RemovePeer(key)
|
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
_, exists := status.peers[key]
|
|
||||||
assert.False(t, exists, "state value shouldn't be found")
|
|
||||||
|
|
||||||
err = status.RemovePeer("not existing")
|
|
||||||
assert.Error(t, err, "should return error when peer doesn't exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateLocalPeerState(t *testing.T) {
|
|
||||||
localPeerState := LocalPeerState{
|
|
||||||
IP: "10.10.10.10",
|
|
||||||
PubKey: "abc",
|
|
||||||
KernelInterface: false,
|
|
||||||
}
|
|
||||||
status := NewRecorder()
|
|
||||||
|
|
||||||
status.UpdateLocalPeerState(localPeerState)
|
|
||||||
|
|
||||||
assert.Equal(t, localPeerState, status.localPeer, "local peer status should be equal")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCleanLocalPeerState(t *testing.T) {
|
|
||||||
emptyLocalPeerState := LocalPeerState{}
|
|
||||||
localPeerState := LocalPeerState{
|
|
||||||
IP: "10.10.10.10",
|
|
||||||
PubKey: "abc",
|
|
||||||
KernelInterface: false,
|
|
||||||
}
|
|
||||||
status := NewRecorder()
|
|
||||||
|
|
||||||
status.localPeer = localPeerState
|
|
||||||
|
|
||||||
status.CleanLocalPeerState()
|
|
||||||
|
|
||||||
assert.Equal(t, emptyLocalPeerState, status.localPeer, "local peer status should be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateSignalState(t *testing.T) {
|
|
||||||
url := "https://signal"
|
|
||||||
var tests = []struct {
|
|
||||||
name string
|
|
||||||
connected bool
|
|
||||||
want SignalState
|
|
||||||
}{
|
|
||||||
{"should mark as connected", true, SignalState{
|
|
||||||
|
|
||||||
URL: url,
|
|
||||||
Connected: true,
|
|
||||||
}},
|
|
||||||
{"should mark as disconnected", false, SignalState{
|
|
||||||
URL: url,
|
|
||||||
Connected: false,
|
|
||||||
}},
|
|
||||||
}
|
|
||||||
|
|
||||||
status := NewRecorder()
|
|
||||||
|
|
||||||
for _, test := range tests {
|
|
||||||
t.Run(test.name, func(t *testing.T) {
|
|
||||||
if test.connected {
|
|
||||||
status.MarkSignalConnected(url)
|
|
||||||
} else {
|
|
||||||
status.MarkSignalDisconnected(url)
|
|
||||||
}
|
|
||||||
assert.Equal(t, test.want, status.signal, "signal status should be equal")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateManagementState(t *testing.T) {
|
|
||||||
url := "https://management"
|
|
||||||
var tests = []struct {
|
|
||||||
name string
|
|
||||||
connected bool
|
|
||||||
want ManagementState
|
|
||||||
}{
|
|
||||||
{"should mark as connected", true, ManagementState{
|
|
||||||
|
|
||||||
URL: url,
|
|
||||||
Connected: true,
|
|
||||||
}},
|
|
||||||
{"should mark as disconnected", false, ManagementState{
|
|
||||||
URL: url,
|
|
||||||
Connected: false,
|
|
||||||
}},
|
|
||||||
}
|
|
||||||
|
|
||||||
status := NewRecorder()
|
|
||||||
|
|
||||||
for _, test := range tests {
|
|
||||||
t.Run(test.name, func(t *testing.T) {
|
|
||||||
if test.connected {
|
|
||||||
status.MarkManagementConnected(url)
|
|
||||||
} else {
|
|
||||||
status.MarkManagementDisconnected(url)
|
|
||||||
}
|
|
||||||
assert.Equal(t, test.want, status.management, "signal status should be equal")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetFullStatus(t *testing.T) {
|
|
||||||
key1 := "abc"
|
|
||||||
key2 := "def"
|
|
||||||
managementState := ManagementState{
|
|
||||||
URL: "https://signal",
|
|
||||||
Connected: true,
|
|
||||||
}
|
|
||||||
signalState := SignalState{
|
|
||||||
URL: "https://signal",
|
|
||||||
Connected: true,
|
|
||||||
}
|
|
||||||
peerState1 := PeerState{
|
|
||||||
PubKey: key1,
|
|
||||||
}
|
|
||||||
|
|
||||||
peerState2 := PeerState{
|
|
||||||
PubKey: key2,
|
|
||||||
}
|
|
||||||
|
|
||||||
status := NewRecorder()
|
|
||||||
|
|
||||||
status.management = managementState
|
|
||||||
status.signal = signalState
|
|
||||||
status.peers[key1] = peerState1
|
|
||||||
status.peers[key2] = peerState2
|
|
||||||
|
|
||||||
fullStatus := status.GetFullStatus()
|
|
||||||
|
|
||||||
assert.Equal(t, managementState, fullStatus.ManagementState, "management status should be equal")
|
|
||||||
assert.Equal(t, signalState, fullStatus.SignalState, "signal status should be equal")
|
|
||||||
assert.ElementsMatch(t, []PeerState{peerState1, peerState2}, fullStatus.Peers, "peers states should match")
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user