Fix ssl configuration

This commit is contained in:
Zoltán Papp 2024-07-01 11:50:18 +02:00
parent ed82ef7fe4
commit d3785dc1fa
7 changed files with 164 additions and 89 deletions

View File

@ -23,10 +23,10 @@ func TestMain(m *testing.M) {
func TestClient(t *testing.T) {
ctx := context.Background()
addr := "localhost:1234"
srvCfg := server.Config{Address: "localhost:1234"}
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
err := srv.Listen(srvCfg)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
@ -39,21 +39,21 @@ func TestClient(t *testing.T) {
}
}()
clientAlice := NewClient(ctx, addr, "alice")
clientAlice := NewClient(ctx, srvCfg.Address, "alice")
err := clientAlice.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer clientAlice.Close()
clientPlaceHolder := NewClient(ctx, addr, "clientPlaceHolder")
clientPlaceHolder := NewClient(ctx, srvCfg.Address, "clientPlaceHolder")
err = clientPlaceHolder.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer clientPlaceHolder.Close()
clientBob := NewClient(ctx, addr, "bob")
clientBob := NewClient(ctx, srvCfg.Address, "bob")
err = clientBob.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
@ -91,16 +91,16 @@ func TestClient(t *testing.T) {
func TestRegistration(t *testing.T) {
ctx := context.Background()
addr := "localhost:1234"
srvCfg := server.Config{Address: "localhost:1234"}
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
err := srv.Listen(srvCfg)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
clientAlice := NewClient(ctx, addr, "alice")
clientAlice := NewClient(ctx, srvCfg.Address, "alice")
err := clientAlice.Connect()
if err != nil {
_ = srv.Close()
@ -156,10 +156,10 @@ func TestEcho(t *testing.T) {
ctx := context.Background()
idAlice := "alice"
idBob := "bob"
addr := "localhost:1234"
srvCfg := server.Config{Address: "localhost:1234"}
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
err := srv.Listen(srvCfg)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
@ -172,7 +172,7 @@ func TestEcho(t *testing.T) {
}
}()
clientAlice := NewClient(ctx, addr, idAlice)
clientAlice := NewClient(ctx, srvCfg.Address, idAlice)
err := clientAlice.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
@ -184,7 +184,7 @@ func TestEcho(t *testing.T) {
}
}()
clientBob := NewClient(ctx, addr, idBob)
clientBob := NewClient(ctx, srvCfg.Address, idBob)
err = clientBob.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
@ -236,10 +236,10 @@ func TestEcho(t *testing.T) {
func TestBindToUnavailabePeer(t *testing.T) {
ctx := context.Background()
addr := "localhost:1234"
srvCfg := server.Config{Address: "localhost:1234"}
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
err := srv.Listen(srvCfg)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
@ -253,7 +253,7 @@ func TestBindToUnavailabePeer(t *testing.T) {
}
}()
clientAlice := NewClient(ctx, addr, "alice")
clientAlice := NewClient(ctx, srvCfg.Address, "alice")
err := clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
@ -273,10 +273,10 @@ func TestBindToUnavailabePeer(t *testing.T) {
func TestBindReconnect(t *testing.T) {
ctx := context.Background()
addr := "localhost:1234"
srvCfg := server.Config{Address: "localhost:1234"}
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
err := srv.Listen(srvCfg)
if err != nil {
t.Errorf("failed to bind server: %s", err)
}
@ -290,7 +290,7 @@ func TestBindReconnect(t *testing.T) {
}
}()
clientAlice := NewClient(ctx, addr, "alice")
clientAlice := NewClient(ctx, srvCfg.Address, "alice")
err := clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
@ -301,7 +301,7 @@ func TestBindReconnect(t *testing.T) {
t.Errorf("failed to bind channel: %s", err)
}
clientBob := NewClient(ctx, addr, "bob")
clientBob := NewClient(ctx, srvCfg.Address, "bob")
err = clientBob.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
@ -318,7 +318,7 @@ func TestBindReconnect(t *testing.T) {
t.Errorf("failed to close client: %s", err)
}
clientAlice = NewClient(ctx, addr, "alice")
clientAlice = NewClient(ctx, srvCfg.Address, "alice")
err = clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
@ -355,10 +355,10 @@ func TestBindReconnect(t *testing.T) {
func TestCloseConn(t *testing.T) {
ctx := context.Background()
addr := "localhost:1234"
srvCfg := server.Config{Address: "localhost:1234"}
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
err := srv.Listen(srvCfg)
if err != nil {
t.Errorf("failed to bind server: %s", err)
}
@ -372,7 +372,7 @@ func TestCloseConn(t *testing.T) {
}
}()
clientAlice := NewClient(ctx, addr, "alice")
clientAlice := NewClient(ctx, srvCfg.Address, "alice")
err := clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
@ -403,10 +403,10 @@ func TestCloseConn(t *testing.T) {
func TestCloseRelayConn(t *testing.T) {
ctx := context.Background()
addr := "localhost:1234"
srvCfg := server.Config{Address: "localhost:1234"}
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
err := srv.Listen(srvCfg)
if err != nil {
t.Errorf("failed to bind server: %s", err)
}
@ -419,7 +419,7 @@ func TestCloseRelayConn(t *testing.T) {
}
}()
clientAlice := NewClient(ctx, addr, "alice")
clientAlice := NewClient(ctx, srvCfg.Address, "alice")
err := clientAlice.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
@ -446,10 +446,10 @@ func TestCloseRelayConn(t *testing.T) {
func TestCloseByServer(t *testing.T) {
ctx := context.Background()
addr1 := "localhost:1234"
srvCfg := server.Config{Address: "localhost:1234"}
srv1 := server.NewServer()
go func() {
err := srv1.Listen(addr1)
err := srv1.Listen(srvCfg)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
@ -457,7 +457,7 @@ func TestCloseByServer(t *testing.T) {
idAlice := "alice"
log.Debugf("connect by alice")
relayClient := NewClient(ctx, addr1, idAlice)
relayClient := NewClient(ctx, srvCfg.Address, idAlice)
err := relayClient.Connect()
if err != nil {
log.Fatalf("failed to connect to server: %s", err)
@ -489,10 +489,10 @@ func TestCloseByServer(t *testing.T) {
func TestCloseByClient(t *testing.T) {
ctx := context.Background()
addr1 := "localhost:1234"
srvCfg := server.Config{Address: "localhost:1234"}
srv := server.NewServer()
go func() {
err := srv.Listen(addr1)
err := srv.Listen(srvCfg)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
@ -500,7 +500,7 @@ func TestCloseByClient(t *testing.T) {
idAlice := "alice"
log.Debugf("connect by alice")
relayClient := NewClient(ctx, addr1, idAlice)
relayClient := NewClient(ctx, srvCfg.Address, idAlice)
err := relayClient.Connect()
if err != nil {
log.Fatalf("failed to connect to server: %s", err)

View File

@ -12,14 +12,16 @@ import (
type Conn struct {
ctx context.Context
*websocket.Conn
srvAddr *net.TCPAddr
srvAddr net.Addr
localAddr net.Addr
}
func NewConn(wsConn *websocket.Conn, srvAddr *net.TCPAddr) net.Conn {
func NewConn(wsConn *websocket.Conn, srvAddr, localAddr net.Addr) net.Conn {
return &Conn{
ctx: context.Background(),
Conn: wsConn,
srvAddr: srvAddr,
ctx: context.Background(),
Conn: wsConn,
srvAddr: srvAddr,
localAddr: localAddr,
}
}
@ -46,8 +48,7 @@ func (c *Conn) RemoteAddr() net.Addr {
}
func (c *Conn) LocalAddr() net.Addr {
// todo: implement me
return nil
return c.localAddr
}
func (c *Conn) SetReadDeadline(t time.Time) error {

View File

@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/http"
"strings"
log "github.com/sirupsen/logrus"
"nhooyr.io/websocket"
@ -13,32 +14,48 @@ import (
)
func Dial(address string) (net.Conn, error) {
hostName, _, err := net.SplitHostPort(address)
addr, err := net.ResolveTCPAddr("tcp", address)
wsURL, err := prepareURL(address)
if err != nil {
log.Errorf("failed to resolve address of Relay server: %s", address)
return nil, err
}
url := fmt.Sprintf("ws://%s:%d", addr.IP.String(), addr.Port)
opts := &websocket.DialOptions{
Host: hostName,
HTTPClient: httpClientNbDialer(),
}
wsConn, _, err := websocket.Dial(context.Background(), url, opts)
wsConn, _, err := websocket.Dial(context.Background(), wsURL, opts)
if err != nil {
log.Errorf("failed to dial to Relay server '%s': %s", url, err)
log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err)
return nil, err
}
conn := NewConn(wsConn, addr)
/*
response.Body.(net.Conn).LocalAddr()
unc, ok := response.Body.(net.Conn)
if !ok {
log.Errorf("failed to get local address: %s", err)
return nil, fmt.Errorf("failed to get local address")
}
*/
// todo figure out the proper address
dummy := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 8080,
}
conn := NewConn(wsConn, dummy, dummy)
return conn, nil
}
func prepareURL(address string) (string, error) {
if !strings.HasPrefix(address, "rel") {
return "", fmt.Errorf("unsupported scheme: %s", address)
}
return strings.Replace(address, "rel", "ws", 1), nil
}
func httpClientNbDialer() *http.Client {
customDialer := nbnet.NewDialer()

View File

@ -13,10 +13,12 @@ import (
func TestForeignConn(t *testing.T) {
ctx := context.Background()
addr1 := "localhost:1234"
srvCfg1 := server.Config{
Address: "localhost:1234",
}
srv1 := server.NewServer()
go func() {
err := srv1.Listen(addr1)
err := srv1.Listen(srvCfg1)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
@ -29,10 +31,12 @@ func TestForeignConn(t *testing.T) {
}
}()
addr2 := "localhost:2234"
srvCfg2 := server.Config{
Address: "localhost:2234",
}
srv2 := server.NewServer()
go func() {
err := srv2.Listen(addr2)
err := srv2.Listen(srvCfg2)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
@ -49,12 +53,12 @@ func TestForeignConn(t *testing.T) {
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
clientAlice := NewManager(mCtx, addr1, idAlice)
clientAlice := NewManager(mCtx, srvCfg1.Address, idAlice)
clientAlice.Serve()
idBob := "bob"
log.Debugf("connect by bob")
clientBob := NewManager(mCtx, addr2, idBob)
clientBob := NewManager(mCtx, srvCfg2.Address, idBob)
clientBob.Serve()
bobsSrvAddr, err := clientBob.RelayAddress()
@ -100,10 +104,12 @@ func TestForeignConn(t *testing.T) {
func TestForeginConnClose(t *testing.T) {
ctx := context.Background()
addr1 := "localhost:1234"
srvCfg1 := server.Config{
Address: "localhost:1234",
}
srv1 := server.NewServer()
go func() {
err := srv1.Listen(addr1)
err := srv1.Listen(srvCfg1)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
@ -116,10 +122,12 @@ func TestForeginConnClose(t *testing.T) {
}
}()
addr2 := "localhost:2234"
srvCfg2 := server.Config{
Address: "localhost:2234",
}
srv2 := server.NewServer()
go func() {
err := srv2.Listen(addr2)
err := srv2.Listen(srvCfg2)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
@ -136,10 +144,10 @@ func TestForeginConnClose(t *testing.T) {
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
mgr := NewManager(mCtx, addr1, idAlice)
mgr := NewManager(mCtx, srvCfg1.Address, idAlice)
mgr.Serve()
conn, err := mgr.OpenConn(addr2, "anotherpeer", nil)
conn, err := mgr.OpenConn(srvCfg2.Address, "anotherpeer", nil)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
@ -153,11 +161,13 @@ func TestForeginConnClose(t *testing.T) {
func TestForeginAutoClose(t *testing.T) {
ctx := context.Background()
relayCleanupInterval = 1 * time.Second
addr1 := "localhost:1234"
srvCfg1 := server.Config{
Address: "localhost:1234",
}
srv1 := server.NewServer()
go func() {
t.Log("binding server 1.")
err := srv1.Listen(addr1)
err := srv1.Listen(srvCfg1)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
@ -172,11 +182,13 @@ func TestForeginAutoClose(t *testing.T) {
t.Logf("server 1. closed")
}()
addr2 := "localhost:2234"
srvCfg2 := server.Config{
Address: "localhost:2234",
}
srv2 := server.NewServer()
go func() {
t.Log("binding server 2.")
err := srv2.Listen(addr2)
err := srv2.Listen(srvCfg2)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
@ -194,11 +206,11 @@ func TestForeginAutoClose(t *testing.T) {
t.Log("connect to server 1.")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
mgr := NewManager(mCtx, addr1, idAlice)
mgr := NewManager(mCtx, srvCfg1.Address, idAlice)
mgr.Serve()
t.Log("open connection to another peer")
conn, err := mgr.OpenConn(addr2, "anotherpeer", nil)
conn, err := mgr.OpenConn(srvCfg2.Address, "anotherpeer", nil)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
@ -222,10 +234,12 @@ func TestAutoReconnect(t *testing.T) {
ctx := context.Background()
reconnectingTimeout = 2 * time.Second
addr := "localhost:1234"
srvCfg := server.Config{
Address: "localhost:1234",
}
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
err := srv.Listen(srvCfg)
if err != nil {
t.Errorf("failed to bind server: %s", err)
}
@ -240,7 +254,7 @@ func TestAutoReconnect(t *testing.T) {
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
clientAlice := NewManager(mCtx, addr, "alice")
clientAlice := NewManager(mCtx, srvCfg.Address, "alice")
clientAlice.Serve()
ra, err := clientAlice.RelayAddress()
if err != nil {

View File

@ -1,6 +1,8 @@
package main
import (
"crypto/tls"
"fmt"
"os"
"os/signal"
"syscall"
@ -8,12 +10,15 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/util"
)
var (
listenAddress string
listenAddress string
letsencryptDataDir string
letsencryptDomain string
rootCmd = &cobra.Command{
Use: "relay",
@ -26,7 +31,8 @@ var (
func init() {
_ = util.InitLog("trace", "console")
rootCmd.PersistentFlags().StringVarP(&listenAddress, "listen-address", "l", ":1235", "listen address")
rootCmd.PersistentFlags().StringVarP(&letsencryptDataDir, "letsencrypt-data-dir", "d", "", "a directory to store Let's Encrypt data. Required if Let's Encrypt is enabled.")
rootCmd.PersistentFlags().StringVarP(&letsencryptDomain, "letsencrypt-domain", "a", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
}
func waitForExitSignal() {
@ -36,8 +42,20 @@ func waitForExitSignal() {
}
func execute(cmd *cobra.Command, args []string) {
srvCfg := server.Config{
Address: listenAddress,
}
if hasLetsEncrypt() {
tlscfg, err := setupTLS()
if err != nil {
log.Errorf("%s", err)
os.Exit(1)
}
srvCfg.TLSConfig = tlscfg
}
srv := server.NewServer()
err := srv.Listen(listenAddress)
err := srv.Listen(srvCfg)
if err != nil {
log.Errorf("failed to bind server: %s", err)
os.Exit(1)
@ -52,6 +70,18 @@ func execute(cmd *cobra.Command, args []string) {
}
}
func hasLetsEncrypt() bool {
return letsencryptDataDir != "" && letsencryptDomain != ""
}
func setupTLS() (*tls.Config, error) {
certManager, err := encryption.CreateCertManager(letsencryptDataDir, letsencryptDomain)
if err != nil {
return nil, fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err)
}
return certManager.TLSConfig(), nil
}
func main() {
err := rootCmd.Execute()
if err != nil {

View File

@ -2,6 +2,7 @@ package ws
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
@ -10,35 +11,37 @@ import (
log "github.com/sirupsen/logrus"
"nhooyr.io/websocket"
"github.com/netbirdio/netbird/relay/server/listener"
)
type Listener struct {
address string
// Address is the address to listen on.
Address string
// TLSConfig is the TLS configuration for the server.
TLSConfig *tls.Config
server *http.Server
acceptFn func(conn net.Conn)
}
func NewListener(address string) listener.Listener {
return &Listener{
address: address,
}
}
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
l.acceptFn = acceptFn
mux := http.NewServeMux()
mux.HandleFunc("/", l.onAccept)
l.server = &http.Server{
Addr: l.address,
Handler: mux,
Addr: l.Address,
Handler: mux,
TLSConfig: l.TLSConfig,
}
log.Infof("WS server is listening on address: %s", l.address)
err := l.server.ListenAndServe()
log.Infof("WS server is listening on address: %s", l.Address)
var err error
if l.TLSConfig != nil {
err = l.server.ListenAndServeTLS("", "")
} else {
err = l.server.ListenAndServe()
}
if errors.Is(err, http.ErrServerClosed) {
return nil
}

View File

@ -2,6 +2,7 @@ package server
import (
"context"
"crypto/tls"
"errors"
"sync"
"time"
@ -13,6 +14,11 @@ import (
"github.com/netbirdio/netbird/relay/server/listener/ws"
)
type Config struct {
Address string
TLSConfig *tls.Config
}
type Server struct {
relay *Relay
uDPListener listener.Listener
@ -25,11 +31,15 @@ func NewServer() *Server {
}
}
func (r *Server) Listen(address string) error {
func (r *Server) Listen(cfg Config) error {
wg := sync.WaitGroup{}
wg.Add(2)
r.wSListener = ws.NewListener(address)
r.wSListener = &ws.Listener{
Address: cfg.Address,
TLSConfig: cfg.TLSConfig,
}
var wslErr error
go func() {
defer wg.Done()
@ -39,7 +49,7 @@ func (r *Server) Listen(address string) error {
}
}()
r.uDPListener = udp.NewListener(address)
r.uDPListener = udp.NewListener(cfg.Address)
var udpLErr error
go func() {
defer wg.Done()