mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-14 17:28:56 +02:00
Fix writing/reading to a closed conn
This commit is contained in:
@ -39,6 +39,7 @@ type Client struct {
|
||||
|
||||
relayConn net.Conn
|
||||
relayConnState bool
|
||||
wgRelayConn sync.WaitGroup
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
@ -81,6 +82,9 @@ func (c *Client) Connect() error {
|
||||
c.relayConnState = true
|
||||
c.mu.Unlock()
|
||||
|
||||
c.wgRelayConn.Add(1)
|
||||
go c.readLoop()
|
||||
|
||||
go func() {
|
||||
<-c.ctx.Done()
|
||||
cErr := c.close()
|
||||
@ -88,18 +92,6 @@ func (c *Client) Connect() error {
|
||||
log.Errorf("failed to close relay connection: %s", cErr)
|
||||
}
|
||||
}()
|
||||
// blocking function
|
||||
c.readLoop()
|
||||
|
||||
c.mu.Lock()
|
||||
|
||||
// close all Conn types
|
||||
for _, container := range c.conns {
|
||||
close(container.messages)
|
||||
}
|
||||
c.conns = make(map[string]*connContainer)
|
||||
|
||||
c.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -114,7 +106,7 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
|
||||
hashedID, hashedStringID := messages.HashID(dstPeerID)
|
||||
log.Infof("open connection to peer: %s", hashedStringID)
|
||||
messageBuffer := make(chan Msg, 2)
|
||||
conn := NewConn(c, hashedID, c.generateConnReaderFN(messageBuffer))
|
||||
conn := NewConn(c, hashedID, hashedStringID, c.generateConnReaderFN(messageBuffer))
|
||||
|
||||
c.conns[hashedStringID] = &connContainer{
|
||||
conn,
|
||||
@ -140,6 +132,14 @@ func (c *Client) close() error {
|
||||
|
||||
err := c.relayConn.Close()
|
||||
|
||||
c.wgRelayConn.Wait()
|
||||
|
||||
// close all Conn types
|
||||
for _, container := range c.conns {
|
||||
close(container.messages)
|
||||
}
|
||||
c.conns = make(map[string]*connContainer)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@ -191,6 +191,7 @@ func (c *Client) handShake() error {
|
||||
func (c *Client) readLoop() {
|
||||
defer func() {
|
||||
c.log.Tracef("exit from read loop")
|
||||
c.wgRelayConn.Done()
|
||||
}()
|
||||
var errExit error
|
||||
var n int
|
||||
@ -237,7 +238,14 @@ func (c *Client) readLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) writeTo(dstID []byte, payload []byte) (int, error) {
|
||||
func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) {
|
||||
c.mu.Lock()
|
||||
_, ok := c.conns[id]
|
||||
if !ok {
|
||||
c.mu.Unlock()
|
||||
return 0, io.EOF
|
||||
}
|
||||
c.mu.Unlock()
|
||||
msg := messages.MarshalTransportMsg(dstID, payload)
|
||||
n, err := c.relayConn.Write(msg)
|
||||
if err != nil {
|
||||
@ -262,3 +270,17 @@ func (c *Client) generateConnReaderFN(msgChannel chan Msg) func(b []byte) (n int
|
||||
return n, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) closeConn(id string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
conn, ok := c.conns[id]
|
||||
if !ok {
|
||||
return fmt.Errorf("connection already closed")
|
||||
}
|
||||
close(conn.messages)
|
||||
delete(c.conns, id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user