mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-21 18:22:37 +02:00
Fix ssl configuration
This commit is contained in:
parent
ed82ef7fe4
commit
d3785dc1fa
@ -23,10 +23,10 @@ func TestMain(m *testing.M) {
|
|||||||
func TestClient(t *testing.T) {
|
func TestClient(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
addr := "localhost:1234"
|
srvCfg := server.Config{Address: "localhost:1234"}
|
||||||
srv := server.NewServer()
|
srv := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
err := srv.Listen(addr)
|
err := srv.Listen(srvCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind server: %s", err)
|
t.Fatalf("failed to bind server: %s", err)
|
||||||
}
|
}
|
||||||
@ -39,21 +39,21 @@ func TestClient(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, addr, "alice")
|
clientAlice := NewClient(ctx, srvCfg.Address, "alice")
|
||||||
err := clientAlice.Connect()
|
err := clientAlice.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
defer clientAlice.Close()
|
defer clientAlice.Close()
|
||||||
|
|
||||||
clientPlaceHolder := NewClient(ctx, addr, "clientPlaceHolder")
|
clientPlaceHolder := NewClient(ctx, srvCfg.Address, "clientPlaceHolder")
|
||||||
err = clientPlaceHolder.Connect()
|
err = clientPlaceHolder.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
defer clientPlaceHolder.Close()
|
defer clientPlaceHolder.Close()
|
||||||
|
|
||||||
clientBob := NewClient(ctx, addr, "bob")
|
clientBob := NewClient(ctx, srvCfg.Address, "bob")
|
||||||
err = clientBob.Connect()
|
err = clientBob.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
@ -91,16 +91,16 @@ func TestClient(t *testing.T) {
|
|||||||
|
|
||||||
func TestRegistration(t *testing.T) {
|
func TestRegistration(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
addr := "localhost:1234"
|
srvCfg := server.Config{Address: "localhost:1234"}
|
||||||
srv := server.NewServer()
|
srv := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
err := srv.Listen(addr)
|
err := srv.Listen(srvCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind server: %s", err)
|
t.Fatalf("failed to bind server: %s", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, addr, "alice")
|
clientAlice := NewClient(ctx, srvCfg.Address, "alice")
|
||||||
err := clientAlice.Connect()
|
err := clientAlice.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = srv.Close()
|
_ = srv.Close()
|
||||||
@ -156,10 +156,10 @@ func TestEcho(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
idAlice := "alice"
|
idAlice := "alice"
|
||||||
idBob := "bob"
|
idBob := "bob"
|
||||||
addr := "localhost:1234"
|
srvCfg := server.Config{Address: "localhost:1234"}
|
||||||
srv := server.NewServer()
|
srv := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
err := srv.Listen(addr)
|
err := srv.Listen(srvCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind server: %s", err)
|
t.Fatalf("failed to bind server: %s", err)
|
||||||
}
|
}
|
||||||
@ -172,7 +172,7 @@ func TestEcho(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, addr, idAlice)
|
clientAlice := NewClient(ctx, srvCfg.Address, idAlice)
|
||||||
err := clientAlice.Connect()
|
err := clientAlice.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
@ -184,7 +184,7 @@ func TestEcho(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientBob := NewClient(ctx, addr, idBob)
|
clientBob := NewClient(ctx, srvCfg.Address, idBob)
|
||||||
err = clientBob.Connect()
|
err = clientBob.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
@ -236,10 +236,10 @@ func TestEcho(t *testing.T) {
|
|||||||
func TestBindToUnavailabePeer(t *testing.T) {
|
func TestBindToUnavailabePeer(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
addr := "localhost:1234"
|
srvCfg := server.Config{Address: "localhost:1234"}
|
||||||
srv := server.NewServer()
|
srv := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
err := srv.Listen(addr)
|
err := srv.Listen(srvCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind server: %s", err)
|
t.Fatalf("failed to bind server: %s", err)
|
||||||
}
|
}
|
||||||
@ -253,7 +253,7 @@ func TestBindToUnavailabePeer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, addr, "alice")
|
clientAlice := NewClient(ctx, srvCfg.Address, "alice")
|
||||||
err := clientAlice.Connect()
|
err := clientAlice.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
@ -273,10 +273,10 @@ func TestBindToUnavailabePeer(t *testing.T) {
|
|||||||
func TestBindReconnect(t *testing.T) {
|
func TestBindReconnect(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
addr := "localhost:1234"
|
srvCfg := server.Config{Address: "localhost:1234"}
|
||||||
srv := server.NewServer()
|
srv := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
err := srv.Listen(addr)
|
err := srv.Listen(srvCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to bind server: %s", err)
|
t.Errorf("failed to bind server: %s", err)
|
||||||
}
|
}
|
||||||
@ -290,7 +290,7 @@ func TestBindReconnect(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, addr, "alice")
|
clientAlice := NewClient(ctx, srvCfg.Address, "alice")
|
||||||
err := clientAlice.Connect()
|
err := clientAlice.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
@ -301,7 +301,7 @@ func TestBindReconnect(t *testing.T) {
|
|||||||
t.Errorf("failed to bind channel: %s", err)
|
t.Errorf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientBob := NewClient(ctx, addr, "bob")
|
clientBob := NewClient(ctx, srvCfg.Address, "bob")
|
||||||
err = clientBob.Connect()
|
err = clientBob.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
@ -318,7 +318,7 @@ func TestBindReconnect(t *testing.T) {
|
|||||||
t.Errorf("failed to close client: %s", err)
|
t.Errorf("failed to close client: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientAlice = NewClient(ctx, addr, "alice")
|
clientAlice = NewClient(ctx, srvCfg.Address, "alice")
|
||||||
err = clientAlice.Connect()
|
err = clientAlice.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
@ -355,10 +355,10 @@ func TestBindReconnect(t *testing.T) {
|
|||||||
func TestCloseConn(t *testing.T) {
|
func TestCloseConn(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
addr := "localhost:1234"
|
srvCfg := server.Config{Address: "localhost:1234"}
|
||||||
srv := server.NewServer()
|
srv := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
err := srv.Listen(addr)
|
err := srv.Listen(srvCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to bind server: %s", err)
|
t.Errorf("failed to bind server: %s", err)
|
||||||
}
|
}
|
||||||
@ -372,7 +372,7 @@ func TestCloseConn(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, addr, "alice")
|
clientAlice := NewClient(ctx, srvCfg.Address, "alice")
|
||||||
err := clientAlice.Connect()
|
err := clientAlice.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
@ -403,10 +403,10 @@ func TestCloseConn(t *testing.T) {
|
|||||||
func TestCloseRelayConn(t *testing.T) {
|
func TestCloseRelayConn(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
addr := "localhost:1234"
|
srvCfg := server.Config{Address: "localhost:1234"}
|
||||||
srv := server.NewServer()
|
srv := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
err := srv.Listen(addr)
|
err := srv.Listen(srvCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to bind server: %s", err)
|
t.Errorf("failed to bind server: %s", err)
|
||||||
}
|
}
|
||||||
@ -419,7 +419,7 @@ func TestCloseRelayConn(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, addr, "alice")
|
clientAlice := NewClient(ctx, srvCfg.Address, "alice")
|
||||||
err := clientAlice.Connect()
|
err := clientAlice.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
@ -446,10 +446,10 @@ func TestCloseRelayConn(t *testing.T) {
|
|||||||
func TestCloseByServer(t *testing.T) {
|
func TestCloseByServer(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
addr1 := "localhost:1234"
|
srvCfg := server.Config{Address: "localhost:1234"}
|
||||||
srv1 := server.NewServer()
|
srv1 := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
err := srv1.Listen(addr1)
|
err := srv1.Listen(srvCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind server: %s", err)
|
t.Fatalf("failed to bind server: %s", err)
|
||||||
}
|
}
|
||||||
@ -457,7 +457,7 @@ func TestCloseByServer(t *testing.T) {
|
|||||||
|
|
||||||
idAlice := "alice"
|
idAlice := "alice"
|
||||||
log.Debugf("connect by alice")
|
log.Debugf("connect by alice")
|
||||||
relayClient := NewClient(ctx, addr1, idAlice)
|
relayClient := NewClient(ctx, srvCfg.Address, idAlice)
|
||||||
err := relayClient.Connect()
|
err := relayClient.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to connect to server: %s", err)
|
log.Fatalf("failed to connect to server: %s", err)
|
||||||
@ -489,10 +489,10 @@ func TestCloseByServer(t *testing.T) {
|
|||||||
func TestCloseByClient(t *testing.T) {
|
func TestCloseByClient(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
addr1 := "localhost:1234"
|
srvCfg := server.Config{Address: "localhost:1234"}
|
||||||
srv := server.NewServer()
|
srv := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
err := srv.Listen(addr1)
|
err := srv.Listen(srvCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind server: %s", err)
|
t.Fatalf("failed to bind server: %s", err)
|
||||||
}
|
}
|
||||||
@ -500,7 +500,7 @@ func TestCloseByClient(t *testing.T) {
|
|||||||
|
|
||||||
idAlice := "alice"
|
idAlice := "alice"
|
||||||
log.Debugf("connect by alice")
|
log.Debugf("connect by alice")
|
||||||
relayClient := NewClient(ctx, addr1, idAlice)
|
relayClient := NewClient(ctx, srvCfg.Address, idAlice)
|
||||||
err := relayClient.Connect()
|
err := relayClient.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to connect to server: %s", err)
|
log.Fatalf("failed to connect to server: %s", err)
|
||||||
|
@ -12,14 +12,16 @@ import (
|
|||||||
type Conn struct {
|
type Conn struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
*websocket.Conn
|
*websocket.Conn
|
||||||
srvAddr *net.TCPAddr
|
srvAddr net.Addr
|
||||||
|
localAddr net.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConn(wsConn *websocket.Conn, srvAddr *net.TCPAddr) net.Conn {
|
func NewConn(wsConn *websocket.Conn, srvAddr, localAddr net.Addr) net.Conn {
|
||||||
return &Conn{
|
return &Conn{
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
Conn: wsConn,
|
Conn: wsConn,
|
||||||
srvAddr: srvAddr,
|
srvAddr: srvAddr,
|
||||||
|
localAddr: localAddr,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -46,8 +48,7 @@ func (c *Conn) RemoteAddr() net.Addr {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) LocalAddr() net.Addr {
|
func (c *Conn) LocalAddr() net.Addr {
|
||||||
// todo: implement me
|
return c.localAddr
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"nhooyr.io/websocket"
|
"nhooyr.io/websocket"
|
||||||
@ -13,32 +14,48 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func Dial(address string) (net.Conn, error) {
|
func Dial(address string) (net.Conn, error) {
|
||||||
|
wsURL, err := prepareURL(address)
|
||||||
hostName, _, err := net.SplitHostPort(address)
|
|
||||||
|
|
||||||
addr, err := net.ResolveTCPAddr("tcp", address)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to resolve address of Relay server: %s", address)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
url := fmt.Sprintf("ws://%s:%d", addr.IP.String(), addr.Port)
|
|
||||||
opts := &websocket.DialOptions{
|
opts := &websocket.DialOptions{
|
||||||
Host: hostName,
|
|
||||||
HTTPClient: httpClientNbDialer(),
|
HTTPClient: httpClientNbDialer(),
|
||||||
}
|
}
|
||||||
|
|
||||||
wsConn, _, err := websocket.Dial(context.Background(), url, opts)
|
wsConn, _, err := websocket.Dial(context.Background(), wsURL, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to dial to Relay server '%s': %s", url, err)
|
log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
conn := NewConn(wsConn, addr)
|
/*
|
||||||
|
response.Body.(net.Conn).LocalAddr()
|
||||||
|
unc, ok := response.Body.(net.Conn)
|
||||||
|
if !ok {
|
||||||
|
log.Errorf("failed to get local address: %s", err)
|
||||||
|
return nil, fmt.Errorf("failed to get local address")
|
||||||
|
}
|
||||||
|
|
||||||
|
*/
|
||||||
|
// todo figure out the proper address
|
||||||
|
dummy := &net.TCPAddr{
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
Port: 8080,
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := NewConn(wsConn, dummy, dummy)
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func prepareURL(address string) (string, error) {
|
||||||
|
if !strings.HasPrefix(address, "rel") {
|
||||||
|
return "", fmt.Errorf("unsupported scheme: %s", address)
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Replace(address, "rel", "ws", 1), nil
|
||||||
|
}
|
||||||
|
|
||||||
func httpClientNbDialer() *http.Client {
|
func httpClientNbDialer() *http.Client {
|
||||||
customDialer := nbnet.NewDialer()
|
customDialer := nbnet.NewDialer()
|
||||||
|
|
||||||
|
@ -13,10 +13,12 @@ import (
|
|||||||
func TestForeignConn(t *testing.T) {
|
func TestForeignConn(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
addr1 := "localhost:1234"
|
srvCfg1 := server.Config{
|
||||||
|
Address: "localhost:1234",
|
||||||
|
}
|
||||||
srv1 := server.NewServer()
|
srv1 := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
err := srv1.Listen(addr1)
|
err := srv1.Listen(srvCfg1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind server: %s", err)
|
t.Fatalf("failed to bind server: %s", err)
|
||||||
}
|
}
|
||||||
@ -29,10 +31,12 @@ func TestForeignConn(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
addr2 := "localhost:2234"
|
srvCfg2 := server.Config{
|
||||||
|
Address: "localhost:2234",
|
||||||
|
}
|
||||||
srv2 := server.NewServer()
|
srv2 := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
err := srv2.Listen(addr2)
|
err := srv2.Listen(srvCfg2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind server: %s", err)
|
t.Fatalf("failed to bind server: %s", err)
|
||||||
}
|
}
|
||||||
@ -49,12 +53,12 @@ func TestForeignConn(t *testing.T) {
|
|||||||
log.Debugf("connect by alice")
|
log.Debugf("connect by alice")
|
||||||
mCtx, cancel := context.WithCancel(ctx)
|
mCtx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
clientAlice := NewManager(mCtx, addr1, idAlice)
|
clientAlice := NewManager(mCtx, srvCfg1.Address, idAlice)
|
||||||
clientAlice.Serve()
|
clientAlice.Serve()
|
||||||
|
|
||||||
idBob := "bob"
|
idBob := "bob"
|
||||||
log.Debugf("connect by bob")
|
log.Debugf("connect by bob")
|
||||||
clientBob := NewManager(mCtx, addr2, idBob)
|
clientBob := NewManager(mCtx, srvCfg2.Address, idBob)
|
||||||
clientBob.Serve()
|
clientBob.Serve()
|
||||||
|
|
||||||
bobsSrvAddr, err := clientBob.RelayAddress()
|
bobsSrvAddr, err := clientBob.RelayAddress()
|
||||||
@ -100,10 +104,12 @@ func TestForeignConn(t *testing.T) {
|
|||||||
func TestForeginConnClose(t *testing.T) {
|
func TestForeginConnClose(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
addr1 := "localhost:1234"
|
srvCfg1 := server.Config{
|
||||||
|
Address: "localhost:1234",
|
||||||
|
}
|
||||||
srv1 := server.NewServer()
|
srv1 := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
err := srv1.Listen(addr1)
|
err := srv1.Listen(srvCfg1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind server: %s", err)
|
t.Fatalf("failed to bind server: %s", err)
|
||||||
}
|
}
|
||||||
@ -116,10 +122,12 @@ func TestForeginConnClose(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
addr2 := "localhost:2234"
|
srvCfg2 := server.Config{
|
||||||
|
Address: "localhost:2234",
|
||||||
|
}
|
||||||
srv2 := server.NewServer()
|
srv2 := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
err := srv2.Listen(addr2)
|
err := srv2.Listen(srvCfg2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind server: %s", err)
|
t.Fatalf("failed to bind server: %s", err)
|
||||||
}
|
}
|
||||||
@ -136,10 +144,10 @@ func TestForeginConnClose(t *testing.T) {
|
|||||||
log.Debugf("connect by alice")
|
log.Debugf("connect by alice")
|
||||||
mCtx, cancel := context.WithCancel(ctx)
|
mCtx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
mgr := NewManager(mCtx, addr1, idAlice)
|
mgr := NewManager(mCtx, srvCfg1.Address, idAlice)
|
||||||
mgr.Serve()
|
mgr.Serve()
|
||||||
|
|
||||||
conn, err := mgr.OpenConn(addr2, "anotherpeer", nil)
|
conn, err := mgr.OpenConn(srvCfg2.Address, "anotherpeer", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
@ -153,11 +161,13 @@ func TestForeginConnClose(t *testing.T) {
|
|||||||
func TestForeginAutoClose(t *testing.T) {
|
func TestForeginAutoClose(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
relayCleanupInterval = 1 * time.Second
|
relayCleanupInterval = 1 * time.Second
|
||||||
addr1 := "localhost:1234"
|
srvCfg1 := server.Config{
|
||||||
|
Address: "localhost:1234",
|
||||||
|
}
|
||||||
srv1 := server.NewServer()
|
srv1 := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
t.Log("binding server 1.")
|
t.Log("binding server 1.")
|
||||||
err := srv1.Listen(addr1)
|
err := srv1.Listen(srvCfg1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind server: %s", err)
|
t.Fatalf("failed to bind server: %s", err)
|
||||||
}
|
}
|
||||||
@ -172,11 +182,13 @@ func TestForeginAutoClose(t *testing.T) {
|
|||||||
t.Logf("server 1. closed")
|
t.Logf("server 1. closed")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
addr2 := "localhost:2234"
|
srvCfg2 := server.Config{
|
||||||
|
Address: "localhost:2234",
|
||||||
|
}
|
||||||
srv2 := server.NewServer()
|
srv2 := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
t.Log("binding server 2.")
|
t.Log("binding server 2.")
|
||||||
err := srv2.Listen(addr2)
|
err := srv2.Listen(srvCfg2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind server: %s", err)
|
t.Fatalf("failed to bind server: %s", err)
|
||||||
}
|
}
|
||||||
@ -194,11 +206,11 @@ func TestForeginAutoClose(t *testing.T) {
|
|||||||
t.Log("connect to server 1.")
|
t.Log("connect to server 1.")
|
||||||
mCtx, cancel := context.WithCancel(ctx)
|
mCtx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
mgr := NewManager(mCtx, addr1, idAlice)
|
mgr := NewManager(mCtx, srvCfg1.Address, idAlice)
|
||||||
mgr.Serve()
|
mgr.Serve()
|
||||||
|
|
||||||
t.Log("open connection to another peer")
|
t.Log("open connection to another peer")
|
||||||
conn, err := mgr.OpenConn(addr2, "anotherpeer", nil)
|
conn, err := mgr.OpenConn(srvCfg2.Address, "anotherpeer", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
@ -222,10 +234,12 @@ func TestAutoReconnect(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
reconnectingTimeout = 2 * time.Second
|
reconnectingTimeout = 2 * time.Second
|
||||||
|
|
||||||
addr := "localhost:1234"
|
srvCfg := server.Config{
|
||||||
|
Address: "localhost:1234",
|
||||||
|
}
|
||||||
srv := server.NewServer()
|
srv := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
err := srv.Listen(addr)
|
err := srv.Listen(srvCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to bind server: %s", err)
|
t.Errorf("failed to bind server: %s", err)
|
||||||
}
|
}
|
||||||
@ -240,7 +254,7 @@ func TestAutoReconnect(t *testing.T) {
|
|||||||
|
|
||||||
mCtx, cancel := context.WithCancel(ctx)
|
mCtx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
clientAlice := NewManager(mCtx, addr, "alice")
|
clientAlice := NewManager(mCtx, srvCfg.Address, "alice")
|
||||||
clientAlice.Serve()
|
clientAlice.Serve()
|
||||||
ra, err := clientAlice.RelayAddress()
|
ra, err := clientAlice.RelayAddress()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"syscall"
|
"syscall"
|
||||||
@ -8,12 +10,15 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/encryption"
|
||||||
"github.com/netbirdio/netbird/relay/server"
|
"github.com/netbirdio/netbird/relay/server"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
listenAddress string
|
listenAddress string
|
||||||
|
letsencryptDataDir string
|
||||||
|
letsencryptDomain string
|
||||||
|
|
||||||
rootCmd = &cobra.Command{
|
rootCmd = &cobra.Command{
|
||||||
Use: "relay",
|
Use: "relay",
|
||||||
@ -26,7 +31,8 @@ var (
|
|||||||
func init() {
|
func init() {
|
||||||
_ = util.InitLog("trace", "console")
|
_ = util.InitLog("trace", "console")
|
||||||
rootCmd.PersistentFlags().StringVarP(&listenAddress, "listen-address", "l", ":1235", "listen address")
|
rootCmd.PersistentFlags().StringVarP(&listenAddress, "listen-address", "l", ":1235", "listen address")
|
||||||
|
rootCmd.PersistentFlags().StringVarP(&letsencryptDataDir, "letsencrypt-data-dir", "d", "", "a directory to store Let's Encrypt data. Required if Let's Encrypt is enabled.")
|
||||||
|
rootCmd.PersistentFlags().StringVarP(&letsencryptDomain, "letsencrypt-domain", "a", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
|
||||||
}
|
}
|
||||||
|
|
||||||
func waitForExitSignal() {
|
func waitForExitSignal() {
|
||||||
@ -36,8 +42,20 @@ func waitForExitSignal() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func execute(cmd *cobra.Command, args []string) {
|
func execute(cmd *cobra.Command, args []string) {
|
||||||
|
srvCfg := server.Config{
|
||||||
|
Address: listenAddress,
|
||||||
|
}
|
||||||
|
if hasLetsEncrypt() {
|
||||||
|
tlscfg, err := setupTLS()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("%s", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
srvCfg.TLSConfig = tlscfg
|
||||||
|
}
|
||||||
|
|
||||||
srv := server.NewServer()
|
srv := server.NewServer()
|
||||||
err := srv.Listen(listenAddress)
|
err := srv.Listen(srvCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to bind server: %s", err)
|
log.Errorf("failed to bind server: %s", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
@ -52,6 +70,18 @@ func execute(cmd *cobra.Command, args []string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hasLetsEncrypt() bool {
|
||||||
|
return letsencryptDataDir != "" && letsencryptDomain != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTLS() (*tls.Config, error) {
|
||||||
|
certManager, err := encryption.CreateCertManager(letsencryptDataDir, letsencryptDomain)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err)
|
||||||
|
}
|
||||||
|
return certManager.TLSConfig(), nil
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
err := rootCmd.Execute()
|
err := rootCmd.Execute()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -2,6 +2,7 @@ package ws
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
@ -10,35 +11,37 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"nhooyr.io/websocket"
|
"nhooyr.io/websocket"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/relay/server/listener"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Listener struct {
|
type Listener struct {
|
||||||
address string
|
// Address is the address to listen on.
|
||||||
|
Address string
|
||||||
|
// TLSConfig is the TLS configuration for the server.
|
||||||
|
TLSConfig *tls.Config
|
||||||
|
|
||||||
server *http.Server
|
server *http.Server
|
||||||
acceptFn func(conn net.Conn)
|
acceptFn func(conn net.Conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewListener(address string) listener.Listener {
|
|
||||||
return &Listener{
|
|
||||||
address: address,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
|
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
|
||||||
l.acceptFn = acceptFn
|
l.acceptFn = acceptFn
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc("/", l.onAccept)
|
mux.HandleFunc("/", l.onAccept)
|
||||||
|
|
||||||
l.server = &http.Server{
|
l.server = &http.Server{
|
||||||
Addr: l.address,
|
Addr: l.Address,
|
||||||
Handler: mux,
|
Handler: mux,
|
||||||
|
TLSConfig: l.TLSConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("WS server is listening on address: %s", l.address)
|
log.Infof("WS server is listening on address: %s", l.Address)
|
||||||
err := l.server.ListenAndServe()
|
var err error
|
||||||
|
if l.TLSConfig != nil {
|
||||||
|
err = l.server.ListenAndServeTLS("", "")
|
||||||
|
|
||||||
|
} else {
|
||||||
|
err = l.server.ListenAndServe()
|
||||||
|
}
|
||||||
if errors.Is(err, http.ErrServerClosed) {
|
if errors.Is(err, http.ErrServerClosed) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -13,6 +14,11 @@ import (
|
|||||||
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Address string
|
||||||
|
TLSConfig *tls.Config
|
||||||
|
}
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
relay *Relay
|
relay *Relay
|
||||||
uDPListener listener.Listener
|
uDPListener listener.Listener
|
||||||
@ -25,11 +31,15 @@ func NewServer() *Server {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Server) Listen(address string) error {
|
func (r *Server) Listen(cfg Config) error {
|
||||||
wg := sync.WaitGroup{}
|
wg := sync.WaitGroup{}
|
||||||
wg.Add(2)
|
wg.Add(2)
|
||||||
|
|
||||||
r.wSListener = ws.NewListener(address)
|
r.wSListener = &ws.Listener{
|
||||||
|
Address: cfg.Address,
|
||||||
|
TLSConfig: cfg.TLSConfig,
|
||||||
|
}
|
||||||
|
|
||||||
var wslErr error
|
var wslErr error
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
@ -39,7 +49,7 @@ func (r *Server) Listen(address string) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
r.uDPListener = udp.NewListener(address)
|
r.uDPListener = udp.NewListener(cfg.Address)
|
||||||
var udpLErr error
|
var udpLErr error
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user