Fix blocked net.Conn Close call (#2600)

This commit is contained in:
Zoltan Papp 2024-09-14 10:27:37 +02:00 committed by GitHub
parent b4c8cf0a67
commit 9e041b7f82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 105 additions and 3 deletions

View File

@ -63,32 +63,53 @@ type connContainer struct {
messages chan Msg messages chan Msg
msgChanLock sync.Mutex msgChanLock sync.Mutex
closed bool // flag to check if channel is closed closed bool // flag to check if channel is closed
ctx context.Context
cancel context.CancelFunc
} }
func newConnContainer(conn *Conn, messages chan Msg) *connContainer { func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
ctx, cancel := context.WithCancel(context.Background())
return &connContainer{ return &connContainer{
conn: conn, conn: conn,
messages: messages, messages: messages,
ctx: ctx,
cancel: cancel,
} }
} }
func (cc *connContainer) writeMsg(msg Msg) { func (cc *connContainer) writeMsg(msg Msg) {
cc.msgChanLock.Lock() cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock() defer cc.msgChanLock.Unlock()
if cc.closed { if cc.closed {
msg.Free()
return return
} }
cc.messages <- msg
select {
case cc.messages <- msg:
case <-cc.ctx.Done():
msg.Free()
}
} }
func (cc *connContainer) close() { func (cc *connContainer) close() {
cc.cancel()
cc.msgChanLock.Lock() cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock() defer cc.msgChanLock.Unlock()
if cc.closed { if cc.closed {
return return
} }
close(cc.messages)
cc.closed = true cc.closed = true
close(cc.messages)
for msg := range cc.messages {
msg.Free()
}
} }
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and // Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
@ -464,8 +485,8 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
if container.conn != connReference { if container.conn != connReference {
return fmt.Errorf("conn reference mismatch") return fmt.Errorf("conn reference mismatch")
} }
container.close()
delete(c.conns, id) delete(c.conns, id)
container.close()
return nil return nil
} }

View File

@ -618,6 +618,87 @@ func TestCloseByClient(t *testing.T) {
} }
} }
func TestCloseNotDrainedChannel(t *testing.T) {
ctx := context.Background()
idAlice := "alice"
idBob := "bob"
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
err = clientAlice.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err := clientAlice.Close()
if err != nil {
t.Errorf("failed to close Alice client: %s", err)
}
}()
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob)
err = clientBob.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err := clientBob.Close()
if err != nil {
t.Errorf("failed to close Bob client: %s", err)
}
}()
connAliceToBob, err := clientAlice.OpenConn(idBob)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
connBobToAlice, err := clientBob.OpenConn(idAlice)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
payload := "hello bob, I am alice"
// the internal channel buffer size is 2. So we should overflow it
for i := 0; i < 5; i++ {
_, err = connAliceToBob.Write([]byte(payload))
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
}
// wait for delivery
time.Sleep(1 * time.Second)
err = connBobToAlice.Close()
if err != nil {
t.Errorf("failed to close channel: %s", err)
}
}
func waitForServerToStart(errChan chan error) error { func waitForServerToStart(errChan chan error) error {
select { select {
case err := <-errChan: case err := <-errChan: