mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-07 08:44:07 +01:00
Fix connstate indication (#732)
Fix the status indication in the client service. The status of the management server and the signal server was incorrect if the network connection was broken. Basically the status update was not used by the management and signal library.
This commit is contained in:
parent
731d3ae464
commit
747797271e
@ -94,7 +94,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithCancel(ctx)
|
||||
SetupCloseHandler(ctx, cancel)
|
||||
return internal.RunClient(ctx, config, peer.NewRecorder())
|
||||
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
|
||||
}
|
||||
|
||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
|
@ -58,8 +58,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status)
|
||||
return err
|
||||
}
|
||||
|
||||
managementURL := config.ManagementURL.String()
|
||||
statusRecorder.MarkManagementDisconnected(managementURL)
|
||||
statusRecorder.MarkManagementDisconnected()
|
||||
|
||||
operation := func() error {
|
||||
// if context cancelled we not start new backoff cycle
|
||||
@ -73,13 +72,16 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status)
|
||||
|
||||
engineCtx, cancel := context.WithCancel(ctx)
|
||||
defer func() {
|
||||
statusRecorder.MarkManagementDisconnected(managementURL)
|
||||
statusRecorder.MarkManagementDisconnected()
|
||||
statusRecorder.CleanLocalPeerState()
|
||||
cancel()
|
||||
}()
|
||||
|
||||
log.Debugf("conecting to the Management service %s", config.ManagementURL.Host)
|
||||
mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
mgmNotifier := statusRecorderToMgmConnStateNotifier(statusRecorder)
|
||||
mgmClient.SetConnStateListener(mgmNotifier)
|
||||
|
||||
if err != nil {
|
||||
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
|
||||
}
|
||||
@ -101,7 +103,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status)
|
||||
}
|
||||
return wrapErr(err)
|
||||
}
|
||||
statusRecorder.MarkManagementConnected(managementURL)
|
||||
statusRecorder.MarkManagementConnected()
|
||||
|
||||
localPeerState := peer.LocalPeerState{
|
||||
IP: loginResp.GetPeerConfig().GetAddress(),
|
||||
@ -117,8 +119,10 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status)
|
||||
loginResp.GetWiretrusteeConfig().GetSignal().GetUri(),
|
||||
)
|
||||
|
||||
statusRecorder.MarkSignalDisconnected(signalURL)
|
||||
defer statusRecorder.MarkSignalDisconnected(signalURL)
|
||||
statusRecorder.UpdateSignalAddress(signalURL)
|
||||
|
||||
statusRecorder.MarkSignalDisconnected()
|
||||
defer statusRecorder.MarkSignalDisconnected()
|
||||
|
||||
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
||||
signalClient, err := connectToSignal(engineCtx, loginResp.GetWiretrusteeConfig(), myPrivateKey)
|
||||
@ -133,7 +137,10 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status)
|
||||
}
|
||||
}()
|
||||
|
||||
statusRecorder.MarkSignalConnected(signalURL)
|
||||
signalNotifier := statusRecorderToSignalConnStateNotifier(statusRecorder)
|
||||
signalClient.SetConnStateListener(signalNotifier)
|
||||
|
||||
statusRecorder.MarkSignalConnected()
|
||||
|
||||
peerConfig := loginResp.GetPeerConfig()
|
||||
|
||||
@ -320,3 +327,15 @@ func UpdateOldManagementPort(ctx context.Context, config *Config, configPath str
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {
|
||||
var sri interface{} = statusRecorder
|
||||
mgmNotifier, _ := sri.(mgm.ConnStateNotifier)
|
||||
return mgmNotifier
|
||||
}
|
||||
|
||||
func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal.ConnStateNotifier {
|
||||
var sri interface{} = statusRecorder
|
||||
notifier, _ := sri.(signal.ConnStateNotifier)
|
||||
return notifier
|
||||
}
|
||||
|
@ -72,7 +72,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, peer.NewRecorder())
|
||||
}, peer.NewRecorder("https://mgm"))
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
@ -206,7 +206,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, peer.NewRecorder())
|
||||
}, peer.NewRecorder("https://mgm"))
|
||||
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU)
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder)
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
@ -390,7 +390,7 @@ func TestEngine_Sync(t *testing.T) {
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, peer.NewRecorder())
|
||||
}, peer.NewRecorder("https://mgm"))
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
@ -548,7 +548,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
WgAddr: wgAddr,
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, peer.NewRecorder())
|
||||
}, peer.NewRecorder("https://mgm"))
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
input := struct {
|
||||
@ -713,7 +713,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
WgAddr: wgAddr,
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, peer.NewRecorder())
|
||||
}, peer.NewRecorder("https://mgm"))
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
@ -978,7 +978,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
WgPort: wgPort,
|
||||
}
|
||||
|
||||
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, peer.NewRecorder()), nil
|
||||
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, peer.NewRecorder("https://mgm")), nil
|
||||
}
|
||||
|
||||
func startSignal() (*grpc.Server, string, error) {
|
||||
|
@ -49,7 +49,7 @@ func TestConn_GetKey(t *testing.T) {
|
||||
|
||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
|
||||
conn, err := NewConn(connConf, NewRecorder())
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -83,7 +83,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
|
||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
|
||||
conn, err := NewConn(connConf, NewRecorder())
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -116,7 +116,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
}
|
||||
func TestConn_Status(t *testing.T) {
|
||||
|
||||
conn, err := NewConn(connConf, NewRecorder())
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -143,7 +143,7 @@ func TestConn_Status(t *testing.T) {
|
||||
|
||||
func TestConn_Close(t *testing.T) {
|
||||
|
||||
conn, err := NewConn(connConf, NewRecorder())
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -49,21 +49,24 @@ type FullStatus struct {
|
||||
|
||||
// Status holds a state of peers, signal and management connections
|
||||
type Status struct {
|
||||
mux sync.Mutex
|
||||
peers map[string]State
|
||||
changeNotify map[string]chan struct{}
|
||||
signal SignalState
|
||||
management ManagementState
|
||||
localPeer LocalPeerState
|
||||
offlinePeers []State
|
||||
mux sync.Mutex
|
||||
peers map[string]State
|
||||
changeNotify map[string]chan struct{}
|
||||
signalState bool
|
||||
managementState bool
|
||||
localPeer LocalPeerState
|
||||
offlinePeers []State
|
||||
mgmAddress string
|
||||
signalAddress string
|
||||
}
|
||||
|
||||
// NewRecorder returns a new Status instance
|
||||
func NewRecorder() *Status {
|
||||
func NewRecorder(mgmAddress string) *Status {
|
||||
return &Status{
|
||||
peers: make(map[string]State),
|
||||
changeNotify: make(map[string]chan struct{}),
|
||||
offlinePeers: make([]State, 0),
|
||||
mgmAddress: mgmAddress,
|
||||
}
|
||||
}
|
||||
|
||||
@ -193,43 +196,45 @@ func (d *Status) CleanLocalPeerState() {
|
||||
}
|
||||
|
||||
// MarkManagementDisconnected sets ManagementState to disconnected
|
||||
func (d *Status) MarkManagementDisconnected(managementURL string) {
|
||||
func (d *Status) MarkManagementDisconnected() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.management = ManagementState{
|
||||
URL: managementURL,
|
||||
Connected: false,
|
||||
}
|
||||
d.managementState = false
|
||||
}
|
||||
|
||||
// MarkManagementConnected sets ManagementState to connected
|
||||
func (d *Status) MarkManagementConnected(managementURL string) {
|
||||
func (d *Status) MarkManagementConnected() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.management = ManagementState{
|
||||
URL: managementURL,
|
||||
Connected: true,
|
||||
}
|
||||
d.managementState = true
|
||||
}
|
||||
|
||||
// UpdateSignalAddress update the address of the signal server
|
||||
func (d *Status) UpdateSignalAddress(signalURL string) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.signalAddress = signalURL
|
||||
}
|
||||
|
||||
// UpdateManagementAddress update the address of the management server
|
||||
func (d *Status) UpdateManagementAddress(mgmAddress string) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mgmAddress = mgmAddress
|
||||
}
|
||||
|
||||
// MarkSignalDisconnected sets SignalState to disconnected
|
||||
func (d *Status) MarkSignalDisconnected(signalURL string) {
|
||||
func (d *Status) MarkSignalDisconnected() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.signal = SignalState{
|
||||
signalURL,
|
||||
false,
|
||||
}
|
||||
d.signalState = false
|
||||
}
|
||||
|
||||
// MarkSignalConnected sets SignalState to connected
|
||||
func (d *Status) MarkSignalConnected(signalURL string) {
|
||||
func (d *Status) MarkSignalConnected() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.signal = SignalState{
|
||||
signalURL,
|
||||
true,
|
||||
}
|
||||
d.signalState = true
|
||||
}
|
||||
|
||||
// GetFullStatus gets full status
|
||||
@ -238,9 +243,15 @@ func (d *Status) GetFullStatus() FullStatus {
|
||||
defer d.mux.Unlock()
|
||||
|
||||
fullStatus := FullStatus{
|
||||
ManagementState: d.management,
|
||||
SignalState: d.signal,
|
||||
LocalPeerState: d.localPeer,
|
||||
ManagementState: ManagementState{
|
||||
d.mgmAddress,
|
||||
d.managementState,
|
||||
},
|
||||
SignalState: SignalState{
|
||||
d.signalAddress,
|
||||
d.signalState,
|
||||
},
|
||||
LocalPeerState: d.localPeer,
|
||||
}
|
||||
|
||||
for _, status := range d.peers {
|
||||
|
@ -8,7 +8,7 @@ import (
|
||||
|
||||
func TestAddPeer(t *testing.T) {
|
||||
key := "abc"
|
||||
status := NewRecorder()
|
||||
status := NewRecorder("https://mgm")
|
||||
err := status.AddPeer(key)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
@ -22,7 +22,7 @@ func TestAddPeer(t *testing.T) {
|
||||
|
||||
func TestGetPeer(t *testing.T) {
|
||||
key := "abc"
|
||||
status := NewRecorder()
|
||||
status := NewRecorder("https://mgm")
|
||||
err := status.AddPeer(key)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
@ -38,7 +38,7 @@ func TestGetPeer(t *testing.T) {
|
||||
func TestUpdatePeerState(t *testing.T) {
|
||||
key := "abc"
|
||||
ip := "10.10.10.10"
|
||||
status := NewRecorder()
|
||||
status := NewRecorder("https://mgm")
|
||||
peerState := State{
|
||||
PubKey: key,
|
||||
}
|
||||
@ -58,7 +58,7 @@ func TestUpdatePeerState(t *testing.T) {
|
||||
func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
||||
key := "abc"
|
||||
fqdn := "peer-a.netbird.local"
|
||||
status := NewRecorder()
|
||||
status := NewRecorder("https://mgm")
|
||||
peerState := State{
|
||||
PubKey: key,
|
||||
}
|
||||
@ -76,7 +76,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
||||
func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
|
||||
key := "abc"
|
||||
ip := "10.10.10.10"
|
||||
status := NewRecorder()
|
||||
status := NewRecorder("https://mgm")
|
||||
peerState := State{
|
||||
PubKey: key,
|
||||
}
|
||||
@ -100,7 +100,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
|
||||
|
||||
func TestRemovePeer(t *testing.T) {
|
||||
key := "abc"
|
||||
status := NewRecorder()
|
||||
status := NewRecorder("https://mgm")
|
||||
peerState := State{
|
||||
PubKey: key,
|
||||
}
|
||||
@ -123,7 +123,7 @@ func TestUpdateLocalPeerState(t *testing.T) {
|
||||
PubKey: "abc",
|
||||
KernelInterface: false,
|
||||
}
|
||||
status := NewRecorder()
|
||||
status := NewRecorder("https://mgm")
|
||||
|
||||
status.UpdateLocalPeerState(localPeerState)
|
||||
|
||||
@ -137,7 +137,7 @@ func TestCleanLocalPeerState(t *testing.T) {
|
||||
PubKey: "abc",
|
||||
KernelInterface: false,
|
||||
}
|
||||
status := NewRecorder()
|
||||
status := NewRecorder("https://mgm")
|
||||
|
||||
status.localPeer = localPeerState
|
||||
|
||||
@ -151,29 +151,23 @@ func TestUpdateSignalState(t *testing.T) {
|
||||
var tests = []struct {
|
||||
name string
|
||||
connected bool
|
||||
want SignalState
|
||||
want bool
|
||||
}{
|
||||
{"should mark as connected", true, SignalState{
|
||||
|
||||
URL: url,
|
||||
Connected: true,
|
||||
}},
|
||||
{"should mark as disconnected", false, SignalState{
|
||||
URL: url,
|
||||
Connected: false,
|
||||
}},
|
||||
{"should mark as connected", true, true},
|
||||
{"should mark as disconnected", false, false},
|
||||
}
|
||||
|
||||
status := NewRecorder()
|
||||
status := NewRecorder("https://mgm")
|
||||
status.UpdateSignalAddress(url)
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
if test.connected {
|
||||
status.MarkSignalConnected(url)
|
||||
status.MarkSignalConnected()
|
||||
} else {
|
||||
status.MarkSignalDisconnected(url)
|
||||
status.MarkSignalDisconnected()
|
||||
}
|
||||
assert.Equal(t, test.want, status.signal, "signal status should be equal")
|
||||
assert.Equal(t, test.want, status.signalState, "signal status should be equal")
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -183,29 +177,22 @@ func TestUpdateManagementState(t *testing.T) {
|
||||
var tests = []struct {
|
||||
name string
|
||||
connected bool
|
||||
want ManagementState
|
||||
want bool
|
||||
}{
|
||||
{"should mark as connected", true, ManagementState{
|
||||
|
||||
URL: url,
|
||||
Connected: true,
|
||||
}},
|
||||
{"should mark as disconnected", false, ManagementState{
|
||||
URL: url,
|
||||
Connected: false,
|
||||
}},
|
||||
{"should mark as connected", true, true},
|
||||
{"should mark as disconnected", false, false},
|
||||
}
|
||||
|
||||
status := NewRecorder()
|
||||
status := NewRecorder(url)
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
if test.connected {
|
||||
status.MarkManagementConnected(url)
|
||||
status.MarkManagementConnected()
|
||||
} else {
|
||||
status.MarkManagementDisconnected(url)
|
||||
status.MarkManagementDisconnected()
|
||||
}
|
||||
assert.Equal(t, test.want, status.management, "signal status should be equal")
|
||||
assert.Equal(t, test.want, status.managementState, "signalState status should be equal")
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -213,12 +200,13 @@ func TestUpdateManagementState(t *testing.T) {
|
||||
func TestGetFullStatus(t *testing.T) {
|
||||
key1 := "abc"
|
||||
key2 := "def"
|
||||
signalAddr := "https://signal"
|
||||
managementState := ManagementState{
|
||||
URL: "https://signal",
|
||||
URL: "https://mgm",
|
||||
Connected: true,
|
||||
}
|
||||
signalState := SignalState{
|
||||
URL: "https://signal",
|
||||
URL: signalAddr,
|
||||
Connected: true,
|
||||
}
|
||||
peerState1 := State{
|
||||
@ -229,10 +217,11 @@ func TestGetFullStatus(t *testing.T) {
|
||||
PubKey: key2,
|
||||
}
|
||||
|
||||
status := NewRecorder()
|
||||
status := NewRecorder("https://mgm")
|
||||
status.UpdateSignalAddress(signalAddr)
|
||||
|
||||
status.management = managementState
|
||||
status.signal = signalState
|
||||
status.managementState = managementState.Connected
|
||||
status.signalState = signalState.Connected
|
||||
status.peers[key1] = peerState1
|
||||
status.peers[key2] = peerState2
|
||||
|
||||
|
@ -398,7 +398,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
err = wgInterface.Create()
|
||||
require.NoError(t, err, "should create testing wireguard interface")
|
||||
|
||||
statusRecorder := peer.NewRecorder()
|
||||
statusRecorder := peer.NewRecorder("https://mgm")
|
||||
ctx := context.TODO()
|
||||
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder)
|
||||
defer routeManager.Stop()
|
||||
|
@ -96,7 +96,9 @@ func (s *Server) Start() error {
|
||||
s.config = config
|
||||
|
||||
if s.statusRecorder == nil {
|
||||
s.statusRecorder = peer.NewRecorder()
|
||||
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
|
||||
} else {
|
||||
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
|
||||
}
|
||||
|
||||
go func() {
|
||||
@ -386,7 +388,9 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
|
||||
}
|
||||
|
||||
if s.statusRecorder == nil {
|
||||
s.statusRecorder = peer.NewRecorder()
|
||||
s.statusRecorder = peer.NewRecorder(s.config.ManagementURL.String())
|
||||
} else {
|
||||
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
|
||||
}
|
||||
|
||||
go func() {
|
||||
@ -430,7 +434,9 @@ func (s *Server) Status(
|
||||
statusResponse := proto.StatusResponse{Status: string(status), DaemonVersion: version.NetbirdVersion()}
|
||||
|
||||
if s.statusRecorder == nil {
|
||||
s.statusRecorder = peer.NewRecorder()
|
||||
s.statusRecorder = peer.NewRecorder(s.config.ManagementURL.String())
|
||||
} else {
|
||||
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
|
||||
}
|
||||
|
||||
if msg.GetFullPeerStatus {
|
||||
|
@ -4,15 +4,13 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
@ -20,13 +18,26 @@ import (
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
// ConnStateNotifier is a wrapper interface of the status recorders
|
||||
type ConnStateNotifier interface {
|
||||
MarkManagementDisconnected()
|
||||
MarkManagementConnected()
|
||||
}
|
||||
|
||||
type GrpcClient struct {
|
||||
key wgtypes.Key
|
||||
realClient proto.ManagementServiceClient
|
||||
ctx context.Context
|
||||
conn *grpc.ClientConn
|
||||
key wgtypes.Key
|
||||
realClient proto.ManagementServiceClient
|
||||
ctx context.Context
|
||||
conn *grpc.ClientConn
|
||||
connStateCallback ConnStateNotifier
|
||||
connStateCallbackLock sync.RWMutex
|
||||
}
|
||||
|
||||
// NewClient creates a new client to Management service
|
||||
@ -56,10 +67,11 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
|
||||
realClient := proto.NewManagementServiceClient(conn)
|
||||
|
||||
return &GrpcClient{
|
||||
key: ourPrivateKey,
|
||||
realClient: realClient,
|
||||
ctx: ctx,
|
||||
conn: conn,
|
||||
key: ourPrivateKey,
|
||||
realClient: realClient,
|
||||
ctx: ctx,
|
||||
conn: conn,
|
||||
connStateCallbackLock: sync.RWMutex{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -68,6 +80,13 @@ func (c *GrpcClient) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
// SetConnStateListener set the ConnStateNotifier
|
||||
func (c *GrpcClient) SetConnStateListener(notifier ConnStateNotifier) {
|
||||
c.connStateCallbackLock.Lock()
|
||||
defer c.connStateCallbackLock.Unlock()
|
||||
c.connStateCallback = notifier
|
||||
}
|
||||
|
||||
// defaultBackoff is a basic backoff mechanism for general issues
|
||||
func defaultBackoff(ctx context.Context) backoff.BackOff {
|
||||
return backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
@ -121,7 +140,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
|
||||
}
|
||||
|
||||
log.Infof("connected to the Management Service stream")
|
||||
|
||||
c.notifyConnected()
|
||||
// blocking until error
|
||||
err = c.receiveEvents(stream, *serverPubKey, msgHandler)
|
||||
if err != nil {
|
||||
@ -131,6 +150,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
|
||||
// we need this reset because after a successful connection and a consequent error, backoff lib doesn't
|
||||
// reset times and next try will start with a long delay
|
||||
backOff.Reset()
|
||||
c.notifyDisconnected()
|
||||
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
|
||||
return err
|
||||
}
|
||||
@ -298,6 +318,26 @@ func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D
|
||||
return flowInfoResp, nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) notifyDisconnected() {
|
||||
c.connStateCallbackLock.RLock()
|
||||
defer c.connStateCallbackLock.RUnlock()
|
||||
|
||||
if c.connStateCallback == nil {
|
||||
return
|
||||
}
|
||||
c.connStateCallback.MarkManagementDisconnected()
|
||||
}
|
||||
|
||||
func (c *GrpcClient) notifyConnected() {
|
||||
c.connStateCallbackLock.RLock()
|
||||
defer c.connStateCallbackLock.RUnlock()
|
||||
|
||||
if c.connStateCallback == nil {
|
||||
return
|
||||
}
|
||||
c.connStateCallback.MarkManagementConnected()
|
||||
}
|
||||
|
||||
func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
|
||||
if info == nil {
|
||||
return nil
|
||||
|
@ -22,6 +22,12 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// ConnStateNotifier is a wrapper interface of the status recorder
|
||||
type ConnStateNotifier interface {
|
||||
MarkSignalDisconnected()
|
||||
MarkSignalConnected()
|
||||
}
|
||||
|
||||
// GrpcClient Wraps the Signal Exchange Service gRpc client
|
||||
type GrpcClient struct {
|
||||
key wgtypes.Key
|
||||
@ -34,6 +40,9 @@ type GrpcClient struct {
|
||||
mux sync.Mutex
|
||||
// StreamConnected indicates whether this client is StreamConnected to the Signal stream
|
||||
status Status
|
||||
|
||||
connStateCallback ConnStateNotifier
|
||||
connStateCallbackLock sync.RWMutex
|
||||
}
|
||||
|
||||
func (c *GrpcClient) StreamConnected() bool {
|
||||
@ -78,15 +87,23 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
|
||||
log.Debugf("connected to Signal Service: %v", conn.Target())
|
||||
|
||||
return &GrpcClient{
|
||||
realClient: proto.NewSignalExchangeClient(conn),
|
||||
ctx: ctx,
|
||||
signalConn: conn,
|
||||
key: key,
|
||||
mux: sync.Mutex{},
|
||||
status: StreamDisconnected,
|
||||
realClient: proto.NewSignalExchangeClient(conn),
|
||||
ctx: ctx,
|
||||
signalConn: conn,
|
||||
key: key,
|
||||
mux: sync.Mutex{},
|
||||
status: StreamDisconnected,
|
||||
connStateCallbackLock: sync.RWMutex{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetConnStateListener set the ConnStateNotifier
|
||||
func (c *GrpcClient) SetConnStateListener(notifier ConnStateNotifier) {
|
||||
c.connStateCallbackLock.Lock()
|
||||
defer c.connStateCallbackLock.Unlock()
|
||||
c.connStateCallback = notifier
|
||||
}
|
||||
|
||||
// defaultBackoff is a basic backoff mechanism for general issues
|
||||
func defaultBackoff(ctx context.Context) backoff.BackOff {
|
||||
return backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
@ -134,13 +151,14 @@ func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) error {
|
||||
c.notifyStreamConnected()
|
||||
|
||||
log.Infof("connected to the Signal Service stream")
|
||||
|
||||
c.notifyConnected()
|
||||
// start receiving messages from the Signal stream (from other peers through signal)
|
||||
err = c.receive(stream, msgHandler)
|
||||
if err != nil {
|
||||
// we need this reset because after a successful connection and a consequent error, backoff lib doesn't
|
||||
// reset times and next try will start with a long delay
|
||||
backOff.Reset()
|
||||
c.notifyDisconnected()
|
||||
log.Warnf("disconnected from the Signal service but will retry silently. Reason: %v", err)
|
||||
return err
|
||||
}
|
||||
@ -341,3 +359,23 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *GrpcClient) notifyDisconnected() {
|
||||
c.connStateCallbackLock.RLock()
|
||||
defer c.connStateCallbackLock.RUnlock()
|
||||
|
||||
if c.connStateCallback == nil {
|
||||
return
|
||||
}
|
||||
c.connStateCallback.MarkSignalDisconnected()
|
||||
}
|
||||
|
||||
func (c *GrpcClient) notifyConnected() {
|
||||
c.connStateCallbackLock.RLock()
|
||||
defer c.connStateCallbackLock.RUnlock()
|
||||
|
||||
if c.connStateCallback == nil {
|
||||
return
|
||||
}
|
||||
c.connStateCallback.MarkSignalConnected()
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user