Pass the ctx to the close function

This commit is contained in:
Zoltán Papp 2024-08-21 16:05:04 +02:00
parent 7633cca3b1
commit a208e7999c
7 changed files with 32 additions and 33 deletions

View File

@ -47,7 +47,7 @@ func TestClient(t *testing.T) {
}()
defer func() {
err := srv.Close()
err := srv.Close(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
@ -135,14 +135,14 @@ func TestRegistration(t *testing.T) {
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect()
if err != nil {
_ = srv.Close()
_ = srv.Close(ctx)
t.Fatalf("failed to connect to server: %s", err)
}
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close conn: %s", err)
}
err = srv.Close()
err = srv.Close(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
@ -202,7 +202,7 @@ func TestEcho(t *testing.T) {
}()
defer func() {
err := srv.Close()
err := srv.Close(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
@ -292,7 +292,7 @@ func TestBindToUnavailabePeer(t *testing.T) {
defer func() {
log.Infof("closing server")
err := srv.Close()
err := srv.Close(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
@ -338,7 +338,7 @@ func TestBindReconnect(t *testing.T) {
defer func() {
log.Infof("closing server")
err := srv.Close()
err := srv.Close(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
@ -429,7 +429,7 @@ func TestCloseConn(t *testing.T) {
defer func() {
log.Infof("closing server")
err := srv.Close()
err := srv.Close(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
@ -485,7 +485,7 @@ func TestCloseRelayConn(t *testing.T) {
}()
defer func() {
err := srv.Close()
err := srv.Close(ctx)
if err != nil {
log.Errorf("failed to close server: %s", err)
}
@ -556,7 +556,7 @@ func TestCloseByServer(t *testing.T) {
close(disconnected)
})
err = srv1.Close()
err = srv1.Close(ctx)
if err != nil {
t.Fatalf("failed to close server: %s", err)
}
@ -612,7 +612,7 @@ func TestCloseByClient(t *testing.T) {
t.Errorf("unexpected opening connection to closed server")
}
err = srv.Close()
err = srv.Close(ctx)
if err != nil {
t.Fatalf("failed to close server: %s", err)
}

View File

@ -38,7 +38,7 @@ func TestForeignConn(t *testing.T) {
}()
defer func() {
err := srv1.Close()
err := srv1.Close(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
@ -64,7 +64,7 @@ func TestForeignConn(t *testing.T) {
}()
defer func() {
err := srv2.Close()
err := srv2.Close(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
@ -150,7 +150,7 @@ func TestForeginConnClose(t *testing.T) {
}()
defer func() {
err := srv1.Close()
err := srv1.Close(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
@ -176,7 +176,7 @@ func TestForeginConnClose(t *testing.T) {
}()
defer func() {
err := srv2.Close()
err := srv2.Close(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
@ -227,7 +227,7 @@ func TestForeginAutoClose(t *testing.T) {
defer func() {
t.Logf("closing server 1.")
err := srv1.Close()
err := srv1.Close(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
@ -255,7 +255,7 @@ func TestForeginAutoClose(t *testing.T) {
}()
defer func() {
t.Logf("closing server 2.")
err := srv2.Close()
err := srv2.Close(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
@ -317,7 +317,7 @@ func TestAutoReconnect(t *testing.T) {
}()
defer func() {
err := srv.Close()
err := srv.Close(ctx)
if err != nil {
log.Errorf("failed to close server: %s", err)
}
@ -381,7 +381,7 @@ func TestNotifierDoubleAdd(t *testing.T) {
}()
defer func() {
err := srv1.Close()
err := srv1.Close(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}

View File

@ -1,6 +1,7 @@
package main
import (
"context"
"crypto/tls"
"errors"
"fmt"
@ -175,7 +176,10 @@ func execute(cmd *cobra.Command, args []string) error {
// it will block until exit signal
waitForExitSignal()
err = srv.Close()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
err = srv.Close(ctx)
if err != nil {
return fmt.Errorf("failed to close server: %s", err)
}

View File

@ -1,8 +1,11 @@
package listener
import "net"
import (
"context"
"net"
)
type Listener interface {
Listen(func(conn net.Conn)) error
Close() error
Close(ctx context.Context) error
}

View File

@ -7,7 +7,6 @@ import (
"fmt"
"net"
"net/http"
"time"
log "github.com/sirupsen/logrus"
"nhooyr.io/websocket"
@ -47,14 +46,11 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
return err
}
func (l *Listener) Close() error {
func (l *Listener) Close(ctx context.Context) error {
if l.server == nil {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
log.Infof("stop WS listener")
if err := l.server.Shutdown(ctx); err != nil {
return fmt.Errorf("server shutdown failed: %v", err)

View File

@ -3,7 +3,6 @@ package server
import (
"context"
"crypto/tls"
"time"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
@ -61,15 +60,12 @@ func (r *Server) Listen(cfg ListenerConfig) error {
// Close stops the relay server. If there are active connections, they will be closed gracefully. In case of a timeout,
// the connections will be forcefully closed.
func (r *Server) Close() (err error) {
func (r *Server) Close(ctx context.Context) (err error) {
// stop service new connections
if r.wSListener != nil {
err = r.wSListener.Close()
err = r.wSListener.Close(ctx)
}
// close accepted connections gracefully
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
r.relay.Close(ctx)
return
}

View File

@ -85,7 +85,7 @@ func transfer(t *testing.T, testData []byte, peerPairs int) {
}()
defer func() {
err := srv.Close()
err := srv.Close(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}