From 15a7b7629b447b915b0b6982b8589d5ca65a92f9 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 2 Jul 2024 11:57:17 +0200 Subject: [PATCH] Add exposed address --- client/internal/peer/conn.go | 4 +-- client/internal/peer/worker_relay.go | 8 +++--- encryption/letsencrypt.go | 4 +-- go.mod | 4 +-- go.sum | 4 +-- relay/client/client.go | 35 +++++++++++++----------- relay/client/client_test.go | 36 ++++++++++++------------- relay/client/dialer/ws/conn.go | 14 ++++------ relay/client/dialer/ws/ws.go | 17 +----------- relay/client/manager.go | 30 ++++++++++----------- relay/client/manager_test.go | 40 ++++++++++++++-------------- relay/cmd/main.go | 22 ++++++++------- relay/messages/message.go | 34 ++++++++++++++++++++--- relay/server/relay.go | 22 ++++++++++----- relay/server/server.go | 8 +++--- 15 files changed, 154 insertions(+), 128 deletions(-) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index e8dcb069b..ac5c46f83 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -606,9 +606,9 @@ func (conn *Conn) doHandshake() error { err error ) ha.IceUFrag, ha.IcePwd = conn.workerICE.GetLocalUserCredentials() - addr, err := conn.workerRelay.RelayAddress() + addr, err := conn.workerRelay.RelayInstanceAddress() if err == nil { - ha.RelayAddr = addr.String() + ha.RelayAddr = addr } conn.log.Tracef("send new offer: %#v", ha) return conn.handshaker.SendOffer(ha) diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go index 5ad4a1fb5..beea912aa 100644 --- a/client/internal/peer/worker_relay.go +++ b/client/internal/peer/worker_relay.go @@ -46,13 +46,13 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { } // the relayManager will return with error in case if the connection has lost with relay server - currentRelayAddress, err := w.relayManager.RelayAddress() + currentRelayAddress, err := w.relayManager.RelayInstanceAddress() if err != nil { w.log.Infof("local Relay connection is lost, skipping connection attempt") return } - srv := w.preferredRelayServer(currentRelayAddress.String(), remoteOfferAnswer.RelaySrvAddress) + srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress) relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key, w.conn.OnDisconnected) if err != nil { @@ -73,8 +73,8 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { }) } -func (w *WorkerRelay) RelayAddress() (net.Addr, error) { - return w.relayManager.RelayAddress() +func (w *WorkerRelay) RelayInstanceAddress() (string, error) { + return w.relayManager.RelayInstanceAddress() } func (w *WorkerRelay) IsController() bool { diff --git a/encryption/letsencrypt.go b/encryption/letsencrypt.go index cfe54ec5a..27a5e3110 100644 --- a/encryption/letsencrypt.go +++ b/encryption/letsencrypt.go @@ -9,7 +9,7 @@ import ( ) // CreateCertManager wraps common logic of generating Let's encrypt certificate. -func CreateCertManager(datadir string, letsencryptDomain string) (*autocert.Manager, error) { +func CreateCertManager(datadir string, letsencryptDomain ...string) (*autocert.Manager, error) { certDir := filepath.Join(datadir, "letsencrypt") if _, err := os.Stat(certDir); os.IsNotExist(err) { @@ -24,7 +24,7 @@ func CreateCertManager(datadir string, letsencryptDomain string) (*autocert.Mana certManager := &autocert.Manager{ Prompt: autocert.AcceptTOS, Cache: autocert.DirCache(certDir), - HostPolicy: autocert.HostWhitelist(letsencryptDomain), + HostPolicy: autocert.HostWhitelist(letsencryptDomain...), } return certManager, nil diff --git a/go.mod b/go.mod index 67655fc7b..9955fe703 100644 --- a/go.mod +++ b/go.mod @@ -47,7 +47,6 @@ require ( github.com/google/martian/v3 v3.0.0 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 github.com/gopacket/gopacket v1.1.1 - github.com/gorilla/websocket v1.5.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 @@ -64,6 +63,7 @@ require ( github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pion/logging v0.2.2 + github.com/pion/randutil v0.1.0 github.com/pion/stun/v2 v2.0.0 github.com/pion/transport/v3 v3.0.1 github.com/pion/turn/v3 v3.0.1 @@ -77,6 +77,7 @@ require ( github.com/things-go/go-socks5 v0.0.4 github.com/yusufpapurcu/wmi v1.2.4 github.com/zcalusic/sysinfo v1.0.2 + go.mongodb.org/mongo-driver v1.16.0 go.opentelemetry.io/otel v1.26.0 go.opentelemetry.io/otel/exporters/prometheus v0.48.0 go.opentelemetry.io/otel/metric v1.26.0 @@ -173,7 +174,6 @@ require ( github.com/pegasus-kv/thrift v0.13.0 // indirect github.com/pion/dtls/v2 v2.2.10 // indirect github.com/pion/mdns v0.0.12 // indirect - github.com/pion/randutil v0.1.0 // indirect github.com/pion/transport/v2 v2.2.4 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index 590915767..db50be34a 100644 --- a/go.sum +++ b/go.sum @@ -237,8 +237,6 @@ github.com/gopherjs/gopherjs v0.0.0-20220410123724-9e86199038b0 h1:fWY+zXdWhvWnd github.com/gopherjs/gopherjs v0.0.0-20220410123724-9e86199038b0/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= -github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= -github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 h1:Fkzd8ktnpOR9h47SXHe2AYPwelXLH2GjGsjlAloiWfo= github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357/go.mod h1:w9Y7gY31krpLmrVU5ZPG9H7l9fZuRu5/3R3S3FMtVQ4= github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 h1:YBftPWNWd4WwGqtY2yeZL2ef8rHAxPBD8KFhJpmcqms= @@ -513,6 +511,8 @@ github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= github.com/zcalusic/sysinfo v1.0.2 h1:nwTTo2a+WQ0NXwo0BGRojOJvJ/5XKvQih+2RrtWqfxc= github.com/zcalusic/sysinfo v1.0.2/go.mod h1:kluzTYflRWo6/tXVMJPdEjShsbPpsFRyy+p1mBQPC30= +go.mongodb.org/mongo-driver v1.16.0 h1:tpRsfBJMROVHKpdGyc1BBEzzjDUWjItxbVSZ8Ls4BQ4= +go.mongodb.org/mongo-driver v1.16.0/go.mod h1:oB6AhJQvFQL4LEHyXi6aJzQJtBiTQHiAd83l0GdFaiw= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI= diff --git a/relay/client/client.go b/relay/client/client.go index 2401f972e..9fc5b84b8 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -97,7 +97,7 @@ func (cc *connContainer) close() { type Client struct { log *log.Entry parentCtx context.Context - serverAddress string + connectionURL string hashedID []byte bufPool *sync.Pool @@ -108,20 +108,19 @@ type Client struct { mu sync.Mutex // protect serviceIsRunning and conns readLoopMutex sync.Mutex wgReadLoop sync.WaitGroup - - remoteAddr net.Addr + instanceURL string onDisconnectListener func() listenerMutex sync.Mutex } // NewClient creates a new client for the relay server. The client is not connected to the server until the Connect -func NewClient(ctx context.Context, serverAddress, peerID string) *Client { +func NewClient(ctx context.Context, serverURL, peerID string) *Client { hashedID, hashedStringId := messages.HashID(peerID) return &Client{ log: log.WithField("client_id", hashedStringId), parentCtx: ctx, - serverAddress: serverAddress, + connectionURL: serverURL, hashedID: hashedID, bufPool: &sync.Pool{ New: func() any { @@ -135,7 +134,7 @@ func NewClient(ctx context.Context, serverAddress, peerID string) *Client { // Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs. func (c *Client) Connect() error { - c.log.Infof("connecting to relay server: %s", c.serverAddress) + c.log.Infof("connecting to relay server: %s", c.connectionURL) c.readLoopMutex.Lock() defer c.readLoopMutex.Unlock() @@ -156,7 +155,7 @@ func (c *Client) Connect() error { c.wgReadLoop.Add(1) go c.readLoop(c.relayConn) - log.Infof("relay connection established with: %s", c.serverAddress) + log.Infof("relay connection established with: %s", c.connectionURL) return nil } @@ -186,14 +185,14 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) { return conn, nil } -// RelayRemoteAddress returns the IP address of the relay server. It could change after the close and reopen the connection. -func (c *Client) RelayRemoteAddress() (net.Addr, error) { +// ServerInstanceURL returns the address of the relay server. It could change after the close and reopen the connection. +func (c *Client) ServerInstanceURL() (string, error) { c.mu.Lock() defer c.mu.Unlock() - if c.remoteAddr == nil { - return nil, fmt.Errorf("relay connection is not established") + if c.instanceURL == "" { + return "", fmt.Errorf("relay connection is not established") } - return c.remoteAddr, nil + return c.instanceURL, nil } // SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed. @@ -215,7 +214,7 @@ func (c *Client) Close() error { } func (c *Client) connect() error { - conn, err := ws.Dial(c.serverAddress) + conn, err := ws.Dial(c.connectionURL) if err != nil { return err } @@ -231,8 +230,6 @@ func (c *Client) connect() error { return err } - c.remoteAddr = conn.RemoteAddr() - return nil } @@ -264,6 +261,12 @@ func (c *Client) handShake() error { log.Errorf("unexpected message type: %s", msgType) return fmt.Errorf("unexpected message type") } + + domain, err := messages.UnmarshalHelloResponse(buf[:n]) + if err != nil { + return err + } + c.instanceURL = domain return nil } @@ -435,7 +438,7 @@ func (c *Client) close(gracefullyExit bool) error { c.mu.Unlock() c.wgReadLoop.Wait() - c.log.Infof("relay connection closed with: %s", c.serverAddress) + c.log.Infof("relay connection closed with: %s", c.connectionURL) return err } diff --git a/relay/client/client_test.go b/relay/client/client_test.go index 278d46d08..7b1ee5c62 100644 --- a/relay/client/client_test.go +++ b/relay/client/client_test.go @@ -23,8 +23,8 @@ func TestMain(m *testing.M) { func TestClient(t *testing.T) { ctx := context.Background() - srvCfg := server.Config{Address: "localhost:1234"} - srv := server.NewServer() + srvCfg := server.ListenerConfig{Address: "localhost:1234"} + srv := server.NewServer(srvCfg.Address, false) go func() { err := srv.Listen(srvCfg) if err != nil { @@ -91,8 +91,8 @@ func TestClient(t *testing.T) { func TestRegistration(t *testing.T) { ctx := context.Background() - srvCfg := server.Config{Address: "localhost:1234"} - srv := server.NewServer() + srvCfg := server.ListenerConfig{Address: "localhost:1234"} + srv := server.NewServer(srvCfg.Address, false) go func() { err := srv.Listen(srvCfg) if err != nil { @@ -156,8 +156,8 @@ func TestEcho(t *testing.T) { ctx := context.Background() idAlice := "alice" idBob := "bob" - srvCfg := server.Config{Address: "localhost:1234"} - srv := server.NewServer() + srvCfg := server.ListenerConfig{Address: "localhost:1234"} + srv := server.NewServer(srvCfg.Address, false) go func() { err := srv.Listen(srvCfg) if err != nil { @@ -236,8 +236,8 @@ func TestEcho(t *testing.T) { func TestBindToUnavailabePeer(t *testing.T) { ctx := context.Background() - srvCfg := server.Config{Address: "localhost:1234"} - srv := server.NewServer() + srvCfg := server.ListenerConfig{Address: "localhost:1234"} + srv := server.NewServer(srvCfg.Address, false) go func() { err := srv.Listen(srvCfg) if err != nil { @@ -273,8 +273,8 @@ func TestBindToUnavailabePeer(t *testing.T) { func TestBindReconnect(t *testing.T) { ctx := context.Background() - srvCfg := server.Config{Address: "localhost:1234"} - srv := server.NewServer() + srvCfg := server.ListenerConfig{Address: "localhost:1234"} + srv := server.NewServer(srvCfg.Address, false) go func() { err := srv.Listen(srvCfg) if err != nil { @@ -355,8 +355,8 @@ func TestBindReconnect(t *testing.T) { func TestCloseConn(t *testing.T) { ctx := context.Background() - srvCfg := server.Config{Address: "localhost:1234"} - srv := server.NewServer() + srvCfg := server.ListenerConfig{Address: "localhost:1234"} + srv := server.NewServer(srvCfg.Address, false) go func() { err := srv.Listen(srvCfg) if err != nil { @@ -403,8 +403,8 @@ func TestCloseConn(t *testing.T) { func TestCloseRelayConn(t *testing.T) { ctx := context.Background() - srvCfg := server.Config{Address: "localhost:1234"} - srv := server.NewServer() + srvCfg := server.ListenerConfig{Address: "localhost:1234"} + srv := server.NewServer(srvCfg.Address, false) go func() { err := srv.Listen(srvCfg) if err != nil { @@ -446,8 +446,8 @@ func TestCloseRelayConn(t *testing.T) { func TestCloseByServer(t *testing.T) { ctx := context.Background() - srvCfg := server.Config{Address: "localhost:1234"} - srv1 := server.NewServer() + srvCfg := server.ListenerConfig{Address: "localhost:1234"} + srv1 := server.NewServer(srvCfg.Address, false) go func() { err := srv1.Listen(srvCfg) if err != nil { @@ -489,8 +489,8 @@ func TestCloseByServer(t *testing.T) { func TestCloseByClient(t *testing.T) { ctx := context.Background() - srvCfg := server.Config{Address: "localhost:1234"} - srv := server.NewServer() + srvCfg := server.ListenerConfig{Address: "localhost:1234"} + srv := server.NewServer(srvCfg.Address, false) go func() { err := srv.Listen(srvCfg) if err != nil { diff --git a/relay/client/dialer/ws/conn.go b/relay/client/dialer/ws/conn.go index 44d86c1bb..1632fcb45 100644 --- a/relay/client/dialer/ws/conn.go +++ b/relay/client/dialer/ws/conn.go @@ -12,16 +12,12 @@ import ( type Conn struct { ctx context.Context *websocket.Conn - srvAddr net.Addr - localAddr net.Addr } -func NewConn(wsConn *websocket.Conn, srvAddr, localAddr net.Addr) net.Conn { +func NewConn(wsConn *websocket.Conn) net.Conn { return &Conn{ - ctx: context.Background(), - Conn: wsConn, - srvAddr: srvAddr, - localAddr: localAddr, + ctx: context.Background(), + Conn: wsConn, } } @@ -44,11 +40,11 @@ func (c *Conn) Write(b []byte) (n int, err error) { } func (c *Conn) RemoteAddr() net.Addr { - return c.srvAddr + panic("not implemented") } func (c *Conn) LocalAddr() net.Addr { - return c.localAddr + panic("not implemented") } func (c *Conn) SetReadDeadline(t time.Time) error { diff --git a/relay/client/dialer/ws/ws.go b/relay/client/dialer/ws/ws.go index 070a02362..3a2bf48c8 100644 --- a/relay/client/dialer/ws/ws.go +++ b/relay/client/dialer/ws/ws.go @@ -29,22 +29,7 @@ func Dial(address string) (net.Conn, error) { return nil, err } - /* - response.Body.(net.Conn).LocalAddr() - unc, ok := response.Body.(net.Conn) - if !ok { - log.Errorf("failed to get local address: %s", err) - return nil, fmt.Errorf("failed to get local address") - } - - */ - // todo figure out the proper address - dummy := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 8080, - } - - conn := NewConn(wsConn, dummy, dummy) + conn := NewConn(wsConn) return conn, nil } diff --git a/relay/client/manager.go b/relay/client/manager.go index ae44be4b9..109d3b139 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -35,9 +35,9 @@ func NewRelayTrack() *RelayTrack { // relay servers will be closed if there is no active connection. Periodically the manager will check if there is any // unused relay connection and close it. type Manager struct { - ctx context.Context - srvAddress string - peerID string + ctx context.Context + serverURL string + peerID string relayClient *Client reconnectGuard *Guard @@ -49,10 +49,10 @@ type Manager struct { listenerLock sync.Mutex } -func NewManager(ctx context.Context, serverAddress string, peerID string) *Manager { +func NewManager(ctx context.Context, serverURL string, peerID string) *Manager { return &Manager{ ctx: ctx, - srvAddress: serverAddress, + serverURL: serverURL, peerID: peerID, relayClients: make(map[string]*RelayTrack), onDisconnectedListeners: make(map[string]map[*func()]struct{}), @@ -65,7 +65,7 @@ func (m *Manager) Serve() error { return fmt.Errorf("manager already serving") } - m.relayClient = NewClient(m.ctx, m.srvAddress, m.peerID) + m.relayClient = NewClient(m.ctx, m.serverURL, m.peerID) err := m.relayClient.Connect() if err != nil { log.Errorf("failed to connect to relay server: %s", err) @@ -74,7 +74,7 @@ func (m *Manager) Serve() error { m.reconnectGuard = NewGuard(m.ctx, m.relayClient) m.relayClient.SetOnDisconnectListener(func() { - m.onServerDisconnected(m.srvAddress) + m.onServerDisconnected(m.serverURL) }) m.startCleanupLoop() @@ -116,17 +116,17 @@ func (m *Manager) OpenConn(serverAddress, peerKey string, onClosedListener func( return netConn, err } -// RelayAddress returns the address of the permanent relay server. It could change if the network connection is lost. +// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is lost. // This address will be sent to the target peer to choose the common relay server for the communication. -func (m *Manager) RelayAddress() (net.Addr, error) { +func (m *Manager) RelayInstanceAddress() (string, error) { if m.relayClient == nil { - return nil, errRelayClientNotConnected + return "", errRelayClientNotConnected } - return m.relayClient.RelayRemoteAddress() + return m.relayClient.ServerInstanceURL() } func (m *Manager) HasRelayAddress() bool { - return m.srvAddress != "" + return m.serverURL != "" } func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { @@ -182,7 +182,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { } func (m *Manager) onServerDisconnected(serverAddress string) { - if serverAddress == m.srvAddress { + if serverAddress == m.serverURL { m.reconnectGuard.OnDisconnected() } @@ -190,11 +190,11 @@ func (m *Manager) onServerDisconnected(serverAddress string) { } func (m *Manager) isForeignServer(address string) (bool, error) { - rAddr, err := m.relayClient.RelayRemoteAddress() + rAddr, err := m.relayClient.ServerInstanceURL() if err != nil { return false, fmt.Errorf("relay client not connected") } - return rAddr.String() != address, nil + return rAddr != address, nil } func (m *Manager) startCleanupLoop() { diff --git a/relay/client/manager_test.go b/relay/client/manager_test.go index 71e4a416f..3192c1a09 100644 --- a/relay/client/manager_test.go +++ b/relay/client/manager_test.go @@ -13,10 +13,10 @@ import ( func TestForeignConn(t *testing.T) { ctx := context.Background() - srvCfg1 := server.Config{ + srvCfg1 := server.ListenerConfig{ Address: "localhost:1234", } - srv1 := server.NewServer() + srv1 := server.NewServer(srvCfg1.Address, false) go func() { err := srv1.Listen(srvCfg1) if err != nil { @@ -31,10 +31,10 @@ func TestForeignConn(t *testing.T) { } }() - srvCfg2 := server.Config{ + srvCfg2 := server.ListenerConfig{ Address: "localhost:2234", } - srv2 := server.NewServer() + srv2 := server.NewServer(srvCfg2.Address, false) go func() { err := srv2.Listen(srvCfg2) if err != nil { @@ -61,15 +61,15 @@ func TestForeignConn(t *testing.T) { clientBob := NewManager(mCtx, srvCfg2.Address, idBob) clientBob.Serve() - bobsSrvAddr, err := clientBob.RelayAddress() + bobsSrvAddr, err := clientBob.RelayInstanceAddress() if err != nil { t.Fatalf("failed to get relay address: %s", err) } - connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr.String(), idBob, nil) + connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob, nil) if err != nil { t.Fatalf("failed to bind channel: %s", err) } - connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr.String(), idAlice, nil) + connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice, nil) if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -104,10 +104,10 @@ func TestForeignConn(t *testing.T) { func TestForeginConnClose(t *testing.T) { ctx := context.Background() - srvCfg1 := server.Config{ + srvCfg1 := server.ListenerConfig{ Address: "localhost:1234", } - srv1 := server.NewServer() + srv1 := server.NewServer(srvCfg1.Address, false) go func() { err := srv1.Listen(srvCfg1) if err != nil { @@ -122,10 +122,10 @@ func TestForeginConnClose(t *testing.T) { } }() - srvCfg2 := server.Config{ + srvCfg2 := server.ListenerConfig{ Address: "localhost:2234", } - srv2 := server.NewServer() + srv2 := server.NewServer(srvCfg2.Address, false) go func() { err := srv2.Listen(srvCfg2) if err != nil { @@ -161,10 +161,10 @@ func TestForeginConnClose(t *testing.T) { func TestForeginAutoClose(t *testing.T) { ctx := context.Background() relayCleanupInterval = 1 * time.Second - srvCfg1 := server.Config{ + srvCfg1 := server.ListenerConfig{ Address: "localhost:1234", } - srv1 := server.NewServer() + srv1 := server.NewServer(srvCfg1.Address, false) go func() { t.Log("binding server 1.") err := srv1.Listen(srvCfg1) @@ -182,10 +182,10 @@ func TestForeginAutoClose(t *testing.T) { t.Logf("server 1. closed") }() - srvCfg2 := server.Config{ + srvCfg2 := server.ListenerConfig{ Address: "localhost:2234", } - srv2 := server.NewServer() + srv2 := server.NewServer(srvCfg2.Address, false) go func() { t.Log("binding server 2.") err := srv2.Listen(srvCfg2) @@ -234,10 +234,10 @@ func TestAutoReconnect(t *testing.T) { ctx := context.Background() reconnectingTimeout = 2 * time.Second - srvCfg := server.Config{ + srvCfg := server.ListenerConfig{ Address: "localhost:1234", } - srv := server.NewServer() + srv := server.NewServer(srvCfg.Address, false) go func() { err := srv.Listen(srvCfg) if err != nil { @@ -256,11 +256,11 @@ func TestAutoReconnect(t *testing.T) { defer cancel() clientAlice := NewManager(mCtx, srvCfg.Address, "alice") clientAlice.Serve() - ra, err := clientAlice.RelayAddress() + ra, err := clientAlice.RelayInstanceAddress() if err != nil { t.Errorf("failed to get relay address: %s", err) } - conn, err := clientAlice.OpenConn(ra.String(), "bob", nil) + conn, err := clientAlice.OpenConn(ra, "bob", nil) if err != nil { t.Errorf("failed to bind channel: %s", err) } @@ -278,7 +278,7 @@ func TestAutoReconnect(t *testing.T) { time.Sleep(reconnectingTimeout + 1*time.Second) log.Infof("reopent the connection") - _, err = clientAlice.OpenConn(ra.String(), "bob", nil) + _, err = clientAlice.OpenConn(ra, "bob", nil) if err != nil { t.Errorf("failed to open channel: %s", err) } diff --git a/relay/cmd/main.go b/relay/cmd/main.go index 9d1802076..fb437e78d 100644 --- a/relay/cmd/main.go +++ b/relay/cmd/main.go @@ -16,9 +16,11 @@ import ( ) var ( - listenAddress string + listenAddress string + // in HA every peer connect to a common domain, the instance domain has been distributed during the p2p connection + exposedAddress string letsencryptDataDir string - letsencryptDomain string + letsencryptDomains []string rootCmd = &cobra.Command{ Use: "relay", @@ -31,8 +33,9 @@ var ( func init() { _ = util.InitLog("trace", "console") rootCmd.PersistentFlags().StringVarP(&listenAddress, "listen-address", "l", ":1235", "listen address") + rootCmd.PersistentFlags().StringVarP(&exposedAddress, "exposed-address", "e", "", "instance domain address (or ip) and port, it will be distributes between peers") rootCmd.PersistentFlags().StringVarP(&letsencryptDataDir, "letsencrypt-data-dir", "d", "", "a directory to store Let's Encrypt data. Required if Let's Encrypt is enabled.") - rootCmd.PersistentFlags().StringVarP(&letsencryptDomain, "letsencrypt-domain", "a", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") + rootCmd.PersistentFlags().StringArrayVarP(&letsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") } func waitForExitSignal() { @@ -42,7 +45,7 @@ func waitForExitSignal() { } func execute(cmd *cobra.Command, args []string) { - srvCfg := server.Config{ + srvListenerCfg := server.ListenerConfig{ Address: listenAddress, } if hasLetsEncrypt() { @@ -51,11 +54,12 @@ func execute(cmd *cobra.Command, args []string) { log.Errorf("%s", err) os.Exit(1) } - srvCfg.TLSConfig = tlscfg + srvListenerCfg.TLSConfig = tlscfg } - srv := server.NewServer() - err := srv.Listen(srvCfg) + tlsSupport := srvListenerCfg.TLSConfig != nil + srv := server.NewServer(exposedAddress, tlsSupport) + err := srv.Listen(srvListenerCfg) if err != nil { log.Errorf("failed to bind server: %s", err) os.Exit(1) @@ -71,11 +75,11 @@ func execute(cmd *cobra.Command, args []string) { } func hasLetsEncrypt() bool { - return letsencryptDataDir != "" && letsencryptDomain != "" + return letsencryptDataDir != "" && letsencryptDomains != nil && len(letsencryptDomains) > 0 } func setupTLS() (*tls.Config, error) { - certManager, err := encryption.CreateCertManager(letsencryptDataDir, letsencryptDomain) + certManager, err := encryption.CreateCertManager(letsencryptDataDir, letsencryptDomains...) if err != nil { return nil, fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err) } diff --git a/relay/messages/message.go b/relay/messages/message.go index aa62fe867..5177b691d 100644 --- a/relay/messages/message.go +++ b/relay/messages/message.go @@ -4,6 +4,8 @@ import ( "bytes" "fmt" + "go.mongodb.org/mongo-driver/bson" + log "github.com/sirupsen/logrus" ) @@ -45,6 +47,10 @@ func (m MsgType) String() string { } } +type HelloResponse struct { + DomainAddress string +} + func DetermineClientMsgType(msg []byte) (MsgType, error) { msgType := MsgType(msg[0]) switch msgType { @@ -97,10 +103,32 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, error) { return msg[5:], nil } -func MarshalHelloResponse() []byte { - msg := make([]byte, 1) +func MarshalHelloResponse(DomainAddress string) ([]byte, error) { + payload := HelloResponse{ + DomainAddress: DomainAddress, + } + helloResponse, err := bson.Marshal(payload) + if err != nil { + log.Errorf("failed to marshal hello response: %s", err) + return nil, err + } + msg := make([]byte, 1, 1+len(helloResponse)) msg[0] = byte(MsgTypeHelloResponse) - return msg + msg = append(msg, helloResponse...) + return msg, nil +} + +func UnmarshalHelloResponse(msg []byte) (string, error) { + if len(msg) < 2 { + return "", fmt.Errorf("invalid 'hello response' message") + } + payload := HelloResponse{} + err := bson.Unmarshal(msg[1:], &payload) + if err != nil { + log.Errorf("failed to unmarshal hello response: %s", err) + return "", err + } + return payload.DomainAddress, nil } // Close message diff --git a/relay/server/relay.go b/relay/server/relay.go index 39719f4a9..9f0528b25 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -12,16 +12,24 @@ import ( ) type Relay struct { - store *Store + store *Store + instaceURL string // domain:port closed bool closeMu sync.RWMutex } -func NewRelay() *Relay { - return &Relay{ +func NewRelay(exposedAddress string, tlsSupport bool) *Relay { + r := &Relay{ store: NewStore(), } + + if tlsSupport { + r.instaceURL = fmt.Sprintf("rels://%s", exposedAddress) + } else { + r.instaceURL = fmt.Sprintf("rel://%s", exposedAddress) + } + return r } func (r *Relay) Accept(conn net.Conn) { @@ -31,7 +39,7 @@ func (r *Relay) Accept(conn net.Conn) { return } - peerID, err := handShake(conn) + peerID, err := r.handShake(conn) if err != nil { log.Errorf("failed to handshake with %s: %s", conn.RemoteAddr(), err) cErr := conn.Close() @@ -68,7 +76,7 @@ func (r *Relay) Close(ctx context.Context) { r.closeMu.Unlock() } -func handShake(conn net.Conn) ([]byte, error) { +func (r *Relay) handShake(conn net.Conn) ([]byte, error) { buf := make([]byte, messages.MaxHandshakeSize) n, err := conn.Read(buf) if err != nil { @@ -79,18 +87,20 @@ func handShake(conn net.Conn) ([]byte, error) { if err != nil { return nil, err } + if msgType != messages.MsgTypeHello { tErr := fmt.Errorf("invalid message type") log.Errorf("failed to handshake: %s", tErr) return nil, tErr } + peerID, err := messages.UnmarshalHelloMsg(buf[:n]) if err != nil { log.Errorf("failed to handshake: %s", err) return nil, err } - msg := messages.MarshalHelloResponse() + msg, _ := messages.MarshalHelloResponse(r.instaceURL) _, err = conn.Write(msg) if err != nil { return nil, err diff --git a/relay/server/server.go b/relay/server/server.go index 687486138..449819e61 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -14,7 +14,7 @@ import ( "github.com/netbirdio/netbird/relay/server/listener/ws" ) -type Config struct { +type ListenerConfig struct { Address string TLSConfig *tls.Config } @@ -25,13 +25,13 @@ type Server struct { wSListener listener.Listener } -func NewServer() *Server { +func NewServer(exposedAddress string, tlsSupport bool) *Server { return &Server{ - relay: NewRelay(), + relay: NewRelay(exposedAddress, tlsSupport), } } -func (r *Server) Listen(cfg Config) error { +func (r *Server) Listen(cfg ListenerConfig) error { wg := sync.WaitGroup{} wg.Add(2)