mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-24 17:13:30 +01:00
Fix blocked net.Conn Close call (#2600)
This commit is contained in:
parent
b4c8cf0a67
commit
9e041b7f82
@ -63,32 +63,53 @@ type connContainer struct {
|
|||||||
messages chan Msg
|
messages chan Msg
|
||||||
msgChanLock sync.Mutex
|
msgChanLock sync.Mutex
|
||||||
closed bool // flag to check if channel is closed
|
closed bool // flag to check if channel is closed
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
|
func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
return &connContainer{
|
return &connContainer{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
messages: messages,
|
messages: messages,
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cc *connContainer) writeMsg(msg Msg) {
|
func (cc *connContainer) writeMsg(msg Msg) {
|
||||||
cc.msgChanLock.Lock()
|
cc.msgChanLock.Lock()
|
||||||
defer cc.msgChanLock.Unlock()
|
defer cc.msgChanLock.Unlock()
|
||||||
|
|
||||||
if cc.closed {
|
if cc.closed {
|
||||||
|
msg.Free()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cc.messages <- msg
|
|
||||||
|
select {
|
||||||
|
case cc.messages <- msg:
|
||||||
|
case <-cc.ctx.Done():
|
||||||
|
msg.Free()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cc *connContainer) close() {
|
func (cc *connContainer) close() {
|
||||||
|
cc.cancel()
|
||||||
|
|
||||||
cc.msgChanLock.Lock()
|
cc.msgChanLock.Lock()
|
||||||
defer cc.msgChanLock.Unlock()
|
defer cc.msgChanLock.Unlock()
|
||||||
|
|
||||||
if cc.closed {
|
if cc.closed {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
close(cc.messages)
|
|
||||||
cc.closed = true
|
cc.closed = true
|
||||||
|
close(cc.messages)
|
||||||
|
|
||||||
|
for msg := range cc.messages {
|
||||||
|
msg.Free()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
|
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
|
||||||
@ -464,8 +485,8 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
|
|||||||
if container.conn != connReference {
|
if container.conn != connReference {
|
||||||
return fmt.Errorf("conn reference mismatch")
|
return fmt.Errorf("conn reference mismatch")
|
||||||
}
|
}
|
||||||
container.close()
|
|
||||||
delete(c.conns, id)
|
delete(c.conns, id)
|
||||||
|
container.close()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -618,6 +618,87 @@ func TestCloseByClient(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCloseNotDrainedChannel(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
idAlice := "alice"
|
||||||
|
idBob := "bob"
|
||||||
|
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||||
|
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create server: %s", err)
|
||||||
|
}
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
err := srv.Listen(srvCfg)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := srv.Shutdown(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close server: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// wait for servers to start
|
||||||
|
if err := waitForServerToStart(errChan); err != nil {
|
||||||
|
t.Fatalf("failed to start server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
|
||||||
|
err = clientAlice.Connect()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := clientAlice.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close Alice client: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob)
|
||||||
|
err = clientBob.Connect()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := clientBob.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close Bob client: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
connAliceToBob, err := clientAlice.OpenConn(idBob)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
connBobToAlice, err := clientBob.OpenConn(idAlice)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := "hello bob, I am alice"
|
||||||
|
// the internal channel buffer size is 2. So we should overflow it
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
_, err = connAliceToBob.Write([]byte(payload))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to write to channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait for delivery
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
err = connBobToAlice.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close channel: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func waitForServerToStart(errChan chan error) error {
|
func waitForServerToStart(errChan chan error) error {
|
||||||
select {
|
select {
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
|
Loading…
Reference in New Issue
Block a user