Fix connstate indication (#732)

Fix the status indication in the client service. The status of the
management server and the signal server was incorrect if the network
connection was broken. Basically the status update was not used by
the management and signal library.
This commit is contained in:
Zoltan Papp 2023-03-16 17:22:36 +01:00 committed by GitHub
parent 731d3ae464
commit 747797271e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 219 additions and 116 deletions

View File

@ -94,7 +94,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel)
return internal.RunClient(ctx, config, peer.NewRecorder())
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
}
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {

View File

@ -58,8 +58,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status)
return err
}
managementURL := config.ManagementURL.String()
statusRecorder.MarkManagementDisconnected(managementURL)
statusRecorder.MarkManagementDisconnected()
operation := func() error {
// if context cancelled we not start new backoff cycle
@ -73,13 +72,16 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status)
engineCtx, cancel := context.WithCancel(ctx)
defer func() {
statusRecorder.MarkManagementDisconnected(managementURL)
statusRecorder.MarkManagementDisconnected()
statusRecorder.CleanLocalPeerState()
cancel()
}()
log.Debugf("conecting to the Management service %s", config.ManagementURL.Host)
mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
mgmNotifier := statusRecorderToMgmConnStateNotifier(statusRecorder)
mgmClient.SetConnStateListener(mgmNotifier)
if err != nil {
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
}
@ -101,7 +103,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status)
}
return wrapErr(err)
}
statusRecorder.MarkManagementConnected(managementURL)
statusRecorder.MarkManagementConnected()
localPeerState := peer.LocalPeerState{
IP: loginResp.GetPeerConfig().GetAddress(),
@ -117,8 +119,10 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status)
loginResp.GetWiretrusteeConfig().GetSignal().GetUri(),
)
statusRecorder.MarkSignalDisconnected(signalURL)
defer statusRecorder.MarkSignalDisconnected(signalURL)
statusRecorder.UpdateSignalAddress(signalURL)
statusRecorder.MarkSignalDisconnected()
defer statusRecorder.MarkSignalDisconnected()
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
signalClient, err := connectToSignal(engineCtx, loginResp.GetWiretrusteeConfig(), myPrivateKey)
@ -133,7 +137,10 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status)
}
}()
statusRecorder.MarkSignalConnected(signalURL)
signalNotifier := statusRecorderToSignalConnStateNotifier(statusRecorder)
signalClient.SetConnStateListener(signalNotifier)
statusRecorder.MarkSignalConnected()
peerConfig := loginResp.GetPeerConfig()
@ -320,3 +327,15 @@ func UpdateOldManagementPort(ctx context.Context, config *Config, configPath str
return config, nil
}
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {
var sri interface{} = statusRecorder
mgmNotifier, _ := sri.(mgm.ConnStateNotifier)
return mgmNotifier
}
func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal.ConnStateNotifier {
var sri interface{} = statusRecorder
notifier, _ := sri.(signal.ConnStateNotifier)
return notifier
}

View File

