Pass the ctx to the close function

This commit is contained in:
Zoltán Papp 2024-08-21 16:05:04 +02:00
parent 7633cca3b1
commit a208e7999c
7 changed files with 32 additions and 33 deletions

View File

@ -47,7 +47,7 @@ func TestClient(t *testing.T) {
}() }()
defer func() { defer func() {
err := srv.Close() err := srv.Close(ctx)
if err != nil { if err != nil {
t.Errorf("failed to close server: %s", err) t.Errorf("failed to close server: %s", err)
} }
@ -135,14 +135,14 @@ func TestRegistration(t *testing.T) {
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect()
if err != nil { if err != nil {
_ = srv.Close() _ = srv.Close(ctx)
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
err = clientAlice.Close() err = clientAlice.Close()
if err != nil { if err != nil {
t.Errorf("failed to close conn: %s", err) t.Errorf("failed to close conn: %s", err)
} }
err = srv.Close() err = srv.Close(ctx)
if err != nil { if err != nil {
t.Errorf("failed to close server: %s", err) t.Errorf("failed to close server: %s", err)
} }
@ -202,7 +202,7 @@ func TestEcho(t *testing.T) {
}() }()
defer func() { defer func() {
err := srv.Close() err := srv.Close(ctx)
if err != nil { if err != nil {
t.Errorf("failed to close server: %s", err) t.Errorf("failed to close server: %s", err)
} }
@ -292,7 +292,7 @@ func TestBindToUnavailabePeer(t *testing.T) {
defer func() { defer func() {
log.Infof("closing server") log.Infof("closing server")
err := srv.Close() err := srv.Close(ctx)
if err != nil { if err != nil {
t.Errorf("failed to close server: %s", err) t.Errorf("failed to close server: %s", err)
} }
@ -338,7 +338,7 @@ func TestBindReconnect(t *testing.T) {
defer func() { defer func() {
log.Infof("closing server") log.Infof("closing server")
err := srv.Close() err := srv.Close(ctx)
if err != nil { if err != nil {
t.Errorf("failed to close server: %s", err) t.Errorf("failed to close server: %s", err)
} }
@ -429,7 +429,7 @@ func TestCloseConn(t *testing.T) {
defer func() { defer func() {
log.Infof("closing server") log.Infof("closing server")
err := srv.Close() err := srv.Close(ctx)
if err != nil { if err != nil {
t.Errorf("failed to close server: %s", err) t.Errorf("failed to close server: %s", err)
} }
@ -485,7 +485,7 @@ func TestCloseRelayConn(t *testing.T) {
}() }()
defer func() { defer func() {
err := srv.Close() err := srv.Close(ctx)
if err != nil { if err != nil {
log.Errorf("failed to close server: %s", err) log.Errorf("failed to close server: %s", err)
} }
@ -556,7 +556,7 @@ func TestCloseByServer(t *testing.T) {
close(disconnected) close(disconnected)
}) })
err = srv1.Close() err = srv1.Close(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to close server: %s", err) t.Fatalf("failed to close server: %s", err)
} }
@ -612,7 +612,7 @@ func TestCloseByClient(t *testing.T) {
t.Errorf("unexpected opening connection to closed server") t.Errorf("unexpected opening connection to closed server")
} }
err = srv.Close() err = srv.Close(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to close server: %s", err) t.Fatalf("failed to close server: %s", err)
} }

View File

@ -38,7 +38,7 @@ func TestForeignConn(t *testing.T) {
}() }()
defer func() { defer func() {
err := srv1.Close() err := srv1.Close(ctx)
if err != nil { if err != nil {
t.Errorf("failed to close server: %s", err) t.Errorf("failed to close server: %s", err)
} }
@ -64,7 +64,7 @@ func TestForeignConn(t *testing.T) {
}() }()
defer func() { defer func() {
err := srv2.Close() err := srv2.Close(ctx)
if err != nil { if err != nil {
t.Errorf("failed to close server: %s", err) t.Errorf("failed to close server: %s", err)
} }
@ -150,7 +150,7 @@ func TestForeginConnClose(t *testing.T) {
}() }()
defer func() { defer func() {
err := srv1.Close() err := srv1.Close(ctx)
if err != nil { if err != nil {
t.Errorf("failed to close server: %s", err) t.Errorf("failed to close server: %s", err)
} }
@ -176,7 +176,7 @@ func TestForeginConnClose(t *testing.T) {
}() }()
defer func() { defer func() {
err := srv2.Close() err := srv2.Close(ctx)
if err != nil { if err != nil {
t.Errorf("failed to close server: %s", err) t.Errorf("failed to close server: %s", err)
} }
@ -227,7 +227,7 @@ func TestForeginAutoClose(t *testing.T) {
defer func() { defer func() {
t.Logf("closing server 1.") t.Logf("closing server 1.")
err := srv1.Close() err := srv1.Close(ctx)
if err != nil { if err != nil {
t.Errorf("failed to close server: %s", err) t.Errorf("failed to close server: %s", err)
} }
@ -255,7 +255,7 @@ func TestForeginAutoClose(t *testing.T) {
}() }()
defer func() { defer func() {
t.Logf("closing server 2.") t.Logf("closing server 2.")
err := srv2.Close() err := srv2.Close(ctx)
if err != nil { if err != nil {
t.Errorf("failed to close server: %s", err) t.Errorf("failed to close server: %s", err)
} }
@ -317,7 +317,7 @@ func TestAutoReconnect(t *testing.T) {
}() }()
defer func() { defer func() {
err := srv.Close() err := srv.Close(ctx)
if err != nil { if err != nil {
log.Errorf("failed to close server: %s", err) log.Errorf("failed to close server: %s", err)
} }
@ -381,7 +381,7 @@ func TestNotifierDoubleAdd(t *testing.T) {
}() }()
defer func() { defer func() {
err := srv1.Close() err := srv1.Close(ctx)
if err != nil { if err != nil {
t.Errorf("failed to close server: %s", err) t.Errorf("failed to close server: %s", err)
} }

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
@ -175,7 +176,10 @@ func execute(cmd *cobra.Command, args []string) error {
// it will block until exit signal // it will block until exit signal
waitForExitSignal() waitForExitSignal()
err = srv.Close() ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
err = srv.Close(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to close server: %s", err) return fmt.Errorf("failed to close server: %s", err)
} }

View File

@ -1,8 +1,11 @@
package listener package listener
import "net" import (
"context"
"net"
)
type Listener interface { type Listener interface {
Listen(func(conn net.Conn)) error Listen(func(conn net.Conn)) error
Close() error Close(ctx context.Context) error
} }

View File

@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"nhooyr.io/websocket" "nhooyr.io/websocket"
@ -47,14 +46,11 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
return err return err
} }
func (l *Listener) Close() error { func (l *Listener) Close(ctx context.Context) error {
if l.server == nil { if l.server == nil {
return nil return nil
} }
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
log.Infof("stop WS listener") log.Infof("stop WS listener")
if err := l.server.Shutdown(ctx); err != nil { if err := l.server.Shutdown(ctx); err != nil {
return fmt.Errorf("server shutdown failed: %v", err) return fmt.Errorf("server shutdown failed: %v", err)

View File

@ -3,7 +3,6 @@ package server
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/metric"
@ -61,15 +60,12 @@ func (r *Server) Listen(cfg ListenerConfig) error {
// Close stops the relay server. If there are active connections, they will be closed gracefully. In case of a timeout, // Close stops the relay server. If there are active connections, they will be closed gracefully. In case of a timeout,
// the connections will be forcefully closed. // the connections will be forcefully closed.
func (r *Server) Close() (err error) { func (r *Server) Close(ctx context.Context) (err error) {
// stop service new connections // stop service new connections
if r.wSListener != nil { if r.wSListener != nil {
err = r.wSListener.Close() err = r.wSListener.Close(ctx)
} }
// close accepted connections gracefully
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
r.relay.Close(ctx) r.relay.Close(ctx)
return return
} }

View File

@ -85,7 +85,7 @@ func transfer(t *testing.T, testData []byte, peerPairs int) {
}() }()
defer func() { defer func() {
err := srv.Close() err := srv.Close(ctx)
if err != nil { if err != nil {
t.Errorf("failed to close server: %s", err) t.Errorf("failed to close server: %s", err)
} }