diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 6b615f667..9f17a087d 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -452,7 +452,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon if conn.wgProxyICE != nil { if err := conn.wgProxyICE.CloseConn(); err != nil { - conn.log.Warnf("failed to close depracated wg proxy conn: %v", err) + conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) } } conn.wgProxyICE = wgProxy @@ -549,7 +549,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { if conn.wgProxyRelay != nil { if err := conn.wgProxyRelay.CloseConn(); err != nil { - conn.log.Warnf("failed to close depracated wg proxy conn: %v", err) + conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) } } conn.wgProxyRelay = wgProxy diff --git a/encryption/route53.go b/encryption/route53.go new file mode 100644 index 000000000..3c81ab103 --- /dev/null +++ b/encryption/route53.go @@ -0,0 +1,87 @@ +package encryption + +import ( + "context" + "crypto/tls" + "fmt" + "os" + "strings" + + "github.com/caddyserver/certmagic" + "github.com/libdns/route53" + log "github.com/sirupsen/logrus" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "golang.org/x/crypto/acme" +) + +// Route53TLS by default, loads the AWS configuration from the environment. +// env variables: AWS_REGION, AWS_PROFILE, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN +type Route53TLS struct { + DataDir string + Email string + Domains []string + CA string +} + +func (r *Route53TLS) GetCertificate() (*tls.Config, error) { + if len(r.Domains) == 0 { + return nil, fmt.Errorf("no domains provided") + } + + certmagic.Default.Logger = logger() + certmagic.Default.Storage = &certmagic.FileStorage{Path: r.DataDir} + certmagic.DefaultACME.Agreed = true + if r.Email != "" { + certmagic.DefaultACME.Email = r.Email + } else { + certmagic.DefaultACME.Email = emailFromDomain(r.Domains[0]) + } + + if r.CA == "" { + certmagic.DefaultACME.CA = certmagic.LetsEncryptProductionCA + } else { + certmagic.DefaultACME.CA = r.CA + } + + certmagic.DefaultACME.DNS01Solver = &certmagic.DNS01Solver{ + DNSManager: certmagic.DNSManager{ + DNSProvider: &route53.Provider{}, + }, + } + cm := certmagic.NewDefault() + if err := cm.ManageSync(context.Background(), r.Domains); err != nil { + log.Errorf("failed to manage certificate: %v", err) + return nil, err + } + + tlsConfig := &tls.Config{ + GetCertificate: cm.GetCertificate, + NextProtos: []string{"h2", "http/1.1", acme.ALPNProto}, + } + + return tlsConfig, nil +} + +func emailFromDomain(domain string) string { + if domain == "" { + return "" + } + + parts := strings.Split(domain, ".") + if len(parts) < 2 { + return "" + } + if parts[0] == "" { + return "" + } + return fmt.Sprintf("admin@%s.%s", parts[len(parts)-2], parts[len(parts)-1]) +} + +func logger() *zap.Logger { + return zap.New(zapcore.NewCore( + zapcore.NewConsoleEncoder(zap.NewProductionEncoderConfig()), + os.Stderr, + zap.ErrorLevel, + )) +} diff --git a/encryption/route53_test.go b/encryption/route53_test.go new file mode 100644 index 000000000..765b60f84 --- /dev/null +++ b/encryption/route53_test.go @@ -0,0 +1,84 @@ +package encryption + +import ( + "context" + "io" + "net/http" + "os" + "testing" + "time" +) + +func TestRoute53TLSConfig(t *testing.T) { + t.SkipNow() // This test requires AWS credentials + exampleString := "Hello, world!" + rtls := &Route53TLS{ + DataDir: t.TempDir(), + Email: os.Getenv("LE_EMAIL_ROUTE53"), + Domains: []string{os.Getenv("DOMAIN")}, + } + tlsConfig, err := rtls.GetCertificate() + if err != nil { + t.Errorf("Route53TLSConfig failed: %v", err) + } + + server := &http.Server{ + Addr: ":8443", + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(exampleString)) + }), + TLSConfig: tlsConfig, + } + + go func() { + err := server.ListenAndServeTLS("", "") + if err != http.ErrServerClosed { + t.Errorf("Failed to start server: %v", err) + } + }() + defer func() { + if err := server.Shutdown(context.Background()); err != nil { + t.Errorf("Failed to shutdown server: %v", err) + } + }() + + time.Sleep(1 * time.Second) + resp, err := http.Get("https://relay.godevltd.com:8443") + if err != nil { + t.Errorf("Failed to get response: %v", err) + return + } + defer func() { + _ = resp.Body.Close() + }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Errorf("Failed to read response body: %v", err) + } + if string(body) != exampleString { + t.Errorf("Unexpected response: %s", body) + } +} + +func Test_emailFromDomain(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"example.com", "admin@example.com"}, + {"x.example.com", "admin@example.com"}, + {"x.x.example.com", "admin@example.com"}, + {"*.example.com", "admin@example.com"}, + {"example", ""}, + {"", ""}, + {".com", ""}, + } + for _, tt := range tests { + t.Run("domain test", func(t *testing.T) { + if got := emailFromDomain(tt.input); got != tt.want { + t.Errorf("emailFromDomain() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/go.mod b/go.mod index b02e19a73..eade75ea5 100644 --- a/go.mod +++ b/go.mod @@ -34,6 +34,7 @@ require ( fyne.io/systray v1.11.0 github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible github.com/c-robinson/iplib v1.0.3 + github.com/caddyserver/certmagic v0.21.3 github.com/cilium/ebpf v0.15.0 github.com/coreos/go-iptables v0.7.0 github.com/creack/pty v1.1.18 @@ -50,11 +51,12 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-version v1.6.0 + github.com/libdns/route53 v1.5.0 github.com/libp2p/go-netroute v0.2.1 github.com/magiconair/properties v1.8.7 github.com/mattn/go-sqlite3 v1.14.19 github.com/mdlayher/socket v0.4.1 - github.com/miekg/dns v1.1.43 + github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e @@ -70,6 +72,7 @@ require ( github.com/rs/xid v1.3.0 github.com/shirou/gopsutil/v3 v3.24.4 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 + github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/stretchr/testify v1.9.0 github.com/testcontainers/testcontainers-go v0.31.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0 @@ -81,6 +84,7 @@ require ( go.opentelemetry.io/otel/exporters/prometheus v0.48.0 go.opentelemetry.io/otel/metric v1.26.0 go.opentelemetry.io/otel/sdk/metric v1.26.0 + go.uber.org/zap v1.27.0 goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a @@ -107,8 +111,23 @@ require ( github.com/Microsoft/hcsshim v0.12.3 // indirect github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2 // indirect github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect + github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect + github.com/aws/aws-sdk-go-v2/config v1.27.27 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.27 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 // indirect + github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 // indirect + github.com/aws/smithy-go v1.20.3 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d // indirect + github.com/caddyserver/zerossl v0.1.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/containerd/containerd v1.7.16 // indirect github.com/containerd/log v0.1.0 // indirect @@ -141,7 +160,7 @@ require ( github.com/googleapis/gax-go/v2 v2.12.3 // indirect github.com/gopherjs/gopherjs v1.17.2 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect - github.com/hashicorp/go-uuid v1.0.2 // indirect + github.com/hashicorp/go-uuid v1.0.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect @@ -150,13 +169,17 @@ require ( github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect + github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/josharian/native v1.1.0 // indirect github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect github.com/kelseyhightower/envconfig v1.4.0 // indirect github.com/klauspost/compress v1.17.8 // indirect + github.com/klauspost/cpuid/v2 v2.2.7 // indirect + github.com/libdns/libdns v0.2.2 // indirect github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect github.com/mdlayher/genetlink v1.3.2 // indirect github.com/mdlayher/netlink v1.7.2 // indirect + github.com/mholt/acmez/v2 v2.0.1 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/patternmatcher v0.6.0 // indirect github.com/moby/sys/sequential v0.5.0 // indirect @@ -187,10 +210,12 @@ require ( github.com/tklauser/numcpus v0.8.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect github.com/yuin/goldmark v1.7.1 // indirect + github.com/zeebo/blake3 v0.2.3 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect go.opentelemetry.io/otel/sdk v1.26.0 // indirect go.opentelemetry.io/otel/trace v1.26.0 // indirect + go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/mod v0.17.0 // indirect golang.org/x/text v0.16.0 // indirect diff --git a/go.sum b/go.sum index c5fc08bfd..8ab365ac8 100644 --- a/go.sum +++ b/go.sum @@ -79,6 +79,34 @@ github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kd github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= +github.com/aws/aws-sdk-go-v2 v1.30.3 h1:jUeBtG0Ih+ZIFH0F4UkmL9w3cSpaMv9tYYDbzILP8dY= +github.com/aws/aws-sdk-go-v2 v1.30.3/go.mod h1:nIQjQVp5sfpQcTc9mPSr1B0PaWK5ByX9MOoDadSN4lc= +github.com/aws/aws-sdk-go-v2/config v1.27.27 h1:HdqgGt1OAP0HkEDDShEl0oSYa9ZZBSOmKpdpsDMdO90= +github.com/aws/aws-sdk-go-v2/config v1.27.27/go.mod h1:MVYamCg76dFNINkZFu4n4RjDixhVr51HLj4ErWzrVwg= +github.com/aws/aws-sdk-go-v2/credentials v1.17.27 h1:2raNba6gr2IfA0eqqiP2XiQ0UVOpGPgDSi0I9iAP+UI= +github.com/aws/aws-sdk-go-v2/credentials v1.17.27/go.mod h1:gniiwbGahQByxan6YjQUMcW4Aov6bLC3m+evgcoN4r4= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 h1:KreluoV8FZDEtI6Co2xuNk/UqI9iwMrOx/87PBNIKqw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11/go.mod h1:SeSUYBLsMYFoRvHE0Tjvn7kbxaUhl75CJi1sbfhMxkU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 h1:SoNJ4RlFEQEbtDcCEt+QG56MY4fm4W8rYirAmq+/DdU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15/go.mod h1:U9ke74k1n2bf+RIgoX1SXFed1HLs51OgUSs+Ph0KJP8= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 h1:C6WHdGnTDIYETAm5iErQUiVNsclNx9qbJVPIt03B6bI= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15/go.mod h1:ZQLZqhcu+JhSrA9/NXRm8SkDvsycE+JkV3WGY41e+IM= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 h1:dT3MqvGhSoaIhRseqw2I0yH81l7wiR2vjs57O51EAm8= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3/go.mod h1:GlAeCkHwugxdHaueRr4nhPuY+WW+gR8UjlcqzPr1SPI= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 h1:HGErhhrxZlQ044RiM+WdoZxp0p+EGM62y3L6pwA4olE= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17/go.mod h1:RkZEx4l0EHYDJpWppMJ3nD9wZJAa8/0lq9aVC+r2UII= +github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3 h1:MmLCRqP4U4Cw9gJ4bNrCG0mWqEtBlmAVleyelcHARMU= +github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3/go.mod h1:AMPjK2YnRh0YgOID3PqhJA1BRNfXDfGOnSsKHtAe8yA= +github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 h1:BXx0ZIxvrJdSgSvKTZ+yRBeSqqgPM89VPlulEcl37tM= +github.com/aws/aws-sdk-go-v2/service/sso v1.22.4/go.mod h1:ooyCOXjvJEsUw7x+ZDHeISPMhtwI3ZCB7ggFMcFfWLU= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 h1:yiwVzJW2ZxZTurVbYWA7QOrAaCYQR72t0wrSBfoesUE= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4/go.mod h1:0oxfLkpz3rQ/CHlx5hB7H69YUpFiI1tql6Q6Ne+1bCw= +github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 h1:ZsDKRLXGWHk8WdtyYMoGNO7bTudrvuKpDKgMVRlepGE= +github.com/aws/aws-sdk-go-v2/service/sts v1.30.3/go.mod h1:zwySh8fpFyXp9yOr/KVzxOl8SRqgf/IDw5aUt9UKFcQ= +github.com/aws/smithy-go v1.20.3 h1:ryHwveWzPV5BIof6fyDvor6V3iUL7nTfiTKXHiW05nE= +github.com/aws/smithy-go v1.20.3/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= @@ -87,6 +115,10 @@ github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d h1:pVrfxiGfwel github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d/go.mod h1:H0wQNHz2YrLsuXOZozoeDmnHXkNCRmMW0gwFWDfEZDA= github.com/c-robinson/iplib v1.0.3 h1:NG0UF0GoEsrC1/vyfX1Lx2Ss7CySWl3KqqXh3q4DdPU= github.com/c-robinson/iplib v1.0.3/go.mod h1:i3LuuFL1hRT5gFpBRnEydzw8R6yhGkF4szNDIbF8pgo= +github.com/caddyserver/certmagic v0.21.3 h1:pqRRry3yuB4CWBVq9+cUqu+Y6E2z8TswbhNx1AZeYm0= +github.com/caddyserver/certmagic v0.21.3/go.mod h1:Zq6pklO9nVRl3DIFUw9gVUfXKdpc/0qwTUAQMBlfgtI= +github.com/caddyserver/zerossl v0.1.3 h1:onS+pxp3M8HnHpN5MMbOMyNjmTheJyWRaZYwn+YTAyA= +github.com/caddyserver/zerossl v0.1.3/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4= github.com/cenkalti/backoff/v4 v4.1.0/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= @@ -352,8 +384,9 @@ github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerX github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4= github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= -github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE= github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= +github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-version v1.6.0 h1:feTTfFNnjP967rlCxM/I9g701jU+RN74YKx2mOkIeek= github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90= @@ -384,6 +417,10 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= github.com/json-iterator/go v0.0.0-20180612202835-f2b4162afba3/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= @@ -403,6 +440,9 @@ github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU= github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= +github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= +github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -415,6 +455,10 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s= +github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ= +github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA= +github.com/libdns/route53 v1.5.0/go.mod h1:joT4hKmaTNKHEwb7GmZ65eoDz1whTu7KKYPS8ZqIh6Q= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tAFlj1FYZl8ztUZ13bdq+PLY+NOfbyI= github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k= @@ -433,9 +477,11 @@ github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/ github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U= github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= +github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k= +github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= -github.com/miekg/dns v1.1.43 h1:JKfpVSCB84vrAmHzyrsxB5NAr5kLoMXZArPSw7Qlgyg= -github.com/miekg/dns v1.1.43/go.mod h1:+evo5L0630/F6ca/Z9+GAqzhjGyn8/c+TBaOyfEl0V4= +github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs= +github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= @@ -592,6 +638,8 @@ github.com/smartystreets/assertions v1.13.0/go.mod h1:wDmR7qL282YbGsPy6H/yAsesrx github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/smartystreets/goconvey v1.7.2 h1:9RBaZCeXEQ3UselpuwUQHltGVXvdwm6cv1hgR6gDIPg= github.com/smartystreets/goconvey v1.7.2/go.mod h1:Vw0tHAZW6lzCRk3xgdin6fKYcG+G3Pg9vgXWeJpQFMM= +github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= +github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= github.com/spf13/afero v1.6.0/go.mod h1:Ai8FlHk4v/PARR026UzYexafAt9roJ7LcLMAmO6Z93I= github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/spf13/cast v1.5.0 h1:rj3WzYc11XZaIZMPKmwP96zkFEnnAmV8s6XbB2aY32w= @@ -660,6 +708,12 @@ github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= github.com/zcalusic/sysinfo v1.0.2 h1:nwTTo2a+WQ0NXwo0BGRojOJvJ/5XKvQih+2RrtWqfxc= github.com/zcalusic/sysinfo v1.0.2/go.mod h1:kluzTYflRWo6/tXVMJPdEjShsbPpsFRyy+p1mBQPC30= +github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY= +github.com/zeebo/assert v1.1.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= +github.com/zeebo/blake3 v0.2.3 h1:TFoLXsjeXqRNFxSbk35Dk4YtszE/MQQGK10BH4ptoTg= +github.com/zeebo/blake3 v0.2.3/go.mod h1:mjJjZpnsyIVtVgTOSpJ9vmRE4wgDeyt2HU3qXvvKCaQ= +github.com/zeebo/pcg v1.0.1 h1:lyqfGeWiv4ahac6ttHs+I5hwtH/+1mrhlCtVNQM2kHo= +github.com/zeebo/pcg v1.0.1/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4= go.etcd.io/etcd/api/v3 v3.5.0/go.mod h1:cbVKeC6lCfl7j/8jBhAK6aIYO9XOjdptoxU/nLQcPvs= go.etcd.io/etcd/client/pkg/v3 v3.5.0/go.mod h1:IJHfcCEKxYu1Os13ZdwCwIUTUVGYTSAM3YSwc9/Ac1g= go.etcd.io/etcd/client/v2 v2.305.0/go.mod h1:h9puh54ZTgAKtEbut2oe9P4L/oqKCVB6xsXlzd7alYQ= @@ -695,8 +749,14 @@ go.opentelemetry.io/otel/trace v1.26.0/go.mod h1:4iDxvGDQuUkHve82hJJ8UqrwswHYsZu go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I= go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo= +go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= +go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= goauthentik.io/api/v3 v3.2023051.3 h1:NebAhD/TeTWNo/9X3/Uj+rM5fG1HaiLOlKTNLQv9Qq4= goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy4+7G8u4U= golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -890,7 +950,6 @@ golang.org/x/sys v0.0.0-20210104204734-6f8348627aad/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210220050731-9a76102bfb43/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210305230114-8fe3ee5dd75b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/relay/client/client.go b/relay/client/client.go index 4a20a3b00..aba940b41 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -46,7 +46,7 @@ func (isf *internalStopFlag) isSet() bool { return isf.stop } -// Msg carry the payload from the server to the client. With this sturct, the net.Conn can free the buffer. +// Msg carry the payload from the server to the client. With this struct, the net.Conn can free the buffer. type Msg struct { Payload []byte diff --git a/relay/cmd/env.go b/relay/cmd/env.go new file mode 100644 index 000000000..85d3e922b --- /dev/null +++ b/relay/cmd/env.go @@ -0,0 +1,35 @@ +package main + +import ( + "os" + "strings" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "github.com/spf13/pflag" +) + +// setFlagsFromEnvVars reads and updates flag values from environment variables with prefix NB_ +func setFlagsFromEnvVars(cmd *cobra.Command) { + flags := cmd.PersistentFlags() + flags.VisitAll(func(f *pflag.Flag) { + newEnvVar := flagNameToEnvVar(f.Name, "NB_") + value, present := os.LookupEnv(newEnvVar) + if !present { + return + } + + err := flags.Set(f.Name, value) + if err != nil { + log.Infof("unable to configure flag %s using variable %s, err: %v", f.Name, newEnvVar, err) + } + }) +} + +// flagNameToEnvVar converts flag name to environment var name adding a prefix, +// replacing dashes and making all uppercase (e.g. setup-keys is converted to NB_SETUP_KEYS according to the input prefix) +func flagNameToEnvVar(cmdFlag string, prefix string) string { + parsed := strings.ReplaceAll(cmdFlag, "-", "_") + upper := strings.ToUpper(parsed) + return prefix + upper +} diff --git a/relay/cmd/main.go b/relay/cmd/main.go index 7bb7d700a..b61f61d1a 100644 --- a/relay/cmd/main.go +++ b/relay/cmd/main.go @@ -31,11 +31,17 @@ type Config struct { // in HA every peer connect to a common domain, the instance domain has been distributed during the p2p connection // it is a domain:port or ip:port ExposedAddress string + LetsencryptEmail string LetsencryptDataDir string LetsencryptDomains []string - TlsCertFile string - TlsKeyFile string - AuthSecret string + // in case of using Route 53 for DNS challenge the credentials should be provided in the environment variables or + // in the AWS credentials file + LetsencryptAWSRoute53 bool + TlsCertFile string + TlsKeyFile string + AuthSecret string + LogLevel string + LogFile string } func (c Config) Validate() error { @@ -58,7 +64,6 @@ func (c Config) HasLetsEncrypt() bool { var ( cobraConfig *Config - cfgFile string rootCmd = &cobra.Command{ Use: "relay", Short: "Relay service", @@ -72,14 +77,19 @@ var ( func init() { _ = util.InitLog("trace", "console") cobraConfig = &Config{} - rootCmd.PersistentFlags().StringVarP(&cfgFile, "config-file", "f", "/etc/netbird/relay.json", "Relay server config file location") rootCmd.PersistentFlags().StringVarP(&cobraConfig.ListenAddress, "listen-address", "l", ":443", "listen address") rootCmd.PersistentFlags().StringVarP(&cobraConfig.ExposedAddress, "exposed-address", "e", "", "instance domain address (or ip) and port, it will be distributes between peers") rootCmd.PersistentFlags().StringVarP(&cobraConfig.LetsencryptDataDir, "letsencrypt-data-dir", "d", "", "a directory to store Let's Encrypt data. Required if Let's Encrypt is enabled.") - rootCmd.PersistentFlags().StringArrayVarP(&cobraConfig.LetsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") + rootCmd.PersistentFlags().StringSliceVarP(&cobraConfig.LetsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") + rootCmd.PersistentFlags().StringVar(&cobraConfig.LetsencryptEmail, "letsencrypt-email", "", "email address to use for Let's Encrypt certificate registration") + rootCmd.PersistentFlags().BoolVar(&cobraConfig.LetsencryptAWSRoute53, "letsencrypt-aws-route53", false, "use AWS Route 53 for Let's Encrypt DNS challenge") rootCmd.PersistentFlags().StringVarP(&cobraConfig.TlsCertFile, "tls-cert-file", "c", "", "") rootCmd.PersistentFlags().StringVarP(&cobraConfig.TlsKeyFile, "tls-key-file", "k", "", "") - rootCmd.PersistentFlags().StringVarP(&cobraConfig.AuthSecret, "auth-secret", "s", "", "log level") + rootCmd.PersistentFlags().StringVarP(&cobraConfig.AuthSecret, "auth-secret", "s", "", "auth secret") + rootCmd.PersistentFlags().StringVar(&cobraConfig.LogLevel, "log-level", "info", "log level") + rootCmd.PersistentFlags().StringVar(&cobraConfig.LogFile, "log-file", "console", "log file") + + setFlagsFromEnvVars(rootCmd) } func waitForExitSignal() { @@ -88,49 +98,15 @@ func waitForExitSignal() { <-osSigs } -func loadConfig(configFile string) (*Config, error) { - log.Infof("loading config from: %s", configFile) - loadedConfig := &Config{} - _, err := util.ReadJson(configFile, loadedConfig) - if err != nil { - return nil, err - } - if cobraConfig.ListenAddress != "" { - loadedConfig.ListenAddress = cobraConfig.ListenAddress - } - - if cobraConfig.ExposedAddress != "" { - loadedConfig.ExposedAddress = cobraConfig.ExposedAddress - } - if cobraConfig.LetsencryptDataDir != "" { - loadedConfig.LetsencryptDataDir = cobraConfig.LetsencryptDataDir - } - if len(cobraConfig.LetsencryptDomains) > 0 { - loadedConfig.LetsencryptDomains = cobraConfig.LetsencryptDomains - } - if cobraConfig.TlsCertFile != "" { - loadedConfig.TlsCertFile = cobraConfig.TlsCertFile - } - if cobraConfig.TlsKeyFile != "" { - loadedConfig.TlsKeyFile = cobraConfig.TlsKeyFile - } - if cobraConfig.AuthSecret != "" { - loadedConfig.AuthSecret = cobraConfig.AuthSecret - } - - return loadedConfig, err -} - func execute(cmd *cobra.Command, args []string) error { - cfg, err := loadConfig(cfgFile) + err := cobraConfig.Validate() if err != nil { - return fmt.Errorf("failed to load config: %s", err) + return fmt.Errorf("invalid config: %s", err) } - err = cfg.Validate() + err = util.InitLog(cobraConfig.LogLevel, cobraConfig.LogFile) if err != nil { - log.Errorf("invalid config: %s", err) - return err + return fmt.Errorf("failed to initialize log: %s", err) } metricsServer, err := metrics.NewServer(metricsPort, "") @@ -146,26 +122,17 @@ func execute(cmd *cobra.Command, args []string) error { }() srvListenerCfg := server.ListenerConfig{ - Address: cfg.ListenAddress, - } - if cfg.HasLetsEncrypt() { - tlsCfg, err := setupTLSCertManager(cfg.LetsencryptDataDir, cfg.LetsencryptDomains...) - if err != nil { - return fmt.Errorf("%s", err) - } - srvListenerCfg.TLSConfig = tlsCfg - } else if cfg.HasCertConfig() { - tlsCfg, err := encryption.LoadTLSConfig(cfg.TlsCertFile, cfg.TlsKeyFile) - if err != nil { - return fmt.Errorf("%s", err) - } - srvListenerCfg.TLSConfig = tlsCfg + Address: cobraConfig.ListenAddress, } - tlsSupport := srvListenerCfg.TLSConfig != nil + tlsConfig, tlsSupport, err := handleTLSConfig(cobraConfig) + if err != nil { + return fmt.Errorf("failed to setup TLS config: %s", err) + } + srvListenerCfg.TLSConfig = tlsConfig - authenticator := auth.NewTimedHMACValidator(cfg.AuthSecret, 24*time.Hour) - srv, err := server.NewServer(metricsServer.Meter, cfg.ExposedAddress, tlsSupport, authenticator) + authenticator := auth.NewTimedHMACValidator(cobraConfig.AuthSecret, 24*time.Hour) + srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator) if err != nil { return fmt.Errorf("failed to create relay server: %v", err) } @@ -194,6 +161,41 @@ func execute(cmd *cobra.Command, args []string) error { return shutDownErrors } +func handleTLSConfig(cfg *Config) (*tls.Config, bool, error) { + if cfg.LetsencryptAWSRoute53 { + log.Debugf("using Let's Encrypt DNS resolver with Route 53 support") + r53 := encryption.Route53TLS{ + DataDir: cfg.LetsencryptDataDir, + Email: cfg.LetsencryptEmail, + Domains: cfg.LetsencryptDomains, + } + tlsCfg, err := r53.GetCertificate() + if err != nil { + return nil, false, fmt.Errorf("%s", err) + } + return tlsCfg, true, nil + } + + if cfg.HasLetsEncrypt() { + log.Infof("setting up TLS with Let's Encrypt.") + tlsCfg, err := setupTLSCertManager(cfg.LetsencryptDataDir, cfg.LetsencryptDomains...) + if err != nil { + return nil, false, fmt.Errorf("%s", err) + } + return tlsCfg, true, nil + } + + if cfg.HasCertConfig() { + log.Debugf("using file based TLS config") + tlsCfg, err := encryption.LoadTLSConfig(cfg.TlsCertFile, cfg.TlsKeyFile) + if err != nil { + return nil, false, fmt.Errorf("%s", err) + } + return tlsCfg, true, nil + } + return nil, false, nil +} + func setupTLSCertManager(letsencryptDataDir string, letsencryptDomains ...string) (*tls.Config, error) { certManager, err := encryption.CreateCertManager(letsencryptDataDir, letsencryptDomains...) if err != nil { diff --git a/relay/messages/message.go b/relay/messages/message.go index 3770f6398..387d87a94 100644 --- a/relay/messages/message.go +++ b/relay/messages/message.go @@ -111,7 +111,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) { // authenticate the client with the server. func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) { if len(msg) < headerSizeHello { - return nil, nil, fmt.Errorf("invalid 'hello' messge") + return nil, nil, fmt.Errorf("invalid 'hello' message") } if !bytes.Equal(msg[1:5], magicHeader) { return nil, nil, fmt.Errorf("invalid magic header") diff --git a/relay/server/store.go b/relay/server/store.go index 79b8aeb5d..96879dae1 100644 --- a/relay/server/store.go +++ b/relay/server/store.go @@ -19,7 +19,7 @@ func NewStore() *Store { } // AddPeer adds a peer to the store -// It distinguishes the peers by their ID +// todo: consider to close peer conn if the peer already exists func (s *Store) AddPeer(peer *Peer) { s.peersLock.Lock() defer s.peersLock.Unlock() @@ -30,6 +30,15 @@ func (s *Store) AddPeer(peer *Peer) { func (s *Store) DeletePeer(peer *Peer) { s.peersLock.Lock() defer s.peersLock.Unlock() + + dp, ok := s.peers[peer.String()] + if !ok { + return + } + if dp != peer { + return + } + delete(s.peers, peer.String()) } diff --git a/relay/server/store_test.go b/relay/server/store_test.go new file mode 100644 index 000000000..4a30bc131 --- /dev/null +++ b/relay/server/store_test.go @@ -0,0 +1,40 @@ +package server + +import ( + "context" + "testing" + + "go.opentelemetry.io/otel" + + "github.com/netbirdio/netbird/relay/metrics" +) + +func TestStore_DeletePeer(t *testing.T) { + s := NewStore() + + m, _ := metrics.NewMetrics(context.Background(), otel.Meter("")) + + p := NewPeer(m, []byte("peer_one"), nil, nil) + s.AddPeer(p) + s.DeletePeer(p) + if _, ok := s.Peer(p.String()); ok { + t.Errorf("peer was not deleted") + } +} + +func TestStore_DeleteDeprecatedPeer(t *testing.T) { + s := NewStore() + + m, _ := metrics.NewMetrics(context.Background(), otel.Meter("")) + + p1 := NewPeer(m, []byte("peer_id"), nil, nil) + p2 := NewPeer(m, []byte("peer_id"), nil, nil) + + s.AddPeer(p1) + s.AddPeer(p2) + s.DeletePeer(p1) + + if _, ok := s.Peer(p2.String()); !ok { + t.Errorf("second peer was deleted") + } +} diff --git a/relay/testec2/main.go b/relay/testec2/main.go new file mode 100644 index 000000000..0c8099a5e --- /dev/null +++ b/relay/testec2/main.go @@ -0,0 +1,258 @@ +//go:build linux || darwin + +package main + +import ( + "crypto/rand" + "flag" + "fmt" + "net" + "os" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/util" +) + +const ( + errMsgFailedReadTCP = "failed to read from tcp: %s" +) + +var ( + dataSize = 1024 * 1024 * 50 // 50MB + pairs = []int{1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100} + signalListenAddress = ":8081" + + relaySrvAddress string + turnSrvAddress string + signalURL string + udpListener string // used for TURN test +) + +type testResult struct { + numOfPairs int + duration time.Duration + speed float64 +} + +func (tr testResult) Speed() string { + speed := tr.speed + var unit string + + switch { + case speed < 1024: + unit = "B/s" + case speed < 1048576: + speed /= 1024 + unit = "KB/s" + case speed < 1073741824: + speed /= 1048576 + unit = "MB/s" + default: + speed /= 1073741824 + unit = "GB/s" + } + + return fmt.Sprintf("%.2f %s", speed, unit) +} + +func seedRandomData(size int) ([]byte, error) { + token := make([]byte, size) + _, err := rand.Read(token) + if err != nil { + return nil, err + } + return token, nil +} + +func avg(transferDuration []time.Duration) (time.Duration, float64) { + var totalDuration time.Duration + for _, d := range transferDuration { + totalDuration += d + } + avgDuration := totalDuration / time.Duration(len(transferDuration)) + bps := float64(dataSize) / avgDuration.Seconds() + return avgDuration, bps +} + +func RelayReceiverMain() []testResult { + testResults := make([]testResult, 0, len(pairs)) + for _, p := range pairs { + tr := testResult{numOfPairs: p} + td := relayReceive(relaySrvAddress, p) + tr.duration, tr.speed = avg(td) + + testResults = append(testResults, tr) + } + + return testResults +} + +func RelaySenderMain() { + log.Infof("starting sender") + log.Infof("starting seed phase") + + testData, err := seedRandomData(dataSize) + if err != nil { + log.Fatalf("failed to seed random data: %s", err) + } + + log.Infof("data size: %d", len(testData)) + + for n, p := range pairs { + log.Infof("running test with %d pairs", p) + relayTransfer(relaySrvAddress, testData, p) + + // grant time to prepare new receivers + if n < len(pairs)-1 { + time.Sleep(3 * time.Second) + } + } + +} + +// TRUNSenderMain is the sender +// - allocate turn clients +// - send relayed addresses to signal server in batch +// - wait for signal server to send back addresses in a map +// - send test data to each address in parallel +func TRUNSenderMain() { + log.Infof("starting TURN sender test") + + log.Infof("starting seed random data: %d", dataSize) + testData, err := seedRandomData(dataSize) + if err != nil { + log.Fatalf("failed to seed random data: %s", err) + } + + ss := SignalClient{signalURL} + + for _, p := range pairs { + log.Infof("running test with %d pairs", p) + turnSender := &TurnSender{} + + createTurnConns(p, turnSender) + + log.Infof("send addresses via signal server: %d", len(turnSender.addresses)) + clientAddresses, err := ss.SendAddress(turnSender.addresses) + if err != nil { + log.Fatalf("failed to send address: %s", err) + } + log.Infof("received addresses: %v", clientAddresses.Address) + + createSenderDevices(turnSender, clientAddresses) + + log.Infof("waiting for tcpListeners to be ready") + time.Sleep(2 * time.Second) + + tcpConns := make([]net.Conn, 0, len(turnSender.devices)) + for i := range turnSender.devices { + addr := fmt.Sprintf("10.0.%d.2:9999", i) + log.Infof("dialing: %s", addr) + tcpConn, err := net.Dial("tcp", addr) + if err != nil { + log.Fatalf("failed to dial tcp: %s", err) + } + tcpConns = append(tcpConns, tcpConn) + } + + log.Infof("start test data transfer for %d pairs", p) + testDataLen := len(testData) + wg := sync.WaitGroup{} + wg.Add(len(tcpConns)) + for i, tcpConn := range tcpConns { + log.Infof("sending test data to device: %d", i) + go runTurnWriting(tcpConn, testData, testDataLen, &wg) + } + wg.Wait() + + for _, d := range turnSender.devices { + _ = d.Close() + } + + log.Infof("test finished with %d pairs", p) + } +} + +func TURNReaderMain() []testResult { + log.Infof("starting TURN receiver test") + si := NewSignalService() + go func() { + log.Infof("starting signal server") + err := si.Listen(signalListenAddress) + if err != nil { + log.Errorf("failed to listen: %s", err) + } + }() + + testResults := make([]testResult, 0, len(pairs)) + for range pairs { + addresses := <-si.AddressesChan + instanceNumber := len(addresses) + log.Infof("received addresses: %d", instanceNumber) + + turnReceiver := &TurnReceiver{} + err := createDevices(addresses, turnReceiver) + if err != nil { + log.Fatalf("%s", err) + } + + // send client addresses back via signal server + si.ClientAddressChan <- turnReceiver.clientAddresses + + durations := make(chan time.Duration, instanceNumber) + for _, device := range turnReceiver.devices { + go runTurnReading(device, durations) + } + + durationsList := make([]time.Duration, 0, instanceNumber) + for d := range durations { + durationsList = append(durationsList, d) + if len(durationsList) == instanceNumber { + close(durations) + } + } + + avgDuration, avgSpeed := avg(durationsList) + ts := testResult{ + numOfPairs: len(durationsList), + duration: avgDuration, + speed: avgSpeed, + } + testResults = append(testResults, ts) + + for _, d := range turnReceiver.devices { + _ = d.Close() + } + } + return testResults +} + +func main() { + var mode string + + _ = util.InitLog("debug", "console") + flag.StringVar(&mode, "mode", "sender", "sender or receiver mode") + flag.Parse() + + relaySrvAddress = os.Getenv("TEST_RELAY_SERVER") // rel://ip:port + turnSrvAddress = os.Getenv("TEST_TURN_SERVER") // ip:3478 + signalURL = os.Getenv("TEST_SIGNAL_URL") // http://receiver_ip:8081 + udpListener = os.Getenv("TEST_UDP_LISTENER") // IP:0 + + if mode == "receiver" { + relayResult := RelayReceiverMain() + turnResults := TURNReaderMain() + for i := 0; i < len(turnResults); i++ { + log.Infof("pairs: %d,\tRelay speed:\t%s,\trelay duration:\t%s", relayResult[i].numOfPairs, relayResult[i].Speed(), relayResult[i].duration) + log.Infof("pairs: %d,\tTURN speed:\t%s,\tturn duration:\t%s", turnResults[i].numOfPairs, turnResults[i].Speed(), turnResults[i].duration) + } + } else { + RelaySenderMain() + // grant time for receiver to start + time.Sleep(3 * time.Second) + TRUNSenderMain() + } +} diff --git a/relay/testec2/relay.go b/relay/testec2/relay.go new file mode 100644 index 000000000..93d084387 --- /dev/null +++ b/relay/testec2/relay.go @@ -0,0 +1,176 @@ +//go:build linux || darwin + +package main + +import ( + "context" + "fmt" + "net" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/relay/auth/hmac" + "github.com/netbirdio/netbird/relay/client" +) + +var ( + hmacTokenStore = &hmac.TokenStore{} +) + +func relayTransfer(serverConnURL string, testData []byte, peerPairs int) { + connsSender := prepareConnsSender(serverConnURL, peerPairs) + defer func() { + for i := 0; i < len(connsSender); i++ { + err := connsSender[i].Close() + if err != nil { + log.Errorf("failed to close connection: %s", err) + } + } + }() + + wg := sync.WaitGroup{} + wg.Add(len(connsSender)) + for _, conn := range connsSender { + go func(conn net.Conn) { + defer wg.Done() + runWriter(conn, testData) + }(conn) + } + wg.Wait() +} + +func runWriter(conn net.Conn, testData []byte) { + si := NewStartInidication(time.Now(), len(testData)) + _, err := conn.Write(si) + if err != nil { + log.Errorf("failed to write to channel: %s", err) + return + } + log.Infof("sent start indication") + + pieceSize := 1024 + testDataLen := len(testData) + + for j := 0; j < testDataLen; j += pieceSize { + end := j + pieceSize + if end > testDataLen { + end = testDataLen + } + _, writeErr := conn.Write(testData[j:end]) + if writeErr != nil { + log.Errorf("failed to write to channel: %s", writeErr) + return + } + } +} + +func prepareConnsSender(serverConnURL string, peerPairs int) []net.Conn { + ctx := context.Background() + clientsSender := make([]*client.Client, peerPairs) + for i := 0; i < cap(clientsSender); i++ { + c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i)) + if err := c.Connect(); err != nil { + log.Fatalf("failed to connect to server: %s", err) + } + clientsSender[i] = c + } + + connsSender := make([]net.Conn, 0, peerPairs) + for i := 0; i < len(clientsSender); i++ { + conn, err := clientsSender[i].OpenConn("receiver-" + fmt.Sprint(i)) + if err != nil { + log.Fatalf("failed to bind channel: %s", err) + } + connsSender = append(connsSender, conn) + } + return connsSender +} + +func relayReceive(serverConnURL string, peerPairs int) []time.Duration { + connsReceiver := prepareConnsReceiver(serverConnURL, peerPairs) + defer func() { + for i := 0; i < len(connsReceiver); i++ { + if err := connsReceiver[i].Close(); err != nil { + log.Errorf("failed to close connection: %s", err) + } + } + }() + + durations := make(chan time.Duration, len(connsReceiver)) + wg := sync.WaitGroup{} + for _, conn := range connsReceiver { + wg.Add(1) + go func(conn net.Conn) { + defer wg.Done() + duration := runReader(conn) + durations <- duration + }(conn) + } + wg.Wait() + + durationsList := make([]time.Duration, 0, len(connsReceiver)) + for d := range durations { + durationsList = append(durationsList, d) + if len(durationsList) == len(connsReceiver) { + close(durations) + } + } + + return durationsList +} + +func runReader(conn net.Conn) time.Duration { + buf := make([]byte, 8192) + + n, readErr := conn.Read(buf) + if readErr != nil { + log.Errorf("failed to read from channel: %s", readErr) + return 0 + } + + si := DecodeStartIndication(buf[:n]) + log.Infof("received start indication: %v", si) + + receivedSize, err := conn.Read(buf) + if err != nil { + log.Fatalf("failed to read from relay: %s", err) + } + now := time.Now() + + rcv := 0 + for receivedSize < si.TransferSize { + n, readErr = conn.Read(buf) + if readErr != nil { + log.Errorf("failed to read from channel: %s", readErr) + return 0 + } + + receivedSize += n + rcv += n + } + return time.Since(now) +} + +func prepareConnsReceiver(serverConnURL string, peerPairs int) []net.Conn { + clientsReceiver := make([]*client.Client, peerPairs) + for i := 0; i < cap(clientsReceiver); i++ { + c := client.NewClient(context.Background(), serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i)) + err := c.Connect() + if err != nil { + log.Fatalf("failed to connect to server: %s", err) + } + clientsReceiver[i] = c + } + + connsReceiver := make([]net.Conn, 0, peerPairs) + for i := 0; i < len(clientsReceiver); i++ { + conn, err := clientsReceiver[i].OpenConn("sender-" + fmt.Sprint(i)) + if err != nil { + log.Fatalf("failed to bind channel: %s", err) + } + connsReceiver = append(connsReceiver, conn) + } + return connsReceiver +} diff --git a/relay/testec2/signal.go b/relay/testec2/signal.go new file mode 100644 index 000000000..fe93a2fe2 --- /dev/null +++ b/relay/testec2/signal.go @@ -0,0 +1,91 @@ +//go:build linux || darwin + +package main + +import ( + "bytes" + "encoding/json" + "net/http" + + log "github.com/sirupsen/logrus" +) + +type PeerAddr struct { + Address []string +} + +type ClientPeerAddr struct { + Address map[string]string +} + +type Signal struct { + AddressesChan chan []string + ClientAddressChan chan map[string]string +} + +func NewSignalService() *Signal { + return &Signal{ + AddressesChan: make(chan []string), + ClientAddressChan: make(chan map[string]string), + } +} + +func (rs *Signal) Listen(listenAddr string) error { + http.HandleFunc("/", rs.onNewAddresses) + return http.ListenAndServe(listenAddr, nil) +} + +func (rs *Signal) onNewAddresses(w http.ResponseWriter, r *http.Request) { + var msg PeerAddr + err := json.NewDecoder(r.Body).Decode(&msg) + if err != nil { + log.Errorf("Error decoding message: %v", err) + } + + log.Infof("received addresses: %d", len(msg.Address)) + rs.AddressesChan <- msg.Address + clientAddresses := <-rs.ClientAddressChan + + respMsg := ClientPeerAddr{ + Address: clientAddresses, + } + data, err := json.Marshal(respMsg) + if err != nil { + log.Errorf("Error marshalling message: %v", err) + return + } + + _, err = w.Write(data) + if err != nil { + log.Errorf("Error writing response: %v", err) + } +} + +type SignalClient struct { + SignalURL string +} + +func (ss SignalClient) SendAddress(addresses []string) (*ClientPeerAddr, error) { + msg := PeerAddr{ + Address: addresses, + } + data, err := json.Marshal(msg) + if err != nil { + return nil, err + } + + response, err := http.Post(ss.SignalURL, "application/json", bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + + defer response.Body.Close() + + log.Debugf("wait for signal response") + var respPeerAddress ClientPeerAddr + err = json.NewDecoder(response.Body).Decode(&respPeerAddress) + if err != nil { + return nil, err + } + return &respPeerAddress, nil +} diff --git a/relay/testec2/start_msg.go b/relay/testec2/start_msg.go new file mode 100644 index 000000000..19b65380b --- /dev/null +++ b/relay/testec2/start_msg.go @@ -0,0 +1,39 @@ +//go:build linux || darwin + +package main + +import ( + "bytes" + "encoding/gob" + "time" + + log "github.com/sirupsen/logrus" +) + +type StartIndication struct { + Started time.Time + TransferSize int +} + +func NewStartInidication(started time.Time, transferSize int) []byte { + si := StartIndication{ + Started: started, + TransferSize: transferSize, + } + + var data bytes.Buffer + err := gob.NewEncoder(&data).Encode(si) + if err != nil { + log.Fatal("encode error:", err) + } + return data.Bytes() +} + +func DecodeStartIndication(data []byte) StartIndication { + var si StartIndication + err := gob.NewDecoder(bytes.NewReader(data)).Decode(&si) + if err != nil { + log.Fatal("decode error:", err) + } + return si +} diff --git a/relay/testec2/tun/proxy.go b/relay/testec2/tun/proxy.go new file mode 100644 index 000000000..7d84bece7 --- /dev/null +++ b/relay/testec2/tun/proxy.go @@ -0,0 +1,72 @@ +//go:build linux || darwin + +package tun + +import ( + "net" + "sync/atomic" + + log "github.com/sirupsen/logrus" +) + +type Proxy struct { + Device *Device + PConn net.PacketConn + DstAddr net.Addr + shutdownFlag atomic.Bool +} + +func (p *Proxy) Start() { + go p.readFromDevice() + go p.readFromConn() +} + +func (p *Proxy) Close() { + p.shutdownFlag.Store(true) +} + +func (p *Proxy) readFromDevice() { + buf := make([]byte, 1500) + for { + n, err := p.Device.Read(buf) + if err != nil { + if p.shutdownFlag.Load() { + return + } + log.Errorf("failed to read from device: %s", err) + return + } + + _, err = p.PConn.WriteTo(buf[:n], p.DstAddr) + if err != nil { + if p.shutdownFlag.Load() { + return + } + log.Errorf("failed to write to conn: %s", err) + return + } + } +} + +func (p *Proxy) readFromConn() { + buf := make([]byte, 1500) + for { + n, _, err := p.PConn.ReadFrom(buf) + if err != nil { + if p.shutdownFlag.Load() { + return + } + log.Errorf("failed to read from conn: %s", err) + return + } + + _, err = p.Device.Write(buf[:n]) + if err != nil { + if p.shutdownFlag.Load() { + return + } + log.Errorf("failed to write to device: %s", err) + return + } + } +} diff --git a/relay/testec2/tun/tun.go b/relay/testec2/tun/tun.go new file mode 100644 index 000000000..5580785ce --- /dev/null +++ b/relay/testec2/tun/tun.go @@ -0,0 +1,110 @@ +//go:build linux || darwin + +package tun + +import ( + "net" + + log "github.com/sirupsen/logrus" + "github.com/songgao/water" + "github.com/vishvananda/netlink" +) + +type Device struct { + Name string + IP string + PConn net.PacketConn + DstAddr net.Addr + + iFace *water.Interface + proxy *Proxy +} + +func (d *Device) Up() error { + cfg := water.Config{ + DeviceType: water.TUN, + PlatformSpecificParams: water.PlatformSpecificParams{ + Name: d.Name, + }, + } + iFace, err := water.New(cfg) + if err != nil { + return err + } + d.iFace = iFace + + err = d.assignIP() + if err != nil { + return err + } + + err = d.bringUp() + if err != nil { + return err + } + + d.proxy = &Proxy{ + Device: d, + PConn: d.PConn, + DstAddr: d.DstAddr, + } + d.proxy.Start() + return nil +} + +func (d *Device) Close() error { + if d.proxy != nil { + d.proxy.Close() + } + if d.iFace != nil { + return d.iFace.Close() + } + return nil +} + +func (d *Device) Read(b []byte) (int, error) { + return d.iFace.Read(b) +} + +func (d *Device) Write(b []byte) (int, error) { + return d.iFace.Write(b) +} + +func (d *Device) assignIP() error { + iface, err := netlink.LinkByName(d.Name) + if err != nil { + log.Errorf("failed to get TUN device: %v", err) + return err + } + + ip := net.IPNet{ + IP: net.ParseIP(d.IP), + Mask: net.CIDRMask(24, 32), + } + + addr := &netlink.Addr{ + IPNet: &ip, + } + err = netlink.AddrAdd(iface, addr) + if err != nil { + log.Errorf("failed to add IP address: %v", err) + return err + } + return nil +} + +func (d *Device) bringUp() error { + iface, err := netlink.LinkByName(d.Name) + if err != nil { + log.Errorf("failed to get device: %v", err) + return err + } + + // Bring the interface up + err = netlink.LinkSetUp(iface) + if err != nil { + log.Errorf("failed to set device up: %v", err) + return err + } + return nil +} diff --git a/relay/testec2/turn.go b/relay/testec2/turn.go new file mode 100644 index 000000000..8beb40423 --- /dev/null +++ b/relay/testec2/turn.go @@ -0,0 +1,181 @@ +//go:build linux || darwin + +package main + +import ( + "fmt" + "net" + "sync" + "time" + + "github.com/netbirdio/netbird/relay/testec2/tun" + + log "github.com/sirupsen/logrus" +) + +type TurnReceiver struct { + conns []*net.UDPConn + clientAddresses map[string]string + devices []*tun.Device +} + +type TurnSender struct { + turnConns map[string]*TurnConn + addresses []string + devices []*tun.Device +} + +func runTurnWriting(tcpConn net.Conn, testData []byte, testDataLen int, wg *sync.WaitGroup) { + defer wg.Done() + defer tcpConn.Close() + + log.Infof("start to sending test data: %s", tcpConn.RemoteAddr()) + + si := NewStartInidication(time.Now(), testDataLen) + _, err := tcpConn.Write(si) + if err != nil { + log.Errorf("failed to write to tcp: %s", err) + return + } + + pieceSize := 1024 + for j := 0; j < testDataLen; j += pieceSize { + end := j + pieceSize + if end > testDataLen { + end = testDataLen + } + _, writeErr := tcpConn.Write(testData[j:end]) + if writeErr != nil { + log.Errorf("failed to write to tcp conn: %s", writeErr) + return + } + } + + // grant time to flush out packages + time.Sleep(3 * time.Second) +} + +func createSenderDevices(sender *TurnSender, clientAddresses *ClientPeerAddr) { + var i int + devices := make([]*tun.Device, 0, len(clientAddresses.Address)) + for k, v := range clientAddresses.Address { + tc, ok := sender.turnConns[k] + if !ok { + log.Fatalf("failed to find turn conn: %s", k) + } + + addr, err := net.ResolveUDPAddr("udp", v) + if err != nil { + log.Fatalf("failed to resolve udp address: %s", err) + } + device := &tun.Device{ + Name: fmt.Sprintf("mtun-sender-%d", i), + IP: fmt.Sprintf("10.0.%d.1", i), + PConn: tc.relayConn, + DstAddr: addr, + } + + err = device.Up() + if err != nil { + log.Fatalf("failed to bring up device: %s", err) + } + + devices = append(devices, device) + i++ + } + sender.devices = devices +} + +func createTurnConns(p int, sender *TurnSender) { + turnConns := make(map[string]*TurnConn) + addresses := make([]string, 0, len(pairs)) + for i := 0; i < p; i++ { + tc := AllocateTurnClient(turnSrvAddress) + log.Infof("allocated turn client: %s", tc.Address().String()) + turnConns[tc.Address().String()] = tc + addresses = append(addresses, tc.Address().String()) + } + + sender.turnConns = turnConns + sender.addresses = addresses +} + +func runTurnReading(d *tun.Device, durations chan time.Duration) { + tcpListener, err := net.Listen("tcp", d.IP+":9999") + if err != nil { + log.Fatalf("failed to listen on tcp: %s", err) + } + log := log.WithField("device", tcpListener.Addr()) + + tcpConn, err := tcpListener.Accept() + if err != nil { + _ = tcpListener.Close() + log.Fatalf("failed to accept connection: %s", err) + } + log.Infof("remote peer connected") + + buf := make([]byte, 103) + n, err := tcpConn.Read(buf) + if err != nil { + _ = tcpListener.Close() + log.Fatalf(errMsgFailedReadTCP, err) + } + + si := DecodeStartIndication(buf[:n]) + log.Infof("received start indication: %v, %d", si, n) + + buf = make([]byte, 8192) + i, err := tcpConn.Read(buf) + if err != nil { + _ = tcpListener.Close() + log.Fatalf(errMsgFailedReadTCP, err) + } + now := time.Now() + for i < si.TransferSize { + n, err := tcpConn.Read(buf) + if err != nil { + _ = tcpListener.Close() + log.Fatalf(errMsgFailedReadTCP, err) + } + i += n + } + durations <- time.Since(now) +} + +func createDevices(addresses []string, receiver *TurnReceiver) error { + receiver.conns = make([]*net.UDPConn, 0, len(addresses)) + receiver.clientAddresses = make(map[string]string, len(addresses)) + receiver.devices = make([]*tun.Device, 0, len(addresses)) + for i, addr := range addresses { + localAddr, err := net.ResolveUDPAddr("udp", udpListener) + if err != nil { + return fmt.Errorf("failed to resolve UDP address: %s", err) + } + + conn, err := net.ListenUDP("udp", localAddr) + if err != nil { + return fmt.Errorf("failed to create UDP connection: %s", err) + } + + receiver.conns = append(receiver.conns, conn) + receiver.clientAddresses[addr] = conn.LocalAddr().String() + + dstAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return fmt.Errorf("failed to resolve address: %s", err) + } + + device := &tun.Device{ + Name: fmt.Sprintf("mtun-%d", i), + IP: fmt.Sprintf("10.0.%d.2", i), + PConn: conn, + DstAddr: dstAddr, + } + + if err = device.Up(); err != nil { + return fmt.Errorf("failed to bring up device: %s, %s", device.Name, err) + } + receiver.devices = append(receiver.devices, device) + } + return nil +} diff --git a/relay/testec2/turn_allocator.go b/relay/testec2/turn_allocator.go new file mode 100644 index 000000000..fd86208df --- /dev/null +++ b/relay/testec2/turn_allocator.go @@ -0,0 +1,83 @@ +//go:build linux || darwin + +package main + +import ( + "fmt" + "net" + + "github.com/pion/logging" + "github.com/pion/turn/v3" + log "github.com/sirupsen/logrus" +) + +type TurnConn struct { + conn net.Conn + turnClient *turn.Client + relayConn net.PacketConn +} + +func (tc *TurnConn) Address() net.Addr { + return tc.relayConn.LocalAddr() +} + +func (tc *TurnConn) Close() { + _ = tc.relayConn.Close() + tc.turnClient.Close() + _ = tc.conn.Close() +} + +func AllocateTurnClient(serverAddr string) *TurnConn { + conn, err := net.Dial("tcp", serverAddr) + if err != nil { + log.Fatal(err) + } + + turnClient, err := getTurnClient(serverAddr, conn) + if err != nil { + log.Fatal(err) + } + + relayConn, err := turnClient.Allocate() + if err != nil { + log.Fatal(err) + } + + return &TurnConn{ + conn: conn, + turnClient: turnClient, + relayConn: relayConn, + } +} + +func getTurnClient(address string, conn net.Conn) (*turn.Client, error) { + // Dial TURN Server + addrStr := fmt.Sprintf("%s:%d", address, 443) + + fac := logging.NewDefaultLoggerFactory() + //fac.DefaultLogLevel = logging.LogLevelTrace + + // Start a new TURN Client and wrap our net.Conn in a STUNConn + // This allows us to simulate datagram based communication over a net.Conn + cfg := &turn.ClientConfig{ + TURNServerAddr: address, + Conn: turn.NewSTUNConn(conn), + Username: "test", + Password: "test", + LoggerFactory: fac, + } + + client, err := turn.NewClient(cfg) + if err != nil { + return nil, fmt.Errorf("failed to create TURN client for server %s: %s", addrStr, err) + } + + // Start listening on the conn provided. + err = client.Listen() + if err != nil { + client.Close() + return nil, fmt.Errorf("failed to listen on TURN client for server %s: %s", addrStr, err) + } + + return client, nil +}