Merge branch 'main' into refactor-get-account-by-token

This commit is contained in:
bcmmbaga
2024-09-20 14:08:09 +03:00
7 changed files with 276 additions and 33 deletions

View File

@@ -518,6 +518,9 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
defer conn.mu.Unlock() defer conn.mu.Unlock()
if conn.ctx.Err() != nil { if conn.ctx.Err() != nil {
if err := rci.relayedConn.Close(); err != nil {
log.Warnf("failed to close unnecessary relayed connection: %v", err)
}
return return
} }
@@ -530,6 +533,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return return
} }
conn.log.Infof("created new wgProxy for relay connection: %s", endpoint)
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
conn.endpointRelay = endpointUdpAddr conn.endpointRelay = endpointUdpAddr
@@ -775,9 +779,8 @@ func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr,
ep, err := wgProxy.AddTurnConn(iceConnInfo.RemoteConn) ep, err := wgProxy.AddTurnConn(iceConnInfo.RemoteConn)
if err != nil { if err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
err = wgProxy.CloseConn() if errClose := wgProxy.CloseConn(); errClose != nil {
if err != nil { conn.log.Warnf("failed to close turn proxy connection: %v", errClose)
conn.log.Warnf("failed to close turn proxy connection: %v", err)
} }
return nil, nil, err return nil, nil, err
} }

View File

