Fix connstate indication (#732)

Fix the status indication in the client service. The status of the
management server and the signal server was incorrect if the network
connection was broken. Basically the status update was not used by
the management and signal library.
This commit is contained in:
Zoltan Papp 2023-03-16 17:22:36 +01:00 committed by GitHub
parent 731d3ae464
commit 747797271e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 219 additions and 116 deletions

View File

@ -94,7 +94,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx) ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel) SetupCloseHandler(ctx, cancel)
return internal.RunClient(ctx, config, peer.NewRecorder()) return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
} }
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {

View File

@ -58,8 +58,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status)
return err return err
} }
managementURL := config.ManagementURL.String() statusRecorder.MarkManagementDisconnected()
statusRecorder.MarkManagementDisconnected(managementURL)
operation := func() error { operation := func() error {
// if context cancelled we not start new backoff cycle // if context cancelled we not start new backoff cycle
@ -73,13 +72,16 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status)
engineCtx, cancel := context.WithCancel(ctx) engineCtx, cancel := context.WithCancel(ctx)
defer func() { defer func() {
statusRecorder.MarkManagementDisconnected(managementURL) statusRecorder.MarkManagementDisconnected()
statusRecorder.CleanLocalPeerState() statusRecorder.CleanLocalPeerState()
cancel() cancel()
}() }()
log.Debugf("conecting to the Management service %s", config.ManagementURL.Host) log.Debugf("conecting to the Management service %s", config.ManagementURL.Host)
mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled) mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
mgmNotifier := statusRecorderToMgmConnStateNotifier(statusRecorder)
mgmClient.SetConnStateListener(mgmNotifier)
if err != nil { if err != nil {
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err)) return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
} }
@ -101,7 +103,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status)
} }
return wrapErr(err) return wrapErr(err)
} }
statusRecorder.MarkManagementConnected(managementURL) statusRecorder.MarkManagementConnected()
localPeerState := peer.LocalPeerState{ localPeerState := peer.LocalPeerState{
IP: loginResp.GetPeerConfig().GetAddress(), IP: loginResp.GetPeerConfig().GetAddress(),
@ -117,8 +119,10 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status)
loginResp.GetWiretrusteeConfig().GetSignal().GetUri(), loginResp.GetWiretrusteeConfig().GetSignal().GetUri(),
) )
statusRecorder.MarkSignalDisconnected(signalURL) statusRecorder.UpdateSignalAddress(signalURL)
defer statusRecorder.MarkSignalDisconnected(signalURL)
statusRecorder.MarkSignalDisconnected()
defer statusRecorder.MarkSignalDisconnected()
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal // with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
signalClient, err := connectToSignal(engineCtx, loginResp.GetWiretrusteeConfig(), myPrivateKey) signalClient, err := connectToSignal(engineCtx, loginResp.GetWiretrusteeConfig(), myPrivateKey)
@ -133,7 +137,10 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status)
} }
}() }()
statusRecorder.MarkSignalConnected(signalURL) signalNotifier := statusRecorderToSignalConnStateNotifier(statusRecorder)
signalClient.SetConnStateListener(signalNotifier)
statusRecorder.MarkSignalConnected()
peerConfig := loginResp.GetPeerConfig() peerConfig := loginResp.GetPeerConfig()
@ -320,3 +327,15 @@ func UpdateOldManagementPort(ctx context.Context, config *Config, configPath str
return config, nil return config, nil
} }
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {
var sri interface{} = statusRecorder
mgmNotifier, _ := sri.(mgm.ConnStateNotifier)
return mgmNotifier
}
func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal.ConnStateNotifier {
var sri interface{} = statusRecorder
notifier, _ := sri.(signal.ConnStateNotifier)
return notifier
}

View File

