mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-18 19:09:09 +02:00
Merge branch 'main' into refactor-get-account-by-token
This commit is contained in:
@@ -518,6 +518,9 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
if conn.ctx.Err() != nil {
|
if conn.ctx.Err() != nil {
|
||||||
|
if err := rci.relayedConn.Close(); err != nil {
|
||||||
|
log.Warnf("failed to close unnecessary relayed connection: %v", err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -530,6 +533,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
conn.log.Infof("created new wgProxy for relay connection: %s", endpoint)
|
||||||
|
|
||||||
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
||||||
conn.endpointRelay = endpointUdpAddr
|
conn.endpointRelay = endpointUdpAddr
|
||||||
@@ -775,9 +779,8 @@ func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr,
|
|||||||
ep, err := wgProxy.AddTurnConn(iceConnInfo.RemoteConn)
|
ep, err := wgProxy.AddTurnConn(iceConnInfo.RemoteConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||||
err = wgProxy.CloseConn()
|
if errClose := wgProxy.CloseConn(); errClose != nil {
|
||||||
if err != nil {
|
conn.log.Warnf("failed to close turn proxy connection: %v", errClose)
|
||||||
conn.log.Warnf("failed to close turn proxy connection: %v", err)
|
|
||||||
}
|
}
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@@ -32,8 +32,8 @@ func NewWGUserSpaceProxy(ctx context.Context, wgPort int) *WGUserSpaceProxy {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddTurnConn start the proxy with the given remote conn
|
// AddTurnConn start the proxy with the given remote conn
|
||||||
func (p *WGUserSpaceProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
|
func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) {
|
||||||
p.remoteConn = turnConn
|
p.remoteConn = remoteConn
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||||
@@ -54,6 +54,14 @@ func (p *WGUserSpaceProxy) CloseConn() error {
|
|||||||
if p.localConn == nil {
|
if p.localConn == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if p.remoteConn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.remoteConn.Close(); err != nil {
|
||||||
|
log.Warnf("failed to close remote conn: %s", err)
|
||||||
|
}
|
||||||
return p.localConn.Close()
|
return p.localConn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,6 +73,8 @@ func (p *WGUserSpaceProxy) Free() error {
|
|||||||
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
|
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
|
||||||
// blocks
|
// blocks
|
||||||
func (p *WGUserSpaceProxy) proxyToRemote() {
|
func (p *WGUserSpaceProxy) proxyToRemote() {
|
||||||
|
defer log.Infof("exit from proxyToRemote: %s", p.localConn.LocalAddr())
|
||||||
|
|
||||||
buf := make([]byte, 1500)
|
buf := make([]byte, 1500)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -93,7 +103,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
|
|||||||
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
|
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
|
||||||
// blocks
|
// blocks
|
||||||
func (p *WGUserSpaceProxy) proxyToLocal() {
|
func (p *WGUserSpaceProxy) proxyToLocal() {
|
||||||
|
defer p.cancel()
|
||||||
|
defer log.Infof("exit from proxyToLocal: %s", p.localConn.LocalAddr())
|
||||||
buf := make([]byte, 1500)
|
buf := make([]byte, 1500)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -103,7 +114,6 @@ func (p *WGUserSpaceProxy) proxyToLocal() {
|
|||||||
n, err := p.remoteConn.Read(buf)
|
n, err := p.remoteConn.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
p.cancel()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Errorf("failed to read from remote conn: %s", err)
|
log.Errorf("failed to read from remote conn: %s", err)
|
||||||
|
@@ -7,8 +7,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
@@ -16,6 +14,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PeersHandler is a handler that returns peers of the account
|
// PeersHandler is a handler that returns peers of the account
|
||||||
@@ -215,7 +214,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv
|
|||||||
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
|
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
|
||||||
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, _, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -228,6 +227,21 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the user is regular user and does not own the peer
|
||||||
|
// with the given peerID return an empty list
|
||||||
|
if !user.HasAdminPower() && !user.IsServiceUser {
|
||||||
|
peer, ok := account.Peers[peerID]
|
||||||
|
if !ok {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.NotFound, "peer not found"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.UserID != user.Id {
|
||||||
|
util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
dnsDomain := h.accountManager.GetDNSDomain()
|
dnsDomain := h.accountManager.GetDNSDomain()
|
||||||
|
|
||||||
validPeers, err := h.accountManager.GetValidatedPeers(account)
|
validPeers, err := h.accountManager.GetValidatedPeers(account)
|
||||||
|
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -12,20 +13,30 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
|
|
||||||
"github.com/magiconair/properties/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
)
|
)
|
||||||
|
|
||||||
const testPeerID = "test_peer"
|
type ctxKey string
|
||||||
const noUpdateChannelTestPeerID = "no-update-channel"
|
|
||||||
|
const (
|
||||||
|
testPeerID = "test_peer"
|
||||||
|
noUpdateChannelTestPeerID = "no-update-channel"
|
||||||
|
|
||||||
|
adminUser = "admin_user"
|
||||||
|
regularUser = "regular_user"
|
||||||
|
serviceUser = "service_user"
|
||||||
|
userIDKey ctxKey = "user_id"
|
||||||
|
)
|
||||||
|
|
||||||
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||||
return &PeersHandler{
|
return &PeersHandler{
|
||||||
@@ -60,21 +71,57 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
|||||||
return "netbird.selfhosted"
|
return "netbird.selfhosted"
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||||
user := server.NewAdminUser("test_user")
|
peersMap := make(map[string]*nbpeer.Peer)
|
||||||
return &server.Account{
|
for _, peer := range peers {
|
||||||
|
peersMap[peer.ID] = peer.Copy()
|
||||||
|
}
|
||||||
|
|
||||||
|
policy := &server.Policy{
|
||||||
|
ID: "policy",
|
||||||
|
AccountID: claims.AccountId,
|
||||||
|
Name: "policy",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*server.PolicyRule{
|
||||||
|
{
|
||||||
|
ID: "rule",
|
||||||
|
Name: "rule",
|
||||||
|
Enabled: true,
|
||||||
|
Action: "accept",
|
||||||
|
Destinations: []string{"group1"},
|
||||||
|
Sources: []string{"group1"},
|
||||||
|
Bidirectional: true,
|
||||||
|
Protocol: "all",
|
||||||
|
Ports: []string{"80"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
srvUser := server.NewRegularUser(serviceUser)
|
||||||
|
srvUser.IsServiceUser = true
|
||||||
|
|
||||||
|
account := &server.Account{
|
||||||
Id: claims.AccountId,
|
Id: claims.AccountId,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
Peers: map[string]*nbpeer.Peer{
|
Peers: peersMap,
|
||||||
peers[0].ID: peers[0],
|
|
||||||
peers[1].ID: peers[1],
|
|
||||||
},
|
|
||||||
Users: map[string]*server.User{
|
Users: map[string]*server.User{
|
||||||
"test_user": user,
|
adminUser: server.NewAdminUser(adminUser),
|
||||||
|
regularUser: server.NewRegularUser(regularUser),
|
||||||
|
serviceUser: srvUser,
|
||||||
|
},
|
||||||
|
Groups: map[string]*nbgroup.Group{
|
||||||
|
"group1": {
|
||||||
|
ID: "group1",
|
||||||
|
AccountID: claims.AccountId,
|
||||||
|
Name: "group1",
|
||||||
|
Issued: "api",
|
||||||
|
Peers: maps.Keys(peersMap),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Settings: &server.Settings{
|
Settings: &server.Settings{
|
||||||
PeerLoginExpirationEnabled: true,
|
PeerLoginExpirationEnabled: true,
|
||||||
PeerLoginExpiration: time.Hour,
|
PeerLoginExpiration: time.Hour,
|
||||||
},
|
},
|
||||||
|
Policies: []*server.Policy{policy},
|
||||||
Network: &server.Network{
|
Network: &server.Network{
|
||||||
Identifier: "ciclqisab2ss43jdn8q0",
|
Identifier: "ciclqisab2ss43jdn8q0",
|
||||||
Net: net.IPNet{
|
Net: net.IPNet{
|
||||||
@@ -83,7 +130,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
|||||||
},
|
},
|
||||||
Serial: 51,
|
Serial: 51,
|
||||||
},
|
},
|
||||||
}, user, nil
|
}
|
||||||
|
|
||||||
|
return account, account.Users[claims.UserId], nil
|
||||||
},
|
},
|
||||||
HasConnectedChannelFunc: func(peerID string) bool {
|
HasConnectedChannelFunc: func(peerID string) bool {
|
||||||
statuses := make(map[string]struct{})
|
statuses := make(map[string]struct{})
|
||||||
@@ -99,8 +148,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
|||||||
},
|
},
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||||
|
userID := r.Context().Value(userIDKey).(string)
|
||||||
return jwtclaims.AuthorizationClaims{
|
return jwtclaims.AuthorizationClaims{
|
||||||
UserId: "test_user",
|
UserId: userID,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
AccountId: "test_id",
|
AccountId: "test_id",
|
||||||
}
|
}
|
||||||
@@ -197,6 +247,8 @@ func TestGetPeers(t *testing.T) {
|
|||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
|
ctx := context.WithValue(context.Background(), userIDKey, "admin_user")
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
|
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
|
||||||
@@ -251,3 +303,119 @@ func TestGetPeers(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetAccessiblePeers(t *testing.T) {
|
||||||
|
peer1 := &nbpeer.Peer{
|
||||||
|
ID: "peer1",
|
||||||
|
Key: "key1",
|
||||||
|
IP: net.ParseIP("100.64.0.1"),
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: true},
|
||||||
|
Name: "peer1",
|
||||||
|
LoginExpirationEnabled: false,
|
||||||
|
UserID: regularUser,
|
||||||
|
}
|
||||||
|
|
||||||
|
peer2 := &nbpeer.Peer{
|
||||||
|
ID: "peer2",
|
||||||
|
Key: "key2",
|
||||||
|
IP: net.ParseIP("100.64.0.2"),
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: true},
|
||||||
|
Name: "peer2",
|
||||||
|
LoginExpirationEnabled: false,
|
||||||
|
UserID: adminUser,
|
||||||
|
}
|
||||||
|
|
||||||
|
peer3 := &nbpeer.Peer{
|
||||||
|
ID: "peer3",
|
||||||
|
Key: "key3",
|
||||||
|
IP: net.ParseIP("100.64.0.3"),
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: true},
|
||||||
|
Name: "peer3",
|
||||||
|
LoginExpirationEnabled: false,
|
||||||
|
UserID: regularUser,
|
||||||
|
}
|
||||||
|
|
||||||
|
p := initTestMetaData(peer1, peer2, peer3)
|
||||||
|
|
||||||
|
tt := []struct {
|
||||||
|
name string
|
||||||
|
peerID string
|
||||||
|
callerUserID string
|
||||||
|
expectedStatus int
|
||||||
|
expectedPeers []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "non admin user can access owned peer",
|
||||||
|
peerID: "peer1",
|
||||||
|
callerUserID: regularUser,
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedPeers: []string{"peer2", "peer3"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non admin user can't access unowned peer",
|
||||||
|
peerID: "peer2",
|
||||||
|
callerUserID: regularUser,
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedPeers: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "admin user can access owned peer",
|
||||||
|
peerID: "peer2",
|
||||||
|
callerUserID: adminUser,
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedPeers: []string{"peer1", "peer3"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "admin user can access unowned peer",
|
||||||
|
peerID: "peer3",
|
||||||
|
callerUserID: adminUser,
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedPeers: []string{"peer1", "peer2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "service user can access unowned peer",
|
||||||
|
peerID: "peer3",
|
||||||
|
callerUserID: serviceUser,
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedPeers: []string{"peer1", "peer2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tt {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
|
||||||
|
ctx := context.WithValue(context.Background(), userIDKey, tc.callerUserID)
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
|
router := mux.NewRouter()
|
||||||
|
router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
res := recorder.Result()
|
||||||
|
if res.StatusCode != tc.expectedStatus {
|
||||||
|
t.Fatalf("handler returned wrong status code: got %v want %v", res.StatusCode, tc.expectedStatus)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read response body: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
var accessiblePeers []api.AccessiblePeer
|
||||||
|
err = json.Unmarshal(body, &accessiblePeers)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
peerIDs := make([]string, len(accessiblePeers))
|
||||||
|
for i, peer := range accessiblePeers {
|
||||||
|
peerIDs[i] = peer.Id
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.ElementsMatch(t, peerIDs, tc.expectedPeers)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -58,7 +58,10 @@ func (m *Msg) Free() {
|
|||||||
m.bufPool.Put(m.bufPtr)
|
m.bufPool.Put(m.bufPtr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// connContainer is a container for the connection to the peer. It is responsible for managing the messages from the
|
||||||
|
// server and forwarding them to the upper layer content reader.
|
||||||
type connContainer struct {
|
type connContainer struct {
|
||||||
|
log *log.Entry
|
||||||
conn *Conn
|
conn *Conn
|
||||||
messages chan Msg
|
messages chan Msg
|
||||||
msgChanLock sync.Mutex
|
msgChanLock sync.Mutex
|
||||||
@@ -67,10 +70,10 @@ type connContainer struct {
|
|||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
|
func newConnContainer(log *log.Entry, conn *Conn, messages chan Msg) *connContainer {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
return &connContainer{
|
return &connContainer{
|
||||||
|
log: log,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
messages: messages,
|
messages: messages,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
@@ -91,6 +94,10 @@ func (cc *connContainer) writeMsg(msg Msg) {
|
|||||||
case cc.messages <- msg:
|
case cc.messages <- msg:
|
||||||
case <-cc.ctx.Done():
|
case <-cc.ctx.Done():
|
||||||
msg.Free()
|
msg.Free()
|
||||||
|
default:
|
||||||
|
msg.Free()
|
||||||
|
cc.log.Infof("message queue is full")
|
||||||
|
// todo consider to close the connection
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -141,8 +148,8 @@ type Client struct {
|
|||||||
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
||||||
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
|
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
|
||||||
hashedID, hashedStringId := messages.HashID(peerID)
|
hashedID, hashedStringId := messages.HashID(peerID)
|
||||||
return &Client{
|
c := &Client{
|
||||||
log: log.WithFields(log.Fields{"client_id": hashedStringId, "relay": serverURL}),
|
log: log.WithFields(log.Fields{"relay": serverURL}),
|
||||||
parentCtx: ctx,
|
parentCtx: ctx,
|
||||||
connectionURL: serverURL,
|
connectionURL: serverURL,
|
||||||
authTokenStore: authTokenStore,
|
authTokenStore: authTokenStore,
|
||||||
@@ -155,6 +162,8 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token
|
|||||||
},
|
},
|
||||||
conns: make(map[string]*connContainer),
|
conns: make(map[string]*connContainer),
|
||||||
}
|
}
|
||||||
|
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedStringId)
|
||||||
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
|
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
|
||||||
@@ -203,10 +212,10 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.log.Infof("open connection to peer: %s", hashedStringID)
|
c.log.Infof("open connection to peer: %s", hashedStringID)
|
||||||
msgChannel := make(chan Msg, 2)
|
msgChannel := make(chan Msg, 100)
|
||||||
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
|
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
|
||||||
|
|
||||||
c.conns[hashedStringID] = newConnContainer(conn, msgChannel)
|
c.conns[hashedStringID] = newConnContainer(c.log, conn, msgChannel)
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -455,7 +464,10 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in
|
|||||||
}
|
}
|
||||||
c.log.Errorf("health check timeout")
|
c.log.Errorf("health check timeout")
|
||||||
internalStopFlag.set()
|
internalStopFlag.set()
|
||||||
_ = conn.Close() // ignore the err because the readLoop will handle it
|
if err := conn.Close(); err != nil {
|
||||||
|
// ignore the err handling because the readLoop will handle it
|
||||||
|
c.log.Warnf("failed to close connection: %s", err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
case <-c.parentCtx.Done():
|
case <-c.parentCtx.Done():
|
||||||
err := c.close(true)
|
err := c.close(true)
|
||||||
@@ -486,6 +498,7 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
|
|||||||
if container.conn != connReference {
|
if container.conn != connReference {
|
||||||
return fmt.Errorf("conn reference mismatch")
|
return fmt.Errorf("conn reference mismatch")
|
||||||
}
|
}
|
||||||
|
c.log.Infof("free up connection to peer: %s", id)
|
||||||
delete(c.conns, id)
|
delete(c.conns, id)
|
||||||
container.close()
|
container.close()
|
||||||
|
|
||||||
|
@@ -35,12 +35,15 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*C
|
|||||||
|
|
||||||
connResultChan := make(chan connResult, totalServers)
|
connResultChan := make(chan connResult, totalServers)
|
||||||
successChan := make(chan connResult, 1)
|
successChan := make(chan connResult, 1)
|
||||||
|
|
||||||
concurrentLimiter := make(chan struct{}, maxConcurrentServers)
|
concurrentLimiter := make(chan struct{}, maxConcurrentServers)
|
||||||
|
|
||||||
for _, url := range urls {
|
for _, url := range urls {
|
||||||
|
// todo check if we have a successful connection so we do not need to connect to other servers
|
||||||
concurrentLimiter <- struct{}{}
|
concurrentLimiter <- struct{}{}
|
||||||
go func(url string) {
|
go func(url string) {
|
||||||
defer func() { <-concurrentLimiter }()
|
defer func() {
|
||||||
|
<-concurrentLimiter
|
||||||
|
}()
|
||||||
sp.startConnection(parentCtx, connResultChan, url)
|
sp.startConnection(parentCtx, connResultChan, url)
|
||||||
}(url)
|
}(url)
|
||||||
}
|
}
|
||||||
@@ -72,7 +75,8 @@ func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan con
|
|||||||
|
|
||||||
func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) {
|
func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) {
|
||||||
var hasSuccess bool
|
var hasSuccess bool
|
||||||
for cr := range resultChan {
|
for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ {
|
||||||
|
cr := <-resultChan
|
||||||
if cr.Err != nil {
|
if cr.Err != nil {
|
||||||
log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
|
log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
|
||||||
continue
|
continue
|
||||||
|
31
relay/client/picker_test.go
Normal file
31
relay/client/picker_test.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServerPicker_UnavailableServers(t *testing.T) {
|
||||||
|
sp := ServerPicker{
|
||||||
|
TokenStore: nil,
|
||||||
|
PeerID: "test",
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, err := sp.PickServer(ctx, []string{"rel://dummy1", "rel://dummy2"})
|
||||||
|
if err == nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-ctx.Done()
|
||||||
|
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||||
|
t.Errorf("PickServer() took too long to complete")
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user