From a208e7999ce886da584b995559f0534c0814f024 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20Papp?= Date: Wed, 21 Aug 2024 16:05:04 +0200 Subject: [PATCH] Pass the ctx to the close function --- relay/client/client_test.go | 20 ++++++++++---------- relay/client/manager_test.go | 16 ++++++++-------- relay/cmd/main.go | 6 +++++- relay/server/listener/listener.go | 7 +++++-- relay/server/listener/ws/listener.go | 6 +----- relay/server/server.go | 8 ++------ relay/test/benchmark_test.go | 2 +- 7 files changed, 32 insertions(+), 33 deletions(-) diff --git a/relay/client/client_test.go b/relay/client/client_test.go index 592d5bb4d..dc7f70b43 100644 --- a/relay/client/client_test.go +++ b/relay/client/client_test.go @@ -47,7 +47,7 @@ func TestClient(t *testing.T) { }() defer func() { - err := srv.Close() + err := srv.Close(ctx) if err != nil { t.Errorf("failed to close server: %s", err) } @@ -135,14 +135,14 @@ func TestRegistration(t *testing.T) { clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") err = clientAlice.Connect() if err != nil { - _ = srv.Close() + _ = srv.Close(ctx) t.Fatalf("failed to connect to server: %s", err) } err = clientAlice.Close() if err != nil { t.Errorf("failed to close conn: %s", err) } - err = srv.Close() + err = srv.Close(ctx) if err != nil { t.Errorf("failed to close server: %s", err) } @@ -202,7 +202,7 @@ func TestEcho(t *testing.T) { }() defer func() { - err := srv.Close() + err := srv.Close(ctx) if err != nil { t.Errorf("failed to close server: %s", err) } @@ -292,7 +292,7 @@ func TestBindToUnavailabePeer(t *testing.T) { defer func() { log.Infof("closing server") - err := srv.Close() + err := srv.Close(ctx) if err != nil { t.Errorf("failed to close server: %s", err) } @@ -338,7 +338,7 @@ func TestBindReconnect(t *testing.T) { defer func() { log.Infof("closing server") - err := srv.Close() + err := srv.Close(ctx) if err != nil { t.Errorf("failed to close server: %s", err) } @@ -429,7 +429,7 @@ func TestCloseConn(t *testing.T) { defer func() { log.Infof("closing server") - err := srv.Close() + err := srv.Close(ctx) if err != nil { t.Errorf("failed to close server: %s", err) } @@ -485,7 +485,7 @@ func TestCloseRelayConn(t *testing.T) { }() defer func() { - err := srv.Close() + err := srv.Close(ctx) if err != nil { log.Errorf("failed to close server: %s", err) } @@ -556,7 +556,7 @@ func TestCloseByServer(t *testing.T) { close(disconnected) }) - err = srv1.Close() + err = srv1.Close(ctx) if err != nil { t.Fatalf("failed to close server: %s", err) } @@ -612,7 +612,7 @@ func TestCloseByClient(t *testing.T) { t.Errorf("unexpected opening connection to closed server") } - err = srv.Close() + err = srv.Close(ctx) if err != nil { t.Fatalf("failed to close server: %s", err) } diff --git a/relay/client/manager_test.go b/relay/client/manager_test.go index 53ac9bcd7..b11e49737 100644 --- a/relay/client/manager_test.go +++ b/relay/client/manager_test.go @@ -38,7 +38,7 @@ func TestForeignConn(t *testing.T) { }() defer func() { - err := srv1.Close() + err := srv1.Close(ctx) if err != nil { t.Errorf("failed to close server: %s", err) } @@ -64,7 +64,7 @@ func TestForeignConn(t *testing.T) { }() defer func() { - err := srv2.Close() + err := srv2.Close(ctx) if err != nil { t.Errorf("failed to close server: %s", err) } @@ -150,7 +150,7 @@ func TestForeginConnClose(t *testing.T) { }() defer func() { - err := srv1.Close() + err := srv1.Close(ctx) if err != nil { t.Errorf("failed to close server: %s", err) } @@ -176,7 +176,7 @@ func TestForeginConnClose(t *testing.T) { }() defer func() { - err := srv2.Close() + err := srv2.Close(ctx) if err != nil { t.Errorf("failed to close server: %s", err) } @@ -227,7 +227,7 @@ func TestForeginAutoClose(t *testing.T) { defer func() { t.Logf("closing server 1.") - err := srv1.Close() + err := srv1.Close(ctx) if err != nil { t.Errorf("failed to close server: %s", err) } @@ -255,7 +255,7 @@ func TestForeginAutoClose(t *testing.T) { }() defer func() { t.Logf("closing server 2.") - err := srv2.Close() + err := srv2.Close(ctx) if err != nil { t.Errorf("failed to close server: %s", err) } @@ -317,7 +317,7 @@ func TestAutoReconnect(t *testing.T) { }() defer func() { - err := srv.Close() + err := srv.Close(ctx) if err != nil { log.Errorf("failed to close server: %s", err) } @@ -381,7 +381,7 @@ func TestNotifierDoubleAdd(t *testing.T) { }() defer func() { - err := srv1.Close() + err := srv1.Close(ctx) if err != nil { t.Errorf("failed to close server: %s", err) } diff --git a/relay/cmd/main.go b/relay/cmd/main.go index dc39beaa5..c4fcbc994 100644 --- a/relay/cmd/main.go +++ b/relay/cmd/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "crypto/tls" "errors" "fmt" @@ -175,7 +176,10 @@ func execute(cmd *cobra.Command, args []string) error { // it will block until exit signal waitForExitSignal() - err = srv.Close() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + err = srv.Close(ctx) if err != nil { return fmt.Errorf("failed to close server: %s", err) } diff --git a/relay/server/listener/listener.go b/relay/server/listener/listener.go index 66e6d357e..45167ac1d 100644 --- a/relay/server/listener/listener.go +++ b/relay/server/listener/listener.go @@ -1,8 +1,11 @@ package listener -import "net" +import ( + "context" + "net" +) type Listener interface { Listen(func(conn net.Conn)) error - Close() error + Close(ctx context.Context) error } diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go index 673a5e19a..c429e4545 100644 --- a/relay/server/listener/ws/listener.go +++ b/relay/server/listener/ws/listener.go @@ -7,7 +7,6 @@ import ( "fmt" "net" "net/http" - "time" log "github.com/sirupsen/logrus" "nhooyr.io/websocket" @@ -47,14 +46,11 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { return err } -func (l *Listener) Close() error { +func (l *Listener) Close(ctx context.Context) error { if l.server == nil { return nil } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - log.Infof("stop WS listener") if err := l.server.Shutdown(ctx); err != nil { return fmt.Errorf("server shutdown failed: %v", err) diff --git a/relay/server/server.go b/relay/server/server.go index 2d74a5eef..c96e2c25d 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -3,7 +3,6 @@ package server import ( "context" "crypto/tls" - "time" log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/metric" @@ -61,15 +60,12 @@ func (r *Server) Listen(cfg ListenerConfig) error { // Close stops the relay server. If there are active connections, they will be closed gracefully. In case of a timeout, // the connections will be forcefully closed. -func (r *Server) Close() (err error) { +func (r *Server) Close(ctx context.Context) (err error) { // stop service new connections if r.wSListener != nil { - err = r.wSListener.Close() + err = r.wSListener.Close(ctx) } - // close accepted connections gracefully - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() r.relay.Close(ctx) return } diff --git a/relay/test/benchmark_test.go b/relay/test/benchmark_test.go index d6d71a765..620902a2a 100644 --- a/relay/test/benchmark_test.go +++ b/relay/test/benchmark_test.go @@ -85,7 +85,7 @@ func transfer(t *testing.T, testData []byte, peerPairs int) { }() defer func() { - err := srv.Close() + err := srv.Close(ctx) if err != nil { t.Errorf("failed to close server: %s", err) }