@ -72,7 +72,7 @@ func TestEngine_SSH(t *testing.T) {
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
}, peer.NewRecorder())
}, peer.NewRecorder("https://mgm"))
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@ -206,7 +206,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
}, peer.NewRecorder())
}, peer.NewRecorder("https://mgm"))
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU)
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder)
engine.dnsServer = &dns.MockServer{
@ -390,7 +390,7 @@ func TestEngine_Sync(t *testing.T) {
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
}, peer.NewRecorder())
}, peer.NewRecorder("https://mgm"))
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@ -548,7 +548,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
}, peer.NewRecorder())
}, peer.NewRecorder("https://mgm"))
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
assert.NoError(t, err, "shouldn't return error")
input := struct {
@ -713,7 +713,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
}, peer.NewRecorder())
}, peer.NewRecorder("https://mgm"))
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
assert.NoError(t, err, "shouldn't return error")
@ -978,7 +978,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
WgPort: wgPort,
}
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, peer.NewRecorder()), nil
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, peer.NewRecorder("https://mgm")), nil
}
func startSignal() (*grpc.Server, string, error) {

View File

@ -49,7 +49,7 @@ func TestConn_GetKey(t *testing.T) {
func TestConn_OnRemoteOffer(t *testing.T) {
conn, err := NewConn(connConf, NewRecorder())
conn, err := NewConn(connConf, NewRecorder("https://mgm"))
if err != nil {
return
}
@ -83,7 +83,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
func TestConn_OnRemoteAnswer(t *testing.T) {
conn, err := NewConn(connConf, NewRecorder())
conn, err := NewConn(connConf, NewRecorder("https://mgm"))
if err != nil {
return
}
@ -116,7 +116,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
}
func TestConn_Status(t *testing.T) {
conn, err := NewConn(connConf, NewRecorder())
conn, err := NewConn(connConf, NewRecorder("https://mgm"))
if err != nil {
return
}
@ -143,7 +143,7 @@ func TestConn_Status(t *testing.T) {
func TestConn_Close(t *testing.T) {
conn, err := NewConn(connConf, NewRecorder())
conn, err := NewConn(connConf, NewRecorder("https://mgm"))
if err != nil {
return
}

View File

@ -52,18 +52,21 @@ type Status struct {
mux sync.Mutex
peers map[string]State
changeNotify map[string]chan struct{}
signal SignalState
management ManagementState
signalState bool
managementState bool
localPeer LocalPeerState
offlinePeers []State
mgmAddress string
signalAddress string
}
// NewRecorder returns a new Status instance
func NewRecorder() *Status {
func NewRecorder(mgmAddress string) *Status {
return &Status{
peers: make(map[string]State),
changeNotify: make(map[string]chan struct{}),
offlinePeers: make([]State, 0),
mgmAddress: mgmAddress,
}
}
@ -193,43 +196,45 @@ func (d *Status) CleanLocalPeerState() {
}
// MarkManagementDisconnected sets ManagementState to disconnected
func (d *Status) MarkManagementDisconnected(managementURL string) {
func (d *Status) MarkManagementDisconnected() {
d.mux.Lock()
defer d.mux.Unlock()
d.management = ManagementState{
URL: managementURL,
Connected: false,
}
d.managementState = false
}
// MarkManagementConnected sets ManagementState to connected
func (d *Status) MarkManagementConnected(managementURL string) {
func (d *Status) MarkManagementConnected() {
d.mux.Lock()
defer d.mux.Unlock()
d.management = ManagementState{
URL: managementURL,
Connected: true,
d.managementState = true
}
// UpdateSignalAddress update the address of the signal server
func (d *Status) UpdateSignalAddress(signalURL string) {
d.mux.Lock()
defer d.mux.Unlock()
d.signalAddress = signalURL
}
// UpdateManagementAddress update the address of the management server
func (d *Status) UpdateManagementAddress(mgmAddress string) {
d.mux.Lock()
defer d.mux.Unlock()
d.mgmAddress = mgmAddress
}
// MarkSignalDisconnected sets SignalState to disconnected
func (d *Status) MarkSignalDisconnected(signalURL string) {
func (d *Status) MarkSignalDisconnected() {
d.mux.Lock()
defer d.mux.Unlock()
d.signal = SignalState{
signalURL,
false,
}
d.signalState = false
}
// MarkSignalConnected sets SignalState to connected
func (d *Status) MarkSignalConnected(signalURL string) {
func (d *Status) MarkSignalConnected() {
d.mux.Lock()
defer d.mux.Unlock()
d.signal = SignalState{
signalURL,
true,
}
d.signalState = true
}
// GetFullStatus gets full status
@ -238,8 +243,14 @@ func (d *Status) GetFullStatus() FullStatus {
defer d.mux.Unlock()
fullStatus := FullStatus{
ManagementState: d.management,
SignalState: d.signal,
ManagementState: ManagementState{
d.mgmAddress,
d.managementState,
},
SignalState: SignalState{
d.signalAddress,
d.signalState,
},
LocalPeerState: d.localPeer,
}

View File

@ -8,7 +8,7 @@ import (
func TestAddPeer(t *testing.T) {
key := "abc"
status := NewRecorder()
status := NewRecorder("https://mgm")
err := status.AddPeer(key)
assert.NoError(t, err, "shouldn't return error")
@ -22,7 +22,7 @@ func TestAddPeer(t *testing.T) {
func TestGetPeer(t *testing.T) {
key := "abc"
status := NewRecorder()
status := NewRecorder("https://mgm")
err := status.AddPeer(key)
assert.NoError(t, err, "shouldn't return error")
@ -38,7 +38,7 @@ func TestGetPeer(t *testing.T) {
func TestUpdatePeerState(t *testing.T) {
key := "abc"
ip := "10.10.10.10"
status := NewRecorder()
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
}
@ -58,7 +58,7 @@ func TestUpdatePeerState(t *testing.T) {
func TestStatus_UpdatePeerFQDN(t *testing.T) {
key := "abc"
fqdn := "peer-a.netbird.local"
status := NewRecorder()
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
}
@ -76,7 +76,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) {
func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
key := "abc"
ip := "10.10.10.10"
status := NewRecorder()
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
}
@ -100,7 +100,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
func TestRemovePeer(t *testing.T) {
key := "abc"
status := NewRecorder()
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
}
@ -123,7 +123,7 @@ func TestUpdateLocalPeerState(t *testing.T) {
PubKey: "abc",
KernelInterface: false,
}
status := NewRecorder()
status := NewRecorder("https://mgm")
status.UpdateLocalPeerState(localPeerState)
@ -137,7 +137,7 @@ func TestCleanLocalPeerState(t *testing.T) {
PubKey: "abc",
KernelInterface: false,
}
status := NewRecorder()
status := NewRecorder("https://mgm")
status.localPeer = localPeerState
@ -151,29 +151,23 @@ func TestUpdateSignalState(t *testing.T) {
var tests = []struct {
name string
connected bool
want SignalState
want bool
}{
{"should mark as connected", true, SignalState{
URL: url,
Connected: true,
}},
{"should mark as disconnected", false, SignalState{
URL: url,
Connected: false,
}},
{"should mark as connected", true, true},
{"should mark as disconnected", false, false},
}
status := NewRecorder()
status := NewRecorder("https://mgm")
status.UpdateSignalAddress(url)
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if test.connected {
status.MarkSignalConnected(url)
status.MarkSignalConnected()
} else {
status.MarkSignalDisconnected(url)
status.MarkSignalDisconnected()
}
assert.Equal(t, test.want, status.signal, "signal status should be equal")
assert.Equal(t, test.want, status.signalState, "signal status should be equal")
})
}
}
@ -183,29 +177,22 @@ func TestUpdateManagementState(t *testing.T) {
var tests = []struct {
name string
connected bool
want ManagementState
want bool
}{
{"should mark as connected", true, ManagementState{
URL: url,
Connected: true,
}},
{"should mark as disconnected", false, ManagementState{
URL: url,
Connected: false,
}},
{"should mark as connected", true, true},
{"should mark as disconnected", false, false},
}
status := NewRecorder()
status := NewRecorder(url)
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if test.connected {
status.MarkManagementConnected(url)
status.MarkManagementConnected()
} else {
status.MarkManagementDisconnected(url)
status.MarkManagementDisconnected()
}
assert.Equal(t, test.want, status.management, "signal status should be equal")
assert.Equal(t, test.want, status.managementState, "signalState status should be equal")
})
}
}
@ -213,12 +200,13 @@ func TestUpdateManagementState(t *testing.T) {
func TestGetFullStatus(t *testing.T) {
key1 := "abc"
key2 := "def"
signalAddr := "https://signal"
managementState := ManagementState{
URL: "https://signal",
URL: "https://mgm",
Connected: true,
}
signalState := SignalState{
URL: "https://signal",
URL: signalAddr,
Connected: true,
}
peerState1 := State{
@ -229,10 +217,11 @@ func TestGetFullStatus(t *testing.T) {
PubKey: key2,
}
status := NewRecorder()
status := NewRecorder("https://mgm")
status.UpdateSignalAddress(signalAddr)
status.management = managementState
status.signal = signalState
status.managementState = managementState.Connected
status.signalState = signalState.Connected
status.peers[key1] = peerState1
status.peers[key2] = peerState2

View File

@ -398,7 +398,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
statusRecorder := peer.NewRecorder()
statusRecorder := peer.NewRecorder("https://mgm")
ctx := context.TODO()
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder)
defer routeManager.Stop()

View File

@ -96,7 +96,9 @@ func (s *Server) Start() error {
s.config = config
if s.statusRecorder == nil {
s.statusRecorder = peer.NewRecorder()
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
} else {
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
}
go func() {
@ -386,7 +388,9 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
}
if s.statusRecorder == nil {
s.statusRecorder = peer.NewRecorder()
s.statusRecorder = peer.NewRecorder(s.config.ManagementURL.String())
} else {
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
}
go func() {
@ -430,7 +434,9 @@ func (s *Server) Status(
statusResponse := proto.StatusResponse{Status: string(status), DaemonVersion: version.NetbirdVersion()}
if s.statusRecorder == nil {
s.statusRecorder = peer.NewRecorder()
s.statusRecorder = peer.NewRecorder(s.config.ManagementURL.String())
} else {
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
}
if msg.GetFullPeerStatus {

View File

@ -4,15 +4,13 @@ import (
"context"
"crypto/tls"
"fmt"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"io"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
@ -20,13 +18,26 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"github.com/cenkalti/backoff/v4"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
)
// ConnStateNotifier is a wrapper interface of the status recorders
type ConnStateNotifier interface {
MarkManagementDisconnected()
MarkManagementConnected()
}
type GrpcClient struct {
key wgtypes.Key
realClient proto.ManagementServiceClient
ctx context.Context
conn *grpc.ClientConn
connStateCallback ConnStateNotifier
connStateCallbackLock sync.RWMutex
}
// NewClient creates a new client to Management service
@ -60,6 +71,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
realClient: realClient,
ctx: ctx,
conn: conn,
connStateCallbackLock: sync.RWMutex{},
}, nil
}
@ -68,6 +80,13 @@ func (c *GrpcClient) Close() error {
return c.conn.Close()
}
// SetConnStateListener set the ConnStateNotifier
func (c *GrpcClient) SetConnStateListener(notifier ConnStateNotifier) {
c.connStateCallbackLock.Lock()
defer c.connStateCallbackLock.Unlock()
c.connStateCallback = notifier
}
// defaultBackoff is a basic backoff mechanism for general issues
func defaultBackoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(&backoff.ExponentialBackOff{
@ -121,7 +140,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
}
log.Infof("connected to the Management Service stream")
c.notifyConnected()
// blocking until error
err = c.receiveEvents(stream, *serverPubKey, msgHandler)
if err != nil {
@ -131,6 +150,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
// we need this reset because after a successful connection and a consequent error, backoff lib doesn't
// reset times and next try will start with a long delay
backOff.Reset()
c.notifyDisconnected()
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
return err
}
@ -298,6 +318,26 @@ func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D
return flowInfoResp, nil
}
func (c *GrpcClient) notifyDisconnected() {
c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock()
if c.connStateCallback == nil {
return
}
c.connStateCallback.MarkManagementDisconnected()
}
func (c *GrpcClient) notifyConnected() {
c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock()
if c.connStateCallback == nil {
return
}
c.connStateCallback.MarkManagementConnected()
}
func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
if info == nil {
return nil

View File

@ -22,6 +22,12 @@ import (
"time"
)
// ConnStateNotifier is a wrapper interface of the status recorder
type ConnStateNotifier interface {
MarkSignalDisconnected()
MarkSignalConnected()
}
// GrpcClient Wraps the Signal Exchange Service gRpc client
type GrpcClient struct {
key wgtypes.Key
@ -34,6 +40,9 @@ type GrpcClient struct {
mux sync.Mutex
// StreamConnected indicates whether this client is StreamConnected to the Signal stream
status Status
connStateCallback ConnStateNotifier
connStateCallbackLock sync.RWMutex
}
func (c *GrpcClient) StreamConnected() bool {
@ -84,9 +93,17 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
key: key,
mux: sync.Mutex{},
status: StreamDisconnected,
connStateCallbackLock: sync.RWMutex{},
}, nil
}
// SetConnStateListener set the ConnStateNotifier
func (c *GrpcClient) SetConnStateListener(notifier ConnStateNotifier) {
c.connStateCallbackLock.Lock()
defer c.connStateCallbackLock.Unlock()
c.connStateCallback = notifier
}
// defaultBackoff is a basic backoff mechanism for general issues
func defaultBackoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(&backoff.ExponentialBackOff{
@ -134,13 +151,14 @@ func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) error {
c.notifyStreamConnected()
log.Infof("connected to the Signal Service stream")
c.notifyConnected()
// start receiving messages from the Signal stream (from other peers through signal)
err = c.receive(stream, msgHandler)
if err != nil {
// we need this reset because after a successful connection and a consequent error, backoff lib doesn't
// reset times and next try will start with a long delay
backOff.Reset()
c.notifyDisconnected()
log.Warnf("disconnected from the Signal service but will retry silently. Reason: %v", err)
return err
}
@ -341,3 +359,23 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
}
}
}
func (c *GrpcClient) notifyDisconnected() {
c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock()
if c.connStateCallback == nil {
return
}
c.connStateCallback.MarkSignalDisconnected()
}
func (c *GrpcClient) notifyConnected() {
c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock()
if c.connStateCallback == nil {
return
}
c.connStateCallback.MarkSignalConnected()
}