@@ -32,8 +32,8 @@ func NewWGUserSpaceProxy(ctx context.Context, wgPort int) *WGUserSpaceProxy {
} }
// AddTurnConn start the proxy with the given remote conn // AddTurnConn start the proxy with the given remote conn
func (p *WGUserSpaceProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) { func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) {
p.remoteConn = turnConn p.remoteConn = remoteConn
var err error var err error
p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
@@ -54,6 +54,14 @@ func (p *WGUserSpaceProxy) CloseConn() error {
if p.localConn == nil { if p.localConn == nil {
return nil return nil
} }
if p.remoteConn == nil {
return nil
}
if err := p.remoteConn.Close(); err != nil {
log.Warnf("failed to close remote conn: %s", err)
}
return p.localConn.Close() return p.localConn.Close()
} }
@@ -65,6 +73,8 @@ func (p *WGUserSpaceProxy) Free() error {
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer // proxyToRemote proxies everything from Wireguard to the RemoteKey peer
// blocks // blocks
func (p *WGUserSpaceProxy) proxyToRemote() { func (p *WGUserSpaceProxy) proxyToRemote() {
defer log.Infof("exit from proxyToRemote: %s", p.localConn.LocalAddr())
buf := make([]byte, 1500) buf := make([]byte, 1500)
for { for {
select { select {
@@ -93,7 +103,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard // proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
// blocks // blocks
func (p *WGUserSpaceProxy) proxyToLocal() { func (p *WGUserSpaceProxy) proxyToLocal() {
defer p.cancel()
defer log.Infof("exit from proxyToLocal: %s", p.localConn.LocalAddr())
buf := make([]byte, 1500) buf := make([]byte, 1500)
for { for {
select { select {
@@ -103,7 +114,6 @@ func (p *WGUserSpaceProxy) proxyToLocal() {
n, err := p.remoteConn.Read(buf) n, err := p.remoteConn.Read(buf)
if err != nil { if err != nil {
if err == io.EOF { if err == io.EOF {
p.cancel()
return return
} }
log.Errorf("failed to read from remote conn: %s", err) log.Errorf("failed to read from remote conn: %s", err)

View File

@@ -7,8 +7,6 @@ import (
"net/http" "net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
@@ -16,6 +14,7 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
) )
// PeersHandler is a handler that returns peers of the account // PeersHandler is a handler that returns peers of the account
@@ -215,7 +214,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network. // GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, _, err := h.accountManager.GetAccountFromToken(r.Context(), claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -228,6 +227,21 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request
return return
} }
// If the user is regular user and does not own the peer
// with the given peerID return an empty list
if !user.HasAdminPower() && !user.IsServiceUser {
peer, ok := account.Peers[peerID]
if !ok {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "peer not found"), w)
return
}
if peer.UserID != user.Id {
util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{})
return
}
}
dnsDomain := h.accountManager.GetDNSDomain() dnsDomain := h.accountManager.GetDNSDomain()
validPeers, err := h.accountManager.GetValidatedPeers(account) validPeers, err := h.accountManager.GetValidatedPeers(account)

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
@@ -12,20 +13,30 @@ import (
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/magiconair/properties/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
) )
const testPeerID = "test_peer" type ctxKey string
const noUpdateChannelTestPeerID = "no-update-channel"
const (
testPeerID = "test_peer"
noUpdateChannelTestPeerID = "no-update-channel"
adminUser = "admin_user"
regularUser = "regular_user"
serviceUser = "service_user"
userIDKey ctxKey = "user_id"
)
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
return &PeersHandler{ return &PeersHandler{
@@ -60,21 +71,57 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
return "netbird.selfhosted" return "netbird.selfhosted"
}, },
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user") peersMap := make(map[string]*nbpeer.Peer)
return &server.Account{ for _, peer := range peers {
peersMap[peer.ID] = peer.Copy()
}
policy := &server.Policy{
ID: "policy",
AccountID: claims.AccountId,
Name: "policy",
Enabled: true,
Rules: []*server.PolicyRule{
{
ID: "rule",
Name: "rule",
Enabled: true,
Action: "accept",
Destinations: []string{"group1"},
Sources: []string{"group1"},
Bidirectional: true,
Protocol: "all",
Ports: []string{"80"},
},
},
}
srvUser := server.NewRegularUser(serviceUser)
srvUser.IsServiceUser = true
account := &server.Account{
Id: claims.AccountId, Id: claims.AccountId,
Domain: "hotmail.com", Domain: "hotmail.com",
Peers: map[string]*nbpeer.Peer{ Peers: peersMap,
peers[0].ID: peers[0],
peers[1].ID: peers[1],
},
Users: map[string]*server.User{ Users: map[string]*server.User{
"test_user": user, adminUser: server.NewAdminUser(adminUser),
regularUser: server.NewRegularUser(regularUser),
serviceUser: srvUser,
},
Groups: map[string]*nbgroup.Group{
"group1": {
ID: "group1",
AccountID: claims.AccountId,
Name: "group1",
Issued: "api",
Peers: maps.Keys(peersMap),
},
}, },
Settings: &server.Settings{ Settings: &server.Settings{
PeerLoginExpirationEnabled: true, PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour, PeerLoginExpiration: time.Hour,
}, },
Policies: []*server.Policy{policy},
Network: &server.Network{ Network: &server.Network{
Identifier: "ciclqisab2ss43jdn8q0", Identifier: "ciclqisab2ss43jdn8q0",
Net: net.IPNet{ Net: net.IPNet{
@@ -83,7 +130,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
}, },
Serial: 51, Serial: 51,
}, },
}, user, nil }
return account, account.Users[claims.UserId], nil
}, },
HasConnectedChannelFunc: func(peerID string) bool { HasConnectedChannelFunc: func(peerID string) bool {
statuses := make(map[string]struct{}) statuses := make(map[string]struct{})
@@ -99,8 +148,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
}, },
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
userID := r.Context().Value(userIDKey).(string)
return jwtclaims.AuthorizationClaims{ return jwtclaims.AuthorizationClaims{
UserId: "test_user", UserId: userID,
Domain: "hotmail.com", Domain: "hotmail.com",
AccountId: "test_id", AccountId: "test_id",
} }
@@ -197,6 +247,8 @@ func TestGetPeers(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
ctx := context.WithValue(context.Background(), userIDKey, "admin_user")
req = req.WithContext(ctx)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET") router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
@@ -251,3 +303,119 @@ func TestGetPeers(t *testing.T) {
}) })
} }
} }
func TestGetAccessiblePeers(t *testing.T) {
peer1 := &nbpeer.Peer{
ID: "peer1",
Key: "key1",
IP: net.ParseIP("100.64.0.1"),
Status: &nbpeer.PeerStatus{Connected: true},
Name: "peer1",
LoginExpirationEnabled: false,
UserID: regularUser,
}
peer2 := &nbpeer.Peer{
ID: "peer2",
Key: "key2",
IP: net.ParseIP("100.64.0.2"),
Status: &nbpeer.PeerStatus{Connected: true},
Name: "peer2",
LoginExpirationEnabled: false,
UserID: adminUser,
}
peer3 := &nbpeer.Peer{
ID: "peer3",
Key: "key3",
IP: net.ParseIP("100.64.0.3"),
Status: &nbpeer.PeerStatus{Connected: true},
Name: "peer3",
LoginExpirationEnabled: false,
UserID: regularUser,
}
p := initTestMetaData(peer1, peer2, peer3)
tt := []struct {
name string
peerID string
callerUserID string
expectedStatus int
expectedPeers []string
}{
{
name: "non admin user can access owned peer",
peerID: "peer1",
callerUserID: regularUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer2", "peer3"},
},
{
name: "non admin user can't access unowned peer",
peerID: "peer2",
callerUserID: regularUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{},
},
{
name: "admin user can access owned peer",
peerID: "peer2",
callerUserID: adminUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer3"},
},
{
name: "admin user can access unowned peer",
peerID: "peer3",
callerUserID: adminUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer2"},
},
{
name: "service user can access unowned peer",
peerID: "peer3",
callerUserID: serviceUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer2"},
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
ctx := context.WithValue(context.Background(), userIDKey, tc.callerUserID)
req = req.WithContext(ctx)
router := mux.NewRouter()
router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
if res.StatusCode != tc.expectedStatus {
t.Fatalf("handler returned wrong status code: got %v want %v", res.StatusCode, tc.expectedStatus)
}
body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("failed to read response body: %v", err)
}
defer res.Body.Close()
var accessiblePeers []api.AccessiblePeer
err = json.Unmarshal(body, &accessiblePeers)
if err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
peerIDs := make([]string, len(accessiblePeers))
for i, peer := range accessiblePeers {
peerIDs[i] = peer.Id
}
assert.ElementsMatch(t, peerIDs, tc.expectedPeers)
})
}
}

