mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-21 18:22:37 +02:00
Fix writing/reading to a closed conn
This commit is contained in:
parent
b4aa7e50f9
commit
645a1f31a7
@ -39,6 +39,7 @@ type Client struct {
|
|||||||
|
|
||||||
relayConn net.Conn
|
relayConn net.Conn
|
||||||
relayConnState bool
|
relayConnState bool
|
||||||
|
wgRelayConn sync.WaitGroup
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,6 +82,9 @@ func (c *Client) Connect() error {
|
|||||||
c.relayConnState = true
|
c.relayConnState = true
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
c.wgRelayConn.Add(1)
|
||||||
|
go c.readLoop()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
<-c.ctx.Done()
|
<-c.ctx.Done()
|
||||||
cErr := c.close()
|
cErr := c.close()
|
||||||
@ -88,18 +92,6 @@ func (c *Client) Connect() error {
|
|||||||
log.Errorf("failed to close relay connection: %s", cErr)
|
log.Errorf("failed to close relay connection: %s", cErr)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
// blocking function
|
|
||||||
c.readLoop()
|
|
||||||
|
|
||||||
c.mu.Lock()
|
|
||||||
|
|
||||||
// close all Conn types
|
|
||||||
for _, container := range c.conns {
|
|
||||||
close(container.messages)
|
|
||||||
}
|
|
||||||
c.conns = make(map[string]*connContainer)
|
|
||||||
|
|
||||||
c.mu.Unlock()
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -114,7 +106,7 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
|
|||||||
hashedID, hashedStringID := messages.HashID(dstPeerID)
|
hashedID, hashedStringID := messages.HashID(dstPeerID)
|
||||||
log.Infof("open connection to peer: %s", hashedStringID)
|
log.Infof("open connection to peer: %s", hashedStringID)
|
||||||
messageBuffer := make(chan Msg, 2)
|
messageBuffer := make(chan Msg, 2)
|
||||||
conn := NewConn(c, hashedID, c.generateConnReaderFN(messageBuffer))
|
conn := NewConn(c, hashedID, hashedStringID, c.generateConnReaderFN(messageBuffer))
|
||||||
|
|
||||||
c.conns[hashedStringID] = &connContainer{
|
c.conns[hashedStringID] = &connContainer{
|
||||||
conn,
|
conn,
|
||||||
@ -140,6 +132,14 @@ func (c *Client) close() error {
|
|||||||
|
|
||||||
err := c.relayConn.Close()
|
err := c.relayConn.Close()
|
||||||
|
|
||||||
|
c.wgRelayConn.Wait()
|
||||||
|
|
||||||
|
// close all Conn types
|
||||||
|
for _, container := range c.conns {
|
||||||
|
close(container.messages)
|
||||||
|
}
|
||||||
|
c.conns = make(map[string]*connContainer)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -191,6 +191,7 @@ func (c *Client) handShake() error {
|
|||||||
func (c *Client) readLoop() {
|
func (c *Client) readLoop() {
|
||||||
defer func() {
|
defer func() {
|
||||||
c.log.Tracef("exit from read loop")
|
c.log.Tracef("exit from read loop")
|
||||||
|
c.wgRelayConn.Done()
|
||||||
}()
|
}()
|
||||||
var errExit error
|
var errExit error
|
||||||
var n int
|
var n int
|
||||||
@ -237,7 +238,14 @@ func (c *Client) readLoop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) writeTo(dstID []byte, payload []byte) (int, error) {
|
func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
_, ok := c.conns[id]
|
||||||
|
if !ok {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
msg := messages.MarshalTransportMsg(dstID, payload)
|
msg := messages.MarshalTransportMsg(dstID, payload)
|
||||||
n, err := c.relayConn.Write(msg)
|
n, err := c.relayConn.Write(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -262,3 +270,17 @@ func (c *Client) generateConnReaderFN(msgChannel chan Msg) func(b []byte) (n int
|
|||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) closeConn(id string) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
conn, ok := c.conns[id]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("connection already closed")
|
||||||
|
}
|
||||||
|
close(conn.messages)
|
||||||
|
delete(c.conns, id)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -6,23 +6,25 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
client *Client
|
client *Client
|
||||||
dstID []byte
|
dstID []byte
|
||||||
readerFn func(b []byte) (n int, err error)
|
dstStringID string
|
||||||
|
readerFn func(b []byte) (n int, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConn(client *Client, dstID []byte, readerFn func(b []byte) (n int, err error)) *Conn {
|
func NewConn(client *Client, dstID []byte, dstStringID string, readerFn func(b []byte) (n int, err error)) *Conn {
|
||||||
c := &Conn{
|
c := &Conn{
|
||||||
client: client,
|
client: client,
|
||||||
dstID: dstID,
|
dstID: dstID,
|
||||||
readerFn: readerFn,
|
dstStringID: dstStringID,
|
||||||
|
readerFn: readerFn,
|
||||||
}
|
}
|
||||||
|
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Write(p []byte) (n int, err error) {
|
func (c *Conn) Write(p []byte) (n int, err error) {
|
||||||
return c.client.writeTo(c.dstID, p)
|
return c.client.writeTo(c.dstStringID, c.dstID, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||||
@ -30,7 +32,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Close() error {
|
func (c *Conn) Close() error {
|
||||||
return nil
|
return c.client.closeConn(c.dstStringID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) LocalAddr() net.Addr {
|
func (c *Conn) LocalAddr() net.Addr {
|
||||||
|
@ -354,3 +354,51 @@ func TestBindReconnect(t *testing.T) {
|
|||||||
t.Errorf("failed to close client: %s", err)
|
t.Errorf("failed to close client: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCloseConn(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
addr := "localhost:1234"
|
||||||
|
srv := server.NewServer()
|
||||||
|
go func() {
|
||||||
|
err := srv.Listen(addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to bind server: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
log.Infof("closing server")
|
||||||
|
err := srv.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close server: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
clientAlice := client.NewClient(ctx, addr, "alice")
|
||||||
|
err := clientAlice.Connect()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := clientAlice.OpenConn("bob")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to bind channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("closing connection")
|
||||||
|
err = conn.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close connection: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = conn.Read(make([]byte, 1))
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("unexpected reading from closed connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = conn.Write([]byte("hello"))
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("unexpected writing from closed connection")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user