mirror of
https://github.com/netbirdio/netbird.git
synced 2025-04-10 18:58:27 +02:00
Rework peer connection status based on the update channel existence (#1213)
With this change, we don't need to update all peers on startup. We will check the existence of an update channel when returning a list or single peer on API. Then after restarting of server consumers of API will see peer not connected status till the creation of an updated channel which indicates peer successful connection.
This commit is contained in:
parent
4ad14cb46b
commit
659110f0d5
@ -103,6 +103,7 @@ type AccountManager interface {
|
|||||||
UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error)
|
UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error)
|
||||||
LoginPeer(login PeerLogin) (*Peer, *NetworkMap, error) // used by peer gRPC API
|
LoginPeer(login PeerLogin) (*Peer, *NetworkMap, error) // used by peer gRPC API
|
||||||
SyncPeer(sync PeerSync) (*Peer, *NetworkMap, error) // used by peer gRPC API
|
SyncPeer(sync PeerSync) (*Peer, *NetworkMap, error) // used by peer gRPC API
|
||||||
|
GetAllConnectedPeers() (map[string]struct{}, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type DefaultAccountManager struct {
|
type DefaultAccountManager struct {
|
||||||
@ -1558,6 +1559,11 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers()
|
||||||
|
func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) {
|
||||||
|
return am.peersUpdateManager.GetAllConnectedPeers(), nil
|
||||||
|
}
|
||||||
|
|
||||||
func isDomainValid(domain string) bool {
|
func isDomainValid(domain string) bool {
|
||||||
re := regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
|
re := regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
|
||||||
return re.Match([]byte(domain))
|
return re.Match([]byte(domain))
|
||||||
|
@ -111,10 +111,6 @@ func restore(file string) (*FileStore, error) {
|
|||||||
for _, peer := range account.Peers {
|
for _, peer := range account.Peers {
|
||||||
store.PeerKeyID2AccountID[peer.Key] = accountID
|
store.PeerKeyID2AccountID[peer.Key] = accountID
|
||||||
store.PeerID2AccountID[peer.ID] = accountID
|
store.PeerID2AccountID[peer.ID] = accountID
|
||||||
// reset all peers to status = Disconnected
|
|
||||||
if peer.Status != nil && peer.Status.Connected {
|
|
||||||
peer.Status.Connected = false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
for _, user := range account.Users {
|
for _, user := range account.Users {
|
||||||
store.UserID2AccountID[user.Id] = accountID
|
store.UserID2AccountID[user.Id] = accountID
|
||||||
|
@ -31,6 +31,24 @@ func NewPeersHandler(accountManager server.AccountManager, authCfg AuthCfg) *Pee
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *PeersHandler) checkPeerStatus(peer *server.Peer) (*server.Peer, error) {
|
||||||
|
peerToReturn := peer.Copy()
|
||||||
|
if peer.Status.Connected {
|
||||||
|
statuses, err := h.accountManager.GetAllConnectedPeers()
|
||||||
|
if err != nil {
|
||||||
|
return peerToReturn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Although we have online status in store we do not yet have an updated channel so have to show it as disconnected
|
||||||
|
// This may happen after server restart when not all peers are yet connected
|
||||||
|
if _, connected := statuses[peerToReturn.ID]; !connected {
|
||||||
|
peerToReturn.Status.Connected = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return peerToReturn, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w http.ResponseWriter) {
|
func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w http.ResponseWriter) {
|
||||||
peer, err := h.accountManager.GetPeer(account.Id, peerID, userID)
|
peer, err := h.accountManager.GetPeer(account.Id, peerID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -38,7 +56,13 @@ func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
util.WriteJSONObject(w, toPeerResponse(peer, account, h.accountManager.GetDNSDomain()))
|
peerToReturn, err := h.checkPeerStatus(peer)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(w, toPeerResponse(peerToReturn, account, h.accountManager.GetDNSDomain()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
|
func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||||
@ -120,7 +144,12 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
respBody := []*api.Peer{}
|
respBody := []*api.Peer{}
|
||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
respBody = append(respBody, toPeerResponse(peer, account, dnsDomain))
|
peerToReturn, err := h.checkPeerStatus(peer)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
respBody = append(respBody, toPeerResponse(peerToReturn, account, dnsDomain))
|
||||||
}
|
}
|
||||||
util.WriteJSONObject(w, respBody)
|
util.WriteJSONObject(w, respBody)
|
||||||
return
|
return
|
||||||
|
@ -3,6 +3,7 @@ package http
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -23,19 +24,33 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const testPeerID = "test_peer"
|
const testPeerID = "test_peer"
|
||||||
|
const noUpdateChannelTestPeerID = "no-update-channel"
|
||||||
|
|
||||||
func initTestMetaData(peers ...*server.Peer) *PeersHandler {
|
func initTestMetaData(peers ...*server.Peer) *PeersHandler {
|
||||||
return &PeersHandler{
|
return &PeersHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
UpdatePeerFunc: func(accountID, userID string, update *server.Peer) (*server.Peer, error) {
|
UpdatePeerFunc: func(accountID, userID string, update *server.Peer) (*server.Peer, error) {
|
||||||
p := peers[0].Copy()
|
var p *server.Peer
|
||||||
|
for _, peer := range peers {
|
||||||
|
if update.ID == peer.ID {
|
||||||
|
p = peer.Copy()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
p.SSHEnabled = update.SSHEnabled
|
p.SSHEnabled = update.SSHEnabled
|
||||||
p.LoginExpirationEnabled = update.LoginExpirationEnabled
|
p.LoginExpirationEnabled = update.LoginExpirationEnabled
|
||||||
p.Name = update.Name
|
p.Name = update.Name
|
||||||
return p, nil
|
return p, nil
|
||||||
},
|
},
|
||||||
GetPeerFunc: func(accountID, peerID, userID string) (*server.Peer, error) {
|
GetPeerFunc: func(accountID, peerID, userID string) (*server.Peer, error) {
|
||||||
return peers[0], nil
|
var p *server.Peer
|
||||||
|
for _, peer := range peers {
|
||||||
|
if peerID == peer.ID {
|
||||||
|
p = peer.Copy()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return p, nil
|
||||||
},
|
},
|
||||||
GetPeersFunc: func(accountID, userID string) ([]*server.Peer, error) {
|
GetPeersFunc: func(accountID, userID string) ([]*server.Peer, error) {
|
||||||
return peers, nil
|
return peers, nil
|
||||||
@ -57,6 +72,16 @@ func initTestMetaData(peers ...*server.Peer) *PeersHandler {
|
|||||||
},
|
},
|
||||||
}, user, nil
|
}, user, nil
|
||||||
},
|
},
|
||||||
|
GetAllConnectedPeersFunc: func() (map[string]struct{}, error) {
|
||||||
|
statuses := make(map[string]struct{})
|
||||||
|
for _, peer := range peers {
|
||||||
|
if peer.ID == noUpdateChannelTestPeerID {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
statuses[peer.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
return statuses, nil
|
||||||
|
},
|
||||||
},
|
},
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||||
@ -79,7 +104,7 @@ func TestGetPeers(t *testing.T) {
|
|||||||
Key: "key",
|
Key: "key",
|
||||||
SetupKey: "setupkey",
|
SetupKey: "setupkey",
|
||||||
IP: net.ParseIP("100.64.0.1"),
|
IP: net.ParseIP("100.64.0.1"),
|
||||||
Status: &server.PeerStatus{},
|
Status: &server.PeerStatus{Connected: true},
|
||||||
Name: "PeerName",
|
Name: "PeerName",
|
||||||
LoginExpirationEnabled: false,
|
LoginExpirationEnabled: false,
|
||||||
Meta: server.PeerSystemMeta{
|
Meta: server.PeerSystemMeta{
|
||||||
@ -93,11 +118,17 @@ func TestGetPeers(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
peer1 := peer.Copy()
|
||||||
|
peer1.ID = noUpdateChannelTestPeerID
|
||||||
|
|
||||||
expectedUpdatedPeer := peer.Copy()
|
expectedUpdatedPeer := peer.Copy()
|
||||||
expectedUpdatedPeer.LoginExpirationEnabled = true
|
expectedUpdatedPeer.LoginExpirationEnabled = true
|
||||||
expectedUpdatedPeer.SSHEnabled = true
|
expectedUpdatedPeer.SSHEnabled = true
|
||||||
expectedUpdatedPeer.Name = "New Name"
|
expectedUpdatedPeer.Name = "New Name"
|
||||||
|
|
||||||
|
expectedPeer1 := peer1.Copy()
|
||||||
|
expectedPeer1.Status.Connected = false
|
||||||
|
|
||||||
tt := []struct {
|
tt := []struct {
|
||||||
name string
|
name string
|
||||||
expectedStatus int
|
expectedStatus int
|
||||||
@ -116,13 +147,21 @@ func TestGetPeers(t *testing.T) {
|
|||||||
expectedPeer: peer,
|
expectedPeer: peer,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "GetPeer",
|
name: "GetPeer with update channel",
|
||||||
requestType: http.MethodGet,
|
requestType: http.MethodGet,
|
||||||
requestPath: "/api/peers/" + testPeerID,
|
requestPath: "/api/peers/" + testPeerID,
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
expectedArray: false,
|
expectedArray: false,
|
||||||
expectedPeer: peer,
|
expectedPeer: peer,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "GetPeer with no update channel",
|
||||||
|
requestType: http.MethodGet,
|
||||||
|
requestPath: "/api/peers/" + peer1.ID,
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedArray: false,
|
||||||
|
expectedPeer: expectedPeer1,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "PutPeer",
|
name: "PutPeer",
|
||||||
requestType: http.MethodPut,
|
requestType: http.MethodPut,
|
||||||
@ -136,7 +175,7 @@ func TestGetPeers(t *testing.T) {
|
|||||||
|
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
p := initTestMetaData(peer)
|
p := initTestMetaData(peer, peer1)
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
@ -171,6 +210,10 @@ func TestGetPeers(t *testing.T) {
|
|||||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// hardcode this check for now as we only have two peers in this suite
|
||||||
|
assert.Equal(t, len(respBody), 2)
|
||||||
|
assert.Equal(t, respBody[1].Connected, false)
|
||||||
|
|
||||||
got = respBody[0]
|
got = respBody[0]
|
||||||
} else {
|
} else {
|
||||||
got = &api.Peer{}
|
got = &api.Peer{}
|
||||||
@ -180,12 +223,15 @@ func TestGetPeers(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println(got)
|
||||||
|
|
||||||
assert.Equal(t, got.Name, tc.expectedPeer.Name)
|
assert.Equal(t, got.Name, tc.expectedPeer.Name)
|
||||||
assert.Equal(t, got.Version, tc.expectedPeer.Meta.WtVersion)
|
assert.Equal(t, got.Version, tc.expectedPeer.Meta.WtVersion)
|
||||||
assert.Equal(t, got.Ip, tc.expectedPeer.IP.String())
|
assert.Equal(t, got.Ip, tc.expectedPeer.IP.String())
|
||||||
assert.Equal(t, got.Os, "OS core")
|
assert.Equal(t, got.Os, "OS core")
|
||||||
assert.Equal(t, got.LoginExpirationEnabled, tc.expectedPeer.LoginExpirationEnabled)
|
assert.Equal(t, got.LoginExpirationEnabled, tc.expectedPeer.LoginExpirationEnabled)
|
||||||
assert.Equal(t, got.SshEnabled, tc.expectedPeer.SSHEnabled)
|
assert.Equal(t, got.SshEnabled, tc.expectedPeer.SSHEnabled)
|
||||||
|
assert.Equal(t, got.Connected, tc.expectedPeer.Status.Connected)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -75,6 +75,7 @@ type MockAccountManager struct {
|
|||||||
LoginPeerFunc func(login server.PeerLogin) (*server.Peer, *server.NetworkMap, error)
|
LoginPeerFunc func(login server.PeerLogin) (*server.Peer, *server.NetworkMap, error)
|
||||||
SyncPeerFunc func(sync server.PeerSync) (*server.Peer, *server.NetworkMap, error)
|
SyncPeerFunc func(sync server.PeerSync) (*server.Peer, *server.NetworkMap, error)
|
||||||
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
|
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
|
||||||
|
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
|
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
|
||||||
@ -583,3 +584,11 @@ func (am *MockAccountManager) SyncPeer(sync server.PeerSync) (*server.Peer, *ser
|
|||||||
}
|
}
|
||||||
return nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
|
return nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAllConnectedPeers mocks GetAllConnectedPeers of the AccountManager interface
|
||||||
|
func (am *MockAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) {
|
||||||
|
if am.GetAllConnectedPeersFunc != nil {
|
||||||
|
return am.GetAllConnectedPeersFunc()
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method GetAllConnectedPeers is not implemented")
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user