Add relay server tracking

This commit is contained in:
Zoltán Papp 2024-06-01 11:48:15 +02:00
parent fd4ad15c83
commit 3430b81622
2 changed files with 107 additions and 21 deletions

View File

@ -5,8 +5,19 @@ import (
"fmt" "fmt"
"net" "net"
"sync" "sync"
log "github.com/sirupsen/logrus"
) )
type RelayTrack struct {
sync.RWMutex
relayClient *Client
}
func NewRelayTrack() *RelayTrack {
return &RelayTrack{}
}
type Manager struct { type Manager struct {
ctx context.Context ctx context.Context
srvAddress string srvAddress string
@ -15,8 +26,8 @@ type Manager struct {
relayClient *Client relayClient *Client
reconnectGuard *Guard reconnectGuard *Guard
relayClients map[string]*Client relayClients map[string]*RelayTrack
relayClientsMutex sync.Mutex relayClientsMutex sync.RWMutex
} }
func NewManager(ctx context.Context, serverAddress string, peerID string) *Manager { func NewManager(ctx context.Context, serverAddress string, peerID string) *Manager {
@ -24,7 +35,7 @@ func NewManager(ctx context.Context, serverAddress string, peerID string) *Manag
ctx: ctx, ctx: ctx,
srvAddress: serverAddress, srvAddress: serverAddress,
peerID: peerID, peerID: peerID,
relayClients: make(map[string]*Client), relayClients: make(map[string]*RelayTrack),
} }
} }
@ -50,10 +61,10 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
return nil, err return nil, err
} }
if foreign { if !foreign {
return m.openConnVia(serverAddress, peerKey)
} else {
return m.relayClient.OpenConn(peerKey) return m.relayClient.OpenConn(peerKey)
} else {
return m.openConnVia(serverAddress, peerKey)
} }
} }
@ -65,30 +76,50 @@ func (m *Manager) RelayAddress() (net.Addr, error) {
} }
func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
relayClient, ok := m.relayClients[serverAddress] m.relayClientsMutex.RLock()
relayTrack, ok := m.relayClients[serverAddress]
if ok { if ok {
return relayClient.OpenConn(peerKey) relayTrack.RLock()
m.relayClientsMutex.RUnlock()
defer relayTrack.RUnlock()
return relayTrack.relayClient.OpenConn(peerKey)
} }
m.relayClientsMutex.RUnlock()
relayClient = NewClient(m.ctx, serverAddress, m.peerID) rt := NewRelayTrack()
rt.Lock()
m.relayClientsMutex.Lock()
m.relayClients[serverAddress] = rt
m.relayClientsMutex.Unlock()
relayClient := NewClient(m.ctx, serverAddress, m.peerID)
err := relayClient.Connect() err := relayClient.Connect()
if err != nil { if err != nil {
rt.Unlock()
m.relayClientsMutex.Lock()
delete(m.relayClients, serverAddress)
m.relayClientsMutex.Unlock()
return nil, err return nil, err
} }
relayClient.SetOnDisconnectListener(func() { relayClient.SetOnDisconnectListener(func() {
m.deleteRelayConn(serverAddress) m.deleteRelayConn(serverAddress)
}) })
rt.Unlock()
conn, err := relayClient.OpenConn(peerKey) conn, err := relayClient.OpenConn(peerKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
m.relayClients[serverAddress] = relayClient
return conn, nil return conn, nil
} }
func (m *Manager) deleteRelayConn(address string) { func (m *Manager) deleteRelayConn(address string) {
log.Infof("deleting relay client for %s", address)
m.relayClientsMutex.Lock()
defer m.relayClientsMutex.Unlock()
delete(m.relayClients, address) delete(m.relayClients, address)
} }

View File

@ -4,13 +4,14 @@ import (
"context" "context"
"testing" "testing"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/relay/server"
) )
func TestNewManager(t *testing.T) { func TestForeignConn(t *testing.T) {
ctx := context.Background() ctx := context.Background()
idAlice := "alice"
idBob := "bob"
addr1 := "localhost:1234" addr1 := "localhost:1234"
srv1 := server.NewServer() srv1 := server.NewServer()
go func() { go func() {
@ -43,22 +44,21 @@ func TestNewManager(t *testing.T) {
} }
}() }()
idAlice := "alice"
log.Debugf("connect by alice")
clientAlice := NewManager(ctx, addr1, idAlice) clientAlice := NewManager(ctx, addr1, idAlice)
err := clientAlice.Serve() err := clientAlice.Serve()
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
aliceSrvAddr, err := clientAlice.RelayAddress()
if err != nil {
t.Fatalf("failed to get relay address: %s", err)
}
idBob := "bob"
log.Debugf("connect by bob")
clientBob := NewManager(ctx, addr2, idBob) clientBob := NewManager(ctx, addr2, idBob)
err = clientBob.Serve() err = clientBob.Serve()
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
bobsSrvAddr, err := clientBob.RelayAddress() bobsSrvAddr, err := clientBob.RelayAddress()
if err != nil { if err != nil {
t.Fatalf("failed to get relay address: %s", err) t.Fatalf("failed to get relay address: %s", err)
@ -67,8 +67,7 @@ func TestNewManager(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr.String(), idAlice)
connBobToAlice, err := clientBob.OpenConn(aliceSrvAddr.String(), idAlice)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@ -99,3 +98,59 @@ func TestNewManager(t *testing.T) {
t.Fatalf("expected %s, got %s", payload, string(buf[:n])) t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
} }
} }
func TestForeginConnClose(t *testing.T) {
ctx := context.Background()
addr1 := "localhost:1234"
srv1 := server.NewServer()
go func() {
err := srv1.Listen(addr1)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv1.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
addr2 := "localhost:2234"
srv2 := server.NewServer()
go func() {
err := srv2.Listen(addr2)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv2.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
idAlice := "alice"
log.Debugf("connect by alice")
clientAlice := NewManager(ctx, addr1, idAlice)
err := clientAlice.Serve()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
conn, err := clientAlice.OpenConn(addr2, "anotherpeer")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
err = conn.Close()
if err != nil {
t.Fatalf("failed to close connection: %s", err)
}
select {}
}