diff --git a/management/server/account.go b/management/server/account.go index d35ad2566..8e453a1fe 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -103,6 +103,7 @@ type AccountManager interface { UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error) LoginPeer(login PeerLogin) (*Peer, *NetworkMap, error) // used by peer gRPC API SyncPeer(sync PeerSync) (*Peer, *NetworkMap, error) // used by peer gRPC API + GetAllConnectedPeers() (map[string]struct{}, error) } type DefaultAccountManager struct { @@ -1558,6 +1559,11 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla } } +// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers() +func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) { + return am.peersUpdateManager.GetAllConnectedPeers(), nil +} + func isDomainValid(domain string) bool { re := regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`) return re.Match([]byte(domain)) diff --git a/management/server/file_store.go b/management/server/file_store.go index ecd02ba99..b90b1d607 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -111,10 +111,6 @@ func restore(file string) (*FileStore, error) { for _, peer := range account.Peers { store.PeerKeyID2AccountID[peer.Key] = accountID store.PeerID2AccountID[peer.ID] = accountID - // reset all peers to status = Disconnected - if peer.Status != nil && peer.Status.Connected { - peer.Status.Connected = false - } } for _, user := range account.Users { store.UserID2AccountID[user.Id] = accountID diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index adf4a9721..a485d6ccf 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -31,6 +31,24 @@ func NewPeersHandler(accountManager server.AccountManager, authCfg AuthCfg) *Pee } } +func (h *PeersHandler) checkPeerStatus(peer *server.Peer) (*server.Peer, error) { + peerToReturn := peer.Copy() + if peer.Status.Connected { + statuses, err := h.accountManager.GetAllConnectedPeers() + if err != nil { + return peerToReturn, err + } + + // Although we have online status in store we do not yet have an updated channel so have to show it as disconnected + // This may happen after server restart when not all peers are yet connected + if _, connected := statuses[peerToReturn.ID]; !connected { + peerToReturn.Status.Connected = false + } + } + + return peerToReturn, nil +} + func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w http.ResponseWriter) { peer, err := h.accountManager.GetPeer(account.Id, peerID, userID) if err != nil { @@ -38,7 +56,13 @@ func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w return } - util.WriteJSONObject(w, toPeerResponse(peer, account, h.accountManager.GetDNSDomain())) + peerToReturn, err := h.checkPeerStatus(peer) + if err != nil { + util.WriteError(err, w) + return + } + + util.WriteJSONObject(w, toPeerResponse(peerToReturn, account, h.accountManager.GetDNSDomain())) } func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) { @@ -120,7 +144,12 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { respBody := []*api.Peer{} for _, peer := range peers { - respBody = append(respBody, toPeerResponse(peer, account, dnsDomain)) + peerToReturn, err := h.checkPeerStatus(peer) + if err != nil { + util.WriteError(err, w) + return + } + respBody = append(respBody, toPeerResponse(peerToReturn, account, dnsDomain)) } util.WriteJSONObject(w, respBody) return diff --git a/management/server/http/peers_handler_test.go b/management/server/http/peers_handler_test.go index 7fe732f2f..1856861d5 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/peers_handler_test.go @@ -3,6 +3,7 @@ package http import ( "bytes" "encoding/json" + "fmt" "io" "net" "net/http" @@ -23,19 +24,33 @@ import ( ) const testPeerID = "test_peer" +const noUpdateChannelTestPeerID = "no-update-channel" func initTestMetaData(peers ...*server.Peer) *PeersHandler { return &PeersHandler{ accountManager: &mock_server.MockAccountManager{ UpdatePeerFunc: func(accountID, userID string, update *server.Peer) (*server.Peer, error) { - p := peers[0].Copy() + var p *server.Peer + for _, peer := range peers { + if update.ID == peer.ID { + p = peer.Copy() + break + } + } p.SSHEnabled = update.SSHEnabled p.LoginExpirationEnabled = update.LoginExpirationEnabled p.Name = update.Name return p, nil }, GetPeerFunc: func(accountID, peerID, userID string) (*server.Peer, error) { - return peers[0], nil + var p *server.Peer + for _, peer := range peers { + if peerID == peer.ID { + p = peer.Copy() + break + } + } + return p, nil }, GetPeersFunc: func(accountID, userID string) ([]*server.Peer, error) { return peers, nil @@ -57,6 +72,16 @@ func initTestMetaData(peers ...*server.Peer) *PeersHandler { }, }, user, nil }, + GetAllConnectedPeersFunc: func() (map[string]struct{}, error) { + statuses := make(map[string]struct{}) + for _, peer := range peers { + if peer.ID == noUpdateChannelTestPeerID { + break + } + statuses[peer.ID] = struct{}{} + } + return statuses, nil + }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { @@ -79,7 +104,7 @@ func TestGetPeers(t *testing.T) { Key: "key", SetupKey: "setupkey", IP: net.ParseIP("100.64.0.1"), - Status: &server.PeerStatus{}, + Status: &server.PeerStatus{Connected: true}, Name: "PeerName", LoginExpirationEnabled: false, Meta: server.PeerSystemMeta{ @@ -93,11 +118,17 @@ func TestGetPeers(t *testing.T) { }, } + peer1 := peer.Copy() + peer1.ID = noUpdateChannelTestPeerID + expectedUpdatedPeer := peer.Copy() expectedUpdatedPeer.LoginExpirationEnabled = true expectedUpdatedPeer.SSHEnabled = true expectedUpdatedPeer.Name = "New Name" + expectedPeer1 := peer1.Copy() + expectedPeer1.Status.Connected = false + tt := []struct { name string expectedStatus int @@ -116,13 +147,21 @@ func TestGetPeers(t *testing.T) { expectedPeer: peer, }, { - name: "GetPeer", + name: "GetPeer with update channel", requestType: http.MethodGet, requestPath: "/api/peers/" + testPeerID, expectedStatus: http.StatusOK, expectedArray: false, expectedPeer: peer, }, + { + name: "GetPeer with no update channel", + requestType: http.MethodGet, + requestPath: "/api/peers/" + peer1.ID, + expectedStatus: http.StatusOK, + expectedArray: false, + expectedPeer: expectedPeer1, + }, { name: "PutPeer", requestType: http.MethodPut, @@ -136,7 +175,7 @@ func TestGetPeers(t *testing.T) { rr := httptest.NewRecorder() - p := initTestMetaData(peer) + p := initTestMetaData(peer, peer1) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { @@ -171,6 +210,10 @@ func TestGetPeers(t *testing.T) { t.Fatalf("Sent content is not in correct json format; %v", err) } + // hardcode this check for now as we only have two peers in this suite + assert.Equal(t, len(respBody), 2) + assert.Equal(t, respBody[1].Connected, false) + got = respBody[0] } else { got = &api.Peer{} @@ -180,12 +223,15 @@ func TestGetPeers(t *testing.T) { } } + fmt.Println(got) + assert.Equal(t, got.Name, tc.expectedPeer.Name) assert.Equal(t, got.Version, tc.expectedPeer.Meta.WtVersion) assert.Equal(t, got.Ip, tc.expectedPeer.IP.String()) assert.Equal(t, got.Os, "OS core") assert.Equal(t, got.LoginExpirationEnabled, tc.expectedPeer.LoginExpirationEnabled) assert.Equal(t, got.SshEnabled, tc.expectedPeer.SSHEnabled) + assert.Equal(t, got.Connected, tc.expectedPeer.Status.Connected) }) } } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 5432b201b..ab3748c01 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -75,6 +75,7 @@ type MockAccountManager struct { LoginPeerFunc func(login server.PeerLogin) (*server.Peer, *server.NetworkMap, error) SyncPeerFunc func(sync server.PeerSync) (*server.Peer, *server.NetworkMap, error) InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error + GetAllConnectedPeersFunc func() (map[string]struct{}, error) } // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface @@ -583,3 +584,11 @@ func (am *MockAccountManager) SyncPeer(sync server.PeerSync) (*server.Peer, *ser } return nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") } + +// GetAllConnectedPeers mocks GetAllConnectedPeers of the AccountManager interface +func (am *MockAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) { + if am.GetAllConnectedPeersFunc != nil { + return am.GetAllConnectedPeersFunc() + } + return nil, status.Errorf(codes.Unimplemented, "method GetAllConnectedPeers is not implemented") +}