View File

@@ -58,7 +58,10 @@ func (m *Msg) Free() {
m.bufPool.Put(m.bufPtr) m.bufPool.Put(m.bufPtr)
} }
// connContainer is a container for the connection to the peer. It is responsible for managing the messages from the
// server and forwarding them to the upper layer content reader.
type connContainer struct { type connContainer struct {
log *log.Entry
conn *Conn conn *Conn
messages chan Msg messages chan Msg
msgChanLock sync.Mutex msgChanLock sync.Mutex
@@ -67,10 +70,10 @@ type connContainer struct {
cancel context.CancelFunc cancel context.CancelFunc
} }
func newConnContainer(conn *Conn, messages chan Msg) *connContainer { func newConnContainer(log *log.Entry, conn *Conn, messages chan Msg) *connContainer {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &connContainer{ return &connContainer{
log: log,
conn: conn, conn: conn,
messages: messages, messages: messages,
ctx: ctx, ctx: ctx,
@@ -91,6 +94,10 @@ func (cc *connContainer) writeMsg(msg Msg) {
case cc.messages <- msg: case cc.messages <- msg:
case <-cc.ctx.Done(): case <-cc.ctx.Done():
msg.Free() msg.Free()
default:
msg.Free()
cc.log.Infof("message queue is full")
// todo consider to close the connection
} }
} }
@@ -141,8 +148,8 @@ type Client struct {
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect // NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client { func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
hashedID, hashedStringId := messages.HashID(peerID) hashedID, hashedStringId := messages.HashID(peerID)
return &Client{ c := &Client{
log: log.WithFields(log.Fields{"client_id": hashedStringId, "relay": serverURL}), log: log.WithFields(log.Fields{"relay": serverURL}),
parentCtx: ctx, parentCtx: ctx,
connectionURL: serverURL, connectionURL: serverURL,
authTokenStore: authTokenStore, authTokenStore: authTokenStore,
@@ -155,6 +162,8 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token
}, },
conns: make(map[string]*connContainer), conns: make(map[string]*connContainer),
} }
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedStringId)
return c
} }
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs. // Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
@@ -203,10 +212,10 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
} }
c.log.Infof("open connection to peer: %s", hashedStringID) c.log.Infof("open connection to peer: %s", hashedStringID)
msgChannel := make(chan Msg, 2) msgChannel := make(chan Msg, 100)
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL) conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
c.conns[hashedStringID] = newConnContainer(conn, msgChannel) c.conns[hashedStringID] = newConnContainer(c.log, conn, msgChannel)
return conn, nil return conn, nil
} }
@@ -455,7 +464,10 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in
} }
c.log.Errorf("health check timeout") c.log.Errorf("health check timeout")
internalStopFlag.set() internalStopFlag.set()
_ = conn.Close() // ignore the err because the readLoop will handle it if err := conn.Close(); err != nil {
// ignore the err handling because the readLoop will handle it
c.log.Warnf("failed to close connection: %s", err)
}
return return
case <-c.parentCtx.Done(): case <-c.parentCtx.Done():
err := c.close(true) err := c.close(true)
@@ -486,6 +498,7 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
if container.conn != connReference { if container.conn != connReference {
return fmt.Errorf("conn reference mismatch") return fmt.Errorf("conn reference mismatch")
} }
c.log.Infof("free up connection to peer: %s", id)
delete(c.conns, id) delete(c.conns, id)
container.close() container.close()

View File

@@ -35,12 +35,15 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*C
connResultChan := make(chan connResult, totalServers) connResultChan := make(chan connResult, totalServers)
successChan := make(chan connResult, 1) successChan := make(chan connResult, 1)
concurrentLimiter := make(chan struct{}, maxConcurrentServers) concurrentLimiter := make(chan struct{}, maxConcurrentServers)
for _, url := range urls { for _, url := range urls {
// todo check if we have a successful connection so we do not need to connect to other servers
concurrentLimiter <- struct{}{} concurrentLimiter <- struct{}{}
go func(url string) { go func(url string) {
defer func() { <-concurrentLimiter }() defer func() {
<-concurrentLimiter
}()
sp.startConnection(parentCtx, connResultChan, url) sp.startConnection(parentCtx, connResultChan, url)
}(url) }(url)
} }
@@ -72,7 +75,8 @@ func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan con
func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) { func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) {
var hasSuccess bool var hasSuccess bool
for cr := range resultChan { for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ {
cr := <-resultChan
if cr.Err != nil { if cr.Err != nil {
log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err) log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
continue continue

View File

@@ -0,0 +1,31 @@
package client
import (
"context"
"errors"
"testing"
"time"
)
func TestServerPicker_UnavailableServers(t *testing.T) {
sp := ServerPicker{
TokenStore: nil,
PeerID: "test",
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
go func() {
_, err := sp.PickServer(ctx, []string{"rel://dummy1", "rel://dummy2"})
if err == nil {
t.Error(err)
}
cancel()
}()
<-ctx.Done()
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
t.Errorf("PickServer() took too long to complete")
}
}