@ -72,7 +72,7 @@ func TestEngine_SSH(t *testing.T) {
WgAddr: "100.64.0.1/24", WgAddr: "100.64.0.1/24",
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, peer.NewRecorder()) }, peer.NewRecorder("https://mgm"))
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@ -206,7 +206,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
WgAddr: "100.64.0.1/24", WgAddr: "100.64.0.1/24",
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, peer.NewRecorder()) }, peer.NewRecorder("https://mgm"))
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU) engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU)
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder) engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder)
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
@ -390,7 +390,7 @@ func TestEngine_Sync(t *testing.T) {
WgAddr: "100.64.0.1/24", WgAddr: "100.64.0.1/24",
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, peer.NewRecorder()) }, peer.NewRecorder("https://mgm"))
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@ -548,7 +548,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgAddr: wgAddr, WgAddr: wgAddr,
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, peer.NewRecorder()) }, peer.NewRecorder("https://mgm"))
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU) engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
input := struct { input := struct {
@ -713,7 +713,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgAddr: wgAddr, WgAddr: wgAddr,
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, peer.NewRecorder()) }, peer.NewRecorder("https://mgm"))
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU) engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
@ -978,7 +978,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
WgPort: wgPort, WgPort: wgPort,
} }
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, peer.NewRecorder()), nil return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, peer.NewRecorder("https://mgm")), nil
} }
func startSignal() (*grpc.Server, string, error) { func startSignal() (*grpc.Server, string, error) {

View File

@ -49,7 +49,7 @@ func TestConn_GetKey(t *testing.T) {
func TestConn_OnRemoteOffer(t *testing.T) { func TestConn_OnRemoteOffer(t *testing.T) {
conn, err := NewConn(connConf, NewRecorder()) conn, err := NewConn(connConf, NewRecorder("https://mgm"))
if err != nil { if err != nil {
return return
} }
@ -83,7 +83,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
func TestConn_OnRemoteAnswer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) {
conn, err := NewConn(connConf, NewRecorder()) conn, err := NewConn(connConf, NewRecorder("https://mgm"))
if err != nil { if err != nil {
return return
} }
@ -116,7 +116,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
} }
func TestConn_Status(t *testing.T) { func TestConn_Status(t *testing.T) {
conn, err := NewConn(connConf, NewRecorder()) conn, err := NewConn(connConf, NewRecorder("https://mgm"))
if err != nil { if err != nil {
return return
} }
@ -143,7 +143,7 @@ func TestConn_Status(t *testing.T) {
func TestConn_Close(t *testing.T) { func TestConn_Close(t *testing.T) {
conn, err := NewConn(connConf, NewRecorder()) conn, err := NewConn(connConf, NewRecorder("https://mgm"))
if err != nil { if err != nil {
return return
} }

View File

@ -49,21 +49,24 @@ type FullStatus struct {
// Status holds a state of peers, signal and management connections // Status holds a state of peers, signal and management connections
type Status struct { type Status struct {
mux sync.Mutex mux sync.Mutex
peers map[string]State peers map[string]State
changeNotify map[string]chan struct{} changeNotify map[string]chan struct{}
signal SignalState signalState bool
management ManagementState managementState bool
localPeer LocalPeerState localPeer LocalPeerState
offlinePeers []State offlinePeers []State
mgmAddress string
signalAddress string
} }
// NewRecorder returns a new Status instance // NewRecorder returns a new Status instance
func NewRecorder() *Status { func NewRecorder(mgmAddress string) *Status {
return &Status{ return &Status{
peers: make(map[string]State), peers: make(map[string]State),
changeNotify: make(map[string]chan struct{}), changeNotify: make(map[string]chan struct{}),
offlinePeers: make([]State, 0), offlinePeers: make([]State, 0),
mgmAddress: mgmAddress,
} }
} }
@ -193,43 +196,45 @@ func (d *Status) CleanLocalPeerState() {
} }
// MarkManagementDisconnected sets ManagementState to disconnected // MarkManagementDisconnected sets ManagementState to disconnected
func (d *Status) MarkManagementDisconnected(managementURL string) { func (d *Status) MarkManagementDisconnected() {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
d.management = ManagementState{ d.managementState = false
URL: managementURL,
Connected: false,
}
} }
// MarkManagementConnected sets ManagementState to connected // MarkManagementConnected sets ManagementState to connected
func (d *Status) MarkManagementConnected(managementURL string) { func (d *Status) MarkManagementConnected() {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
d.management = ManagementState{ d.managementState = true
URL: managementURL, }
Connected: true,
} // UpdateSignalAddress update the address of the signal server
func (d *Status) UpdateSignalAddress(signalURL string) {
d.mux.Lock()
defer d.mux.Unlock()
d.signalAddress = signalURL
}
// UpdateManagementAddress update the address of the management server
func (d *Status) UpdateManagementAddress(mgmAddress string) {
d.mux.Lock()
defer d.mux.Unlock()
d.mgmAddress = mgmAddress
} }
// MarkSignalDisconnected sets SignalState to disconnected // MarkSignalDisconnected sets SignalState to disconnected
func (d *Status) MarkSignalDisconnected(signalURL string) { func (d *Status) MarkSignalDisconnected() {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
d.signal = SignalState{ d.signalState = false
signalURL,
false,
}
} }
// MarkSignalConnected sets SignalState to connected // MarkSignalConnected sets SignalState to connected
func (d *Status) MarkSignalConnected(signalURL string) { func (d *Status) MarkSignalConnected() {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
d.signal = SignalState{ d.signalState = true
signalURL,
true,
}
} }
// GetFullStatus gets full status // GetFullStatus gets full status
@ -238,9 +243,15 @@ func (d *Status) GetFullStatus() FullStatus {
defer d.mux.Unlock() defer d.mux.Unlock()
fullStatus := FullStatus{ fullStatus := FullStatus{
ManagementState: d.management, ManagementState: ManagementState{
SignalState: d.signal, d.mgmAddress,
LocalPeerState: d.localPeer, d.managementState,
},
SignalState: SignalState{
d.signalAddress,
d.signalState,
},
LocalPeerState: d.localPeer,
} }
for _, status := range d.peers { for _, status := range d.peers {

View File

@ -8,7 +8,7 @@ import (
func TestAddPeer(t *testing.T) { func TestAddPeer(t *testing.T) {
key := "abc" key := "abc"
status := NewRecorder() status := NewRecorder("https://mgm")
err := status.AddPeer(key) err := status.AddPeer(key)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
@ -22,7 +22,7 @@ func TestAddPeer(t *testing.T) {
func TestGetPeer(t *testing.T) { func TestGetPeer(t *testing.T) {
key := "abc" key := "abc"
status := NewRecorder() status := NewRecorder("https://mgm")
err := status.AddPeer(key) err := status.AddPeer(key)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
@ -38,7 +38,7 @@ func TestGetPeer(t *testing.T) {
func TestUpdatePeerState(t *testing.T) { func TestUpdatePeerState(t *testing.T) {
key := "abc" key := "abc"
ip := "10.10.10.10" ip := "10.10.10.10"
status := NewRecorder() status := NewRecorder("https://mgm")
peerState := State{ peerState := State{
PubKey: key, PubKey: key,
} }
@ -58,7 +58,7 @@ func TestUpdatePeerState(t *testing.T) {
func TestStatus_UpdatePeerFQDN(t *testing.T) { func TestStatus_UpdatePeerFQDN(t *testing.T) {
key := "abc" key := "abc"
fqdn := "peer-a.netbird.local" fqdn := "peer-a.netbird.local"
status := NewRecorder() status := NewRecorder("https://mgm")
peerState := State{ peerState := State{
PubKey: key, PubKey: key,
} }
@ -76,7 +76,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) {
func TestGetPeerStateChangeNotifierLogic(t *testing.T) { func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
key := "abc" key := "abc"
ip := "10.10.10.10" ip := "10.10.10.10"
status := NewRecorder() status := NewRecorder("https://mgm")
peerState := State{ peerState := State{
PubKey: key, PubKey: key,
} }
@ -100,7 +100,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
func TestRemovePeer(t *testing.T) { func TestRemovePeer(t *testing.T) {
key := "abc" key := "abc"
status := NewRecorder() status := NewRecorder("https://mgm")
peerState := State{ peerState := State{
PubKey: key, PubKey: key,
} }
@ -123,7 +123,7 @@ func TestUpdateLocalPeerState(t *testing.T) {
PubKey: "abc", PubKey: "abc",
KernelInterface: false, KernelInterface: false,
} }
status := NewRecorder() status := NewRecorder("https://mgm")
status.UpdateLocalPeerState(localPeerState) status.UpdateLocalPeerState(localPeerState)
@ -137,7 +137,7 @@ func TestCleanLocalPeerState(t *testing.T) {
PubKey: "abc", PubKey: "abc",
KernelInterface: false, KernelInterface: false,
} }
status := NewRecorder() status := NewRecorder("https://mgm")
status.localPeer = localPeerState status.localPeer = localPeerState
@ -151,29 +151,23 @@ func TestUpdateSignalState(t *testing.T) {
var tests = []struct { var tests = []struct {
name string name string
connected bool connected bool
want SignalState want bool
}{ }{
{"should mark as connected", true, SignalState{ {"should mark as connected", true, true},
{"should mark as disconnected", false, false},
URL: url,
Connected: true,
}},
{"should mark as disconnected", false, SignalState{
URL: url,
Connected: false,
}},
} }
status := NewRecorder() status := NewRecorder("https://mgm")
status.UpdateSignalAddress(url)
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
if test.connected { if test.connected {
status.MarkSignalConnected(url) status.MarkSignalConnected()
} else { } else {
status.MarkSignalDisconnected(url) status.MarkSignalDisconnected()
} }
assert.Equal(t, test.want, status.signal, "signal status should be equal") assert.Equal(t, test.want, status.signalState, "signal status should be equal")
}) })
} }
} }
@ -183,29 +177,22 @@ func TestUpdateManagementState(t *testing.T) {
var tests = []struct { var tests = []struct {
name string name string
connected bool connected bool
want ManagementState want bool
}{ }{
{"should mark as connected", true, ManagementState{ {"should mark as connected", true, true},
{"should mark as disconnected", false, false},
URL: url,
Connected: true,
}},
{"should mark as disconnected", false, ManagementState{
URL: url,
Connected: false,
}},
} }
status := NewRecorder() status := NewRecorder(url)
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
if test.connected { if test.connected {
status.MarkManagementConnected(url) status.MarkManagementConnected()
} else { } else {
status.MarkManagementDisconnected(url) status.MarkManagementDisconnected()
} }
assert.Equal(t, test.want, status.management, "signal status should be equal") assert.Equal(t, test.want, status.managementState, "signalState status should be equal")
}) })
} }
} }
@ -213,12 +200,13 @@ func TestUpdateManagementState(t *testing.T) {
func TestGetFullStatus(t *testing.T) { func TestGetFullStatus(t *testing.T) {
key1 := "abc" key1 := "abc"
key2 := "def" key2 := "def"
signalAddr := "https://signal"
managementState := ManagementState{ managementState := ManagementState{
URL: "https://signal", URL: "https://mgm",
Connected: true, Connected: true,
} }
signalState := SignalState{ signalState := SignalState{
URL: "https://signal", URL: signalAddr,
Connected: true, Connected: true,
} }
peerState1 := State{ peerState1 := State{
@ -229,10 +217,11 @@ func TestGetFullStatus(t *testing.T) {
PubKey: key2, PubKey: key2,
} }
status := NewRecorder() status := NewRecorder("https://mgm")
status.UpdateSignalAddress(signalAddr)
status.management = managementState status.managementState = managementState.Connected
status.signal = signalState status.signalState = signalState.Connected
status.peers[key1] = peerState1 status.peers[key1] = peerState1
status.peers[key2] = peerState2 status.peers[key2] = peerState2

View File

@ -398,7 +398,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
err = wgInterface.Create() err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface") require.NoError(t, err, "should create testing wireguard interface")
statusRecorder := peer.NewRecorder() statusRecorder := peer.NewRecorder("https://mgm")
ctx := context.TODO() ctx := context.TODO()
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder) routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder)
defer routeManager.Stop() defer routeManager.Stop()

View File

@ -96,7 +96,9 @@ func (s *Server) Start() error {
s.config = config s.config = config
if s.statusRecorder == nil { if s.statusRecorder == nil {
s.statusRecorder = peer.NewRecorder() s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
} else {
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
} }
go func() { go func() {
@ -386,7 +388,9 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
} }
if s.statusRecorder == nil { if s.statusRecorder == nil {
s.statusRecorder = peer.NewRecorder() s.statusRecorder = peer.NewRecorder(s.config.ManagementURL.String())
} else {
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
} }
go func() { go func() {
@ -430,7 +434,9 @@ func (s *Server) Status(
statusResponse := proto.StatusResponse{Status: string(status), DaemonVersion: version.NetbirdVersion()} statusResponse := proto.StatusResponse{Status: string(status), DaemonVersion: version.NetbirdVersion()}
if s.statusRecorder == nil { if s.statusRecorder == nil {
s.statusRecorder = peer.NewRecorder() s.statusRecorder = peer.NewRecorder(s.config.ManagementURL.String())
} else {
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
} }
if msg.GetFullPeerStatus { if msg.GetFullPeerStatus {

View File

@ -4,15 +4,13 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"io" "io"
"sync"
"time" "time"
"github.com/cenkalti/backoff/v4" "google.golang.org/grpc/codes"
"github.com/netbirdio/netbird/client/system" gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -20,13 +18,26 @@ import (
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"github.com/cenkalti/backoff/v4"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
) )
// ConnStateNotifier is a wrapper interface of the status recorders
type ConnStateNotifier interface {
MarkManagementDisconnected()
MarkManagementConnected()
}
type GrpcClient struct { type GrpcClient struct {
key wgtypes.Key key wgtypes.Key
realClient proto.ManagementServiceClient realClient proto.ManagementServiceClient
ctx context.Context ctx context.Context
conn *grpc.ClientConn conn *grpc.ClientConn
connStateCallback ConnStateNotifier
connStateCallbackLock sync.RWMutex
} }
// NewClient creates a new client to Management service // NewClient creates a new client to Management service
@ -56,10 +67,11 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
realClient := proto.NewManagementServiceClient(conn) realClient := proto.NewManagementServiceClient(conn)
return &GrpcClient{ return &GrpcClient{
key: ourPrivateKey, key: ourPrivateKey,
realClient: realClient, realClient: realClient,
ctx: ctx, ctx: ctx,
conn: conn, conn: conn,
connStateCallbackLock: sync.RWMutex{},
}, nil }, nil
} }
@ -68,6 +80,13 @@ func (c *GrpcClient) Close() error {
return c.conn.Close() return c.conn.Close()
} }
// SetConnStateListener set the ConnStateNotifier
func (c *GrpcClient) SetConnStateListener(notifier ConnStateNotifier) {
c.connStateCallbackLock.Lock()
defer c.connStateCallbackLock.Unlock()
c.connStateCallback = notifier
}
// defaultBackoff is a basic backoff mechanism for general issues // defaultBackoff is a basic backoff mechanism for general issues
func defaultBackoff(ctx context.Context) backoff.BackOff { func defaultBackoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(&backoff.ExponentialBackOff{ return backoff.WithContext(&backoff.ExponentialBackOff{
@ -121,7 +140,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
} }
log.Infof("connected to the Management Service stream") log.Infof("connected to the Management Service stream")
c.notifyConnected()
// blocking until error // blocking until error
err = c.receiveEvents(stream, *serverPubKey, msgHandler) err = c.receiveEvents(stream, *serverPubKey, msgHandler)
if err != nil { if err != nil {
@ -131,6 +150,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
// we need this reset because after a successful connection and a consequent error, backoff lib doesn't // we need this reset because after a successful connection and a consequent error, backoff lib doesn't
// reset times and next try will start with a long delay // reset times and next try will start with a long delay
backOff.Reset() backOff.Reset()
c.notifyDisconnected()
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err) log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
return err return err
} }
@ -298,6 +318,26 @@ func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D
return flowInfoResp, nil return flowInfoResp, nil
} }
func (c *GrpcClient) notifyDisconnected() {
c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock()
if c.connStateCallback == nil {
return
}
c.connStateCallback.MarkManagementDisconnected()
}
func (c *GrpcClient) notifyConnected() {
c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock()
if c.connStateCallback == nil {
return
}
c.connStateCallback.MarkManagementConnected()
}
func infoToMetaData(info *system.Info) *proto.PeerSystemMeta { func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
if info == nil { if info == nil {
return nil return nil

View File

@ -22,6 +22,12 @@ import (
"time" "time"
) )
// ConnStateNotifier is a wrapper interface of the status recorder
type ConnStateNotifier interface {
MarkSignalDisconnected()
MarkSignalConnected()
}
// GrpcClient Wraps the Signal Exchange Service gRpc client // GrpcClient Wraps the Signal Exchange Service gRpc client
type GrpcClient struct { type GrpcClient struct {
key wgtypes.Key key wgtypes.Key
@ -34,6 +40,9 @@ type GrpcClient struct {
mux sync.Mutex mux sync.Mutex
// StreamConnected indicates whether this client is StreamConnected to the Signal stream // StreamConnected indicates whether this client is StreamConnected to the Signal stream
status Status status Status
connStateCallback ConnStateNotifier
connStateCallbackLock sync.RWMutex
} }
func (c *GrpcClient) StreamConnected() bool { func (c *GrpcClient) StreamConnected() bool {
@ -78,15 +87,23 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
log.Debugf("connected to Signal Service: %v", conn.Target()) log.Debugf("connected to Signal Service: %v", conn.Target())
return &GrpcClient{ return &GrpcClient{
realClient: proto.NewSignalExchangeClient(conn), realClient: proto.NewSignalExchangeClient(conn),
ctx: ctx, ctx: ctx,
signalConn: conn, signalConn: conn,
key: key, key: key,
mux: sync.Mutex{}, mux: sync.Mutex{},
status: StreamDisconnected, status: StreamDisconnected,
connStateCallbackLock: sync.RWMutex{},
}, nil }, nil
} }
// SetConnStateListener set the ConnStateNotifier
func (c *GrpcClient) SetConnStateListener(notifier ConnStateNotifier) {
c.connStateCallbackLock.Lock()
defer c.connStateCallbackLock.Unlock()
c.connStateCallback = notifier
}
// defaultBackoff is a basic backoff mechanism for general issues // defaultBackoff is a basic backoff mechanism for general issues
func defaultBackoff(ctx context.Context) backoff.BackOff { func defaultBackoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(&backoff.ExponentialBackOff{ return backoff.WithContext(&backoff.ExponentialBackOff{
@ -134,13 +151,14 @@ func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) error {
c.notifyStreamConnected() c.notifyStreamConnected()
log.Infof("connected to the Signal Service stream") log.Infof("connected to the Signal Service stream")
c.notifyConnected()
// start receiving messages from the Signal stream (from other peers through signal) // start receiving messages from the Signal stream (from other peers through signal)
err = c.receive(stream, msgHandler) err = c.receive(stream, msgHandler)
if err != nil { if err != nil {
// we need this reset because after a successful connection and a consequent error, backoff lib doesn't // we need this reset because after a successful connection and a consequent error, backoff lib doesn't
// reset times and next try will start with a long delay // reset times and next try will start with a long delay
backOff.Reset() backOff.Reset()
c.notifyDisconnected()
log.Warnf("disconnected from the Signal service but will retry silently. Reason: %v", err) log.Warnf("disconnected from the Signal service but will retry silently. Reason: %v", err)
return err return err
} }
@ -341,3 +359,23 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
} }
} }
} }
func (c *GrpcClient) notifyDisconnected() {
c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock()
if c.connStateCallback == nil {
return
}
c.connStateCallback.MarkSignalDisconnected()
}
func (c *GrpcClient) notifyConnected() {
c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock()
if c.connStateCallback == nil {
return
}
c.connStateCallback.MarkSignalConnected()
}