From 9e041b7f824d26de2a4dd8207acd83e2db5cc08c Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Sat, 14 Sep 2024 10:27:37 +0200 Subject: [PATCH] Fix blocked net.Conn Close call (#2600) --- relay/client/client.go | 27 +++++++++++-- relay/client/client_test.go | 81 +++++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 3 deletions(-) diff --git a/relay/client/client.go b/relay/client/client.go index 6560c81e1..7ff17944f 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -63,32 +63,53 @@ type connContainer struct { messages chan Msg msgChanLock sync.Mutex closed bool // flag to check if channel is closed + ctx context.Context + cancel context.CancelFunc } func newConnContainer(conn *Conn, messages chan Msg) *connContainer { + ctx, cancel := context.WithCancel(context.Background()) + return &connContainer{ conn: conn, messages: messages, + ctx: ctx, + cancel: cancel, } } func (cc *connContainer) writeMsg(msg Msg) { cc.msgChanLock.Lock() defer cc.msgChanLock.Unlock() + if cc.closed { + msg.Free() return } - cc.messages <- msg + + select { + case cc.messages <- msg: + case <-cc.ctx.Done(): + msg.Free() + } } func (cc *connContainer) close() { + cc.cancel() + cc.msgChanLock.Lock() defer cc.msgChanLock.Unlock() + if cc.closed { return } - close(cc.messages) + cc.closed = true + close(cc.messages) + + for msg := range cc.messages { + msg.Free() + } } // Client is a client for the relay server. It is responsible for establishing a connection to the relay server and @@ -464,8 +485,8 @@ func (c *Client) closeConn(connReference *Conn, id string) error { if container.conn != connReference { return fmt.Errorf("conn reference mismatch") } - container.close() delete(c.conns, id) + container.close() return nil } diff --git a/relay/client/client_test.go b/relay/client/client_test.go index b7f1a63ca..ef28203e9 100644 --- a/relay/client/client_test.go +++ b/relay/client/client_test.go @@ -618,6 +618,87 @@ func TestCloseByClient(t *testing.T) { } } +func TestCloseNotDrainedChannel(t *testing.T) { + ctx := context.Background() + idAlice := "alice" + idBob := "bob" + srvCfg := server.ListenerConfig{Address: serverListenAddr} + srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + errChan := make(chan error, 1) + go func() { + err := srv.Listen(srvCfg) + if err != nil { + errChan <- err + } + }() + + defer func() { + err := srv.Shutdown(ctx) + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + // wait for servers to start + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("failed to start server: %s", err) + } + + clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice) + err = clientAlice.Connect() + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + defer func() { + err := clientAlice.Close() + if err != nil { + t.Errorf("failed to close Alice client: %s", err) + } + }() + + clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob) + err = clientBob.Connect() + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + defer func() { + err := clientBob.Close() + if err != nil { + t.Errorf("failed to close Bob client: %s", err) + } + }() + + connAliceToBob, err := clientAlice.OpenConn(idBob) + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + + connBobToAlice, err := clientBob.OpenConn(idAlice) + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + + payload := "hello bob, I am alice" + // the internal channel buffer size is 2. So we should overflow it + for i := 0; i < 5; i++ { + _, err = connAliceToBob.Write([]byte(payload)) + if err != nil { + t.Fatalf("failed to write to channel: %s", err) + } + + } + + // wait for delivery + time.Sleep(1 * time.Second) + err = connBobToAlice.Close() + if err != nil { + t.Errorf("failed to close channel: %s", err) + } +} + func waitForServerToStart(errChan chan error) error { select { case err := <-errChan: