mirror of
https://github.com/netbirdio/netbird.git
synced 2025-03-05 02:11:13 +01:00
Fix ssl configuration
This commit is contained in:
parent
ed82ef7fe4
commit
d3785dc1fa
@ -23,10 +23,10 @@ func TestMain(m *testing.M) {
|
||||
func TestClient(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
addr := "localhost:1234"
|
||||
srvCfg := server.Config{Address: "localhost:1234"}
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
err := srv.Listen(addr)
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
@ -39,21 +39,21 @@ func TestClient(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
clientAlice := NewClient(ctx, addr, "alice")
|
||||
clientAlice := NewClient(ctx, srvCfg.Address, "alice")
|
||||
err := clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer clientAlice.Close()
|
||||
|
||||
clientPlaceHolder := NewClient(ctx, addr, "clientPlaceHolder")
|
||||
clientPlaceHolder := NewClient(ctx, srvCfg.Address, "clientPlaceHolder")
|
||||
err = clientPlaceHolder.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer clientPlaceHolder.Close()
|
||||
|
||||
clientBob := NewClient(ctx, addr, "bob")
|
||||
clientBob := NewClient(ctx, srvCfg.Address, "bob")
|
||||
err = clientBob.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
@ -91,16 +91,16 @@ func TestClient(t *testing.T) {
|
||||
|
||||
func TestRegistration(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
addr := "localhost:1234"
|
||||
srvCfg := server.Config{Address: "localhost:1234"}
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
err := srv.Listen(addr)
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
clientAlice := NewClient(ctx, addr, "alice")
|
||||
clientAlice := NewClient(ctx, srvCfg.Address, "alice")
|
||||
err := clientAlice.Connect()
|
||||
if err != nil {
|
||||
_ = srv.Close()
|
||||
@ -156,10 +156,10 @@ func TestEcho(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
idAlice := "alice"
|
||||
idBob := "bob"
|
||||
addr := "localhost:1234"
|
||||
srvCfg := server.Config{Address: "localhost:1234"}
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
err := srv.Listen(addr)
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
@ -172,7 +172,7 @@ func TestEcho(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
clientAlice := NewClient(ctx, addr, idAlice)
|
||||
clientAlice := NewClient(ctx, srvCfg.Address, idAlice)
|
||||
err := clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
@ -184,7 +184,7 @@ func TestEcho(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
clientBob := NewClient(ctx, addr, idBob)
|
||||
clientBob := NewClient(ctx, srvCfg.Address, idBob)
|
||||
err = clientBob.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
@ -236,10 +236,10 @@ func TestEcho(t *testing.T) {
|
||||
func TestBindToUnavailabePeer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
addr := "localhost:1234"
|
||||
srvCfg := server.Config{Address: "localhost:1234"}
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
err := srv.Listen(addr)
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
@ -253,7 +253,7 @@ func TestBindToUnavailabePeer(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
clientAlice := NewClient(ctx, addr, "alice")
|
||||
clientAlice := NewClient(ctx, srvCfg.Address, "alice")
|
||||
err := clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
@ -273,10 +273,10 @@ func TestBindToUnavailabePeer(t *testing.T) {
|
||||
func TestBindReconnect(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
addr := "localhost:1234"
|
||||
srvCfg := server.Config{Address: "localhost:1234"}
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
err := srv.Listen(addr)
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind server: %s", err)
|
||||
}
|
||||
@ -290,7 +290,7 @@ func TestBindReconnect(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
clientAlice := NewClient(ctx, addr, "alice")
|
||||
clientAlice := NewClient(ctx, srvCfg.Address, "alice")
|
||||
err := clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
@ -301,7 +301,7 @@ func TestBindReconnect(t *testing.T) {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
clientBob := NewClient(ctx, addr, "bob")
|
||||
clientBob := NewClient(ctx, srvCfg.Address, "bob")
|
||||
err = clientBob.Connect()
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
@ -318,7 +318,7 @@ func TestBindReconnect(t *testing.T) {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
|
||||
clientAlice = NewClient(ctx, addr, "alice")
|
||||
clientAlice = NewClient(ctx, srvCfg.Address, "alice")
|
||||
err = clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
@ -355,10 +355,10 @@ func TestBindReconnect(t *testing.T) {
|
||||
func TestCloseConn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
addr := "localhost:1234"
|
||||
srvCfg := server.Config{Address: "localhost:1234"}
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
err := srv.Listen(addr)
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind server: %s", err)
|
||||
}
|
||||
@ -372,7 +372,7 @@ func TestCloseConn(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
clientAlice := NewClient(ctx, addr, "alice")
|
||||
clientAlice := NewClient(ctx, srvCfg.Address, "alice")
|
||||
err := clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
@ -403,10 +403,10 @@ func TestCloseConn(t *testing.T) {
|
||||
func TestCloseRelayConn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
addr := "localhost:1234"
|
||||
srvCfg := server.Config{Address: "localhost:1234"}
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
err := srv.Listen(addr)
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind server: %s", err)
|
||||
}
|
||||
@ -419,7 +419,7 @@ func TestCloseRelayConn(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
clientAlice := NewClient(ctx, addr, "alice")
|
||||
clientAlice := NewClient(ctx, srvCfg.Address, "alice")
|
||||
err := clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
@ -446,10 +446,10 @@ func TestCloseRelayConn(t *testing.T) {
|
||||
func TestCloseByServer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
addr1 := "localhost:1234"
|
||||
srvCfg := server.Config{Address: "localhost:1234"}
|
||||
srv1 := server.NewServer()
|
||||
go func() {
|
||||
err := srv1.Listen(addr1)
|
||||
err := srv1.Listen(srvCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
@ -457,7 +457,7 @@ func TestCloseByServer(t *testing.T) {
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
relayClient := NewClient(ctx, addr1, idAlice)
|
||||
relayClient := NewClient(ctx, srvCfg.Address, idAlice)
|
||||
err := relayClient.Connect()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to connect to server: %s", err)
|
||||
@ -489,10 +489,10 @@ func TestCloseByServer(t *testing.T) {
|
||||
func TestCloseByClient(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
addr1 := "localhost:1234"
|
||||
srvCfg := server.Config{Address: "localhost:1234"}
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
err := srv.Listen(addr1)
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
@ -500,7 +500,7 @@ func TestCloseByClient(t *testing.T) {
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
relayClient := NewClient(ctx, addr1, idAlice)
|
||||
relayClient := NewClient(ctx, srvCfg.Address, idAlice)
|
||||
err := relayClient.Connect()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to connect to server: %s", err)
|
||||
|
@ -12,14 +12,16 @@ import (
|
||||
type Conn struct {
|
||||
ctx context.Context
|
||||
*websocket.Conn
|
||||
srvAddr *net.TCPAddr
|
||||
srvAddr net.Addr
|
||||
localAddr net.Addr
|
||||
}
|
||||
|
||||
func NewConn(wsConn *websocket.Conn, srvAddr *net.TCPAddr) net.Conn {
|
||||
func NewConn(wsConn *websocket.Conn, srvAddr, localAddr net.Addr) net.Conn {
|
||||
return &Conn{
|
||||
ctx: context.Background(),
|
||||
Conn: wsConn,
|
||||
srvAddr: srvAddr,
|
||||
ctx: context.Background(),
|
||||
Conn: wsConn,
|
||||
srvAddr: srvAddr,
|
||||
localAddr: localAddr,
|
||||
}
|
||||
}
|
||||
|
||||
@ -46,8 +48,7 @@ func (c *Conn) RemoteAddr() net.Addr {
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
// todo: implement me
|
||||
return nil
|
||||
return c.localAddr
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"nhooyr.io/websocket"
|
||||
@ -13,32 +14,48 @@ import (
|
||||
)
|
||||
|
||||
func Dial(address string) (net.Conn, error) {
|
||||
|
||||
hostName, _, err := net.SplitHostPort(address)
|
||||
|
||||
addr, err := net.ResolveTCPAddr("tcp", address)
|
||||
wsURL, err := prepareURL(address)
|
||||
if err != nil {
|
||||
log.Errorf("failed to resolve address of Relay server: %s", address)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("ws://%s:%d", addr.IP.String(), addr.Port)
|
||||
opts := &websocket.DialOptions{
|
||||
Host: hostName,
|
||||
HTTPClient: httpClientNbDialer(),
|
||||
}
|
||||
|
||||
wsConn, _, err := websocket.Dial(context.Background(), url, opts)
|
||||
wsConn, _, err := websocket.Dial(context.Background(), wsURL, opts)
|
||||
if err != nil {
|
||||
log.Errorf("failed to dial to Relay server '%s': %s", url, err)
|
||||
log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn := NewConn(wsConn, addr)
|
||||
/*
|
||||
response.Body.(net.Conn).LocalAddr()
|
||||
unc, ok := response.Body.(net.Conn)
|
||||
if !ok {
|
||||
log.Errorf("failed to get local address: %s", err)
|
||||
return nil, fmt.Errorf("failed to get local address")
|
||||
}
|
||||
|
||||
*/
|
||||
// todo figure out the proper address
|
||||
dummy := &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 8080,
|
||||
}
|
||||
|
||||
conn := NewConn(wsConn, dummy, dummy)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func prepareURL(address string) (string, error) {
|
||||
if !strings.HasPrefix(address, "rel") {
|
||||
return "", fmt.Errorf("unsupported scheme: %s", address)
|
||||
}
|
||||
|
||||
return strings.Replace(address, "rel", "ws", 1), nil
|
||||
}
|
||||
|
||||
func httpClientNbDialer() *http.Client {
|
||||
customDialer := nbnet.NewDialer()
|
||||
|
||||
|
@ -13,10 +13,12 @@ import (
|
||||
func TestForeignConn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
addr1 := "localhost:1234"
|
||||
srvCfg1 := server.Config{
|
||||
Address: "localhost:1234",
|
||||
}
|
||||
srv1 := server.NewServer()
|
||||
go func() {
|
||||
err := srv1.Listen(addr1)
|
||||
err := srv1.Listen(srvCfg1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
@ -29,10 +31,12 @@ func TestForeignConn(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
addr2 := "localhost:2234"
|
||||
srvCfg2 := server.Config{
|
||||
Address: "localhost:2234",
|
||||
}
|
||||
srv2 := server.NewServer()
|
||||
go func() {
|
||||
err := srv2.Listen(addr2)
|
||||
err := srv2.Listen(srvCfg2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
@ -49,12 +53,12 @@ func TestForeignConn(t *testing.T) {
|
||||
log.Debugf("connect by alice")
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
clientAlice := NewManager(mCtx, addr1, idAlice)
|
||||
clientAlice := NewManager(mCtx, srvCfg1.Address, idAlice)
|
||||
clientAlice.Serve()
|
||||
|
||||
idBob := "bob"
|
||||
log.Debugf("connect by bob")
|
||||
clientBob := NewManager(mCtx, addr2, idBob)
|
||||
clientBob := NewManager(mCtx, srvCfg2.Address, idBob)
|
||||
clientBob.Serve()
|
||||
|
||||
bobsSrvAddr, err := clientBob.RelayAddress()
|
||||
@ -100,10 +104,12 @@ func TestForeignConn(t *testing.T) {
|
||||
func TestForeginConnClose(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
addr1 := "localhost:1234"
|
||||
srvCfg1 := server.Config{
|
||||
Address: "localhost:1234",
|
||||
}
|
||||
srv1 := server.NewServer()
|
||||
go func() {
|
||||
err := srv1.Listen(addr1)
|
||||
err := srv1.Listen(srvCfg1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
@ -116,10 +122,12 @@ func TestForeginConnClose(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
addr2 := "localhost:2234"
|
||||
srvCfg2 := server.Config{
|
||||
Address: "localhost:2234",
|
||||
}
|
||||
srv2 := server.NewServer()
|
||||
go func() {
|
||||
err := srv2.Listen(addr2)
|
||||
err := srv2.Listen(srvCfg2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
@ -136,10 +144,10 @@ func TestForeginConnClose(t *testing.T) {
|
||||
log.Debugf("connect by alice")
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
mgr := NewManager(mCtx, addr1, idAlice)
|
||||
mgr := NewManager(mCtx, srvCfg1.Address, idAlice)
|
||||
mgr.Serve()
|
||||
|
||||
conn, err := mgr.OpenConn(addr2, "anotherpeer", nil)
|
||||
conn, err := mgr.OpenConn(srvCfg2.Address, "anotherpeer", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
@ -153,11 +161,13 @@ func TestForeginConnClose(t *testing.T) {
|
||||
func TestForeginAutoClose(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
relayCleanupInterval = 1 * time.Second
|
||||
addr1 := "localhost:1234"
|
||||
srvCfg1 := server.Config{
|
||||
Address: "localhost:1234",
|
||||
}
|
||||
srv1 := server.NewServer()
|
||||
go func() {
|
||||
t.Log("binding server 1.")
|
||||
err := srv1.Listen(addr1)
|
||||
err := srv1.Listen(srvCfg1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
@ -172,11 +182,13 @@ func TestForeginAutoClose(t *testing.T) {
|
||||
t.Logf("server 1. closed")
|
||||
}()
|
||||
|
||||
addr2 := "localhost:2234"
|
||||
srvCfg2 := server.Config{
|
||||
Address: "localhost:2234",
|
||||
}
|
||||
srv2 := server.NewServer()
|
||||
go func() {
|
||||
t.Log("binding server 2.")
|
||||
err := srv2.Listen(addr2)
|
||||
err := srv2.Listen(srvCfg2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
@ -194,11 +206,11 @@ func TestForeginAutoClose(t *testing.T) {
|
||||
t.Log("connect to server 1.")
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
mgr := NewManager(mCtx, addr1, idAlice)
|
||||
mgr := NewManager(mCtx, srvCfg1.Address, idAlice)
|
||||
mgr.Serve()
|
||||
|
||||
t.Log("open connection to another peer")
|
||||
conn, err := mgr.OpenConn(addr2, "anotherpeer", nil)
|
||||
conn, err := mgr.OpenConn(srvCfg2.Address, "anotherpeer", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
@ -222,10 +234,12 @@ func TestAutoReconnect(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
reconnectingTimeout = 2 * time.Second
|
||||
|
||||
addr := "localhost:1234"
|
||||
srvCfg := server.Config{
|
||||
Address: "localhost:1234",
|
||||
}
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
err := srv.Listen(addr)
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind server: %s", err)
|
||||
}
|
||||
@ -240,7 +254,7 @@ func TestAutoReconnect(t *testing.T) {
|
||||
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
clientAlice := NewManager(mCtx, addr, "alice")
|
||||
clientAlice := NewManager(mCtx, srvCfg.Address, "alice")
|
||||
clientAlice.Serve()
|
||||
ra, err := clientAlice.RelayAddress()
|
||||
if err != nil {
|
||||
|
@ -1,6 +1,8 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
@ -8,12 +10,15 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/relay/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var (
|
||||
listenAddress string
|
||||
listenAddress string
|
||||
letsencryptDataDir string
|
||||
letsencryptDomain string
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "relay",
|
||||
@ -26,7 +31,8 @@ var (
|
||||
func init() {
|
||||
_ = util.InitLog("trace", "console")
|
||||
rootCmd.PersistentFlags().StringVarP(&listenAddress, "listen-address", "l", ":1235", "listen address")
|
||||
|
||||
rootCmd.PersistentFlags().StringVarP(&letsencryptDataDir, "letsencrypt-data-dir", "d", "", "a directory to store Let's Encrypt data. Required if Let's Encrypt is enabled.")
|
||||
rootCmd.PersistentFlags().StringVarP(&letsencryptDomain, "letsencrypt-domain", "a", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
|
||||
}
|
||||
|
||||
func waitForExitSignal() {
|
||||
@ -36,8 +42,20 @@ func waitForExitSignal() {
|
||||
}
|
||||
|
||||
func execute(cmd *cobra.Command, args []string) {
|
||||
srvCfg := server.Config{
|
||||
Address: listenAddress,
|
||||
}
|
||||
if hasLetsEncrypt() {
|
||||
tlscfg, err := setupTLS()
|
||||
if err != nil {
|
||||
log.Errorf("%s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
srvCfg.TLSConfig = tlscfg
|
||||
}
|
||||
|
||||
srv := server.NewServer()
|
||||
err := srv.Listen(listenAddress)
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
log.Errorf("failed to bind server: %s", err)
|
||||
os.Exit(1)
|
||||
@ -52,6 +70,18 @@ func execute(cmd *cobra.Command, args []string) {
|
||||
}
|
||||
}
|
||||
|
||||
func hasLetsEncrypt() bool {
|
||||
return letsencryptDataDir != "" && letsencryptDomain != ""
|
||||
}
|
||||
|
||||
func setupTLS() (*tls.Config, error) {
|
||||
certManager, err := encryption.CreateCertManager(letsencryptDataDir, letsencryptDomain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err)
|
||||
}
|
||||
return certManager.TLSConfig(), nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
err := rootCmd.Execute()
|
||||
if err != nil {
|
||||
|
@ -2,6 +2,7 @@ package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@ -10,35 +11,37 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"nhooyr.io/websocket"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/server/listener"
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
address string
|
||||
// Address is the address to listen on.
|
||||
Address string
|
||||
// TLSConfig is the TLS configuration for the server.
|
||||
TLSConfig *tls.Config
|
||||
|
||||
server *http.Server
|
||||
acceptFn func(conn net.Conn)
|
||||
}
|
||||
|
||||
func NewListener(address string) listener.Listener {
|
||||
return &Listener{
|
||||
address: address,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
|
||||
l.acceptFn = acceptFn
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", l.onAccept)
|
||||
|
||||
l.server = &http.Server{
|
||||
Addr: l.address,
|
||||
Handler: mux,
|
||||
Addr: l.Address,
|
||||
Handler: mux,
|
||||
TLSConfig: l.TLSConfig,
|
||||
}
|
||||
|
||||
log.Infof("WS server is listening on address: %s", l.address)
|
||||
err := l.server.ListenAndServe()
|
||||
log.Infof("WS server is listening on address: %s", l.Address)
|
||||
var err error
|
||||
if l.TLSConfig != nil {
|
||||
err = l.server.ListenAndServeTLS("", "")
|
||||
|
||||
} else {
|
||||
err = l.server.ListenAndServe()
|
||||
}
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
return nil
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
@ -13,6 +14,11 @@ import (
|
||||
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Address string
|
||||
TLSConfig *tls.Config
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
relay *Relay
|
||||
uDPListener listener.Listener
|
||||
@ -25,11 +31,15 @@ func NewServer() *Server {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Server) Listen(address string) error {
|
||||
func (r *Server) Listen(cfg Config) error {
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
r.wSListener = ws.NewListener(address)
|
||||
r.wSListener = &ws.Listener{
|
||||
Address: cfg.Address,
|
||||
TLSConfig: cfg.TLSConfig,
|
||||
}
|
||||
|
||||
var wslErr error
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
@ -39,7 +49,7 @@ func (r *Server) Listen(address string) error {
|
||||
}
|
||||
}()
|
||||
|
||||
r.uDPListener = udp.NewListener(address)
|
||||
r.uDPListener = udp.NewListener(cfg.Address)
|
||||
var udpLErr error
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
Loading…
Reference in New Issue
Block a user