Fix writing/reading to a closed conn

This commit is contained in:
Zoltán Papp 2024-05-27 10:25:08 +02:00
parent b4aa7e50f9
commit 645a1f31a7
3 changed files with 95 additions and 23 deletions

View File

@ -39,6 +39,7 @@ type Client struct {
relayConn net.Conn relayConn net.Conn
relayConnState bool relayConnState bool
wgRelayConn sync.WaitGroup
mu sync.Mutex mu sync.Mutex
} }
@ -81,6 +82,9 @@ func (c *Client) Connect() error {
c.relayConnState = true c.relayConnState = true
c.mu.Unlock() c.mu.Unlock()
c.wgRelayConn.Add(1)
go c.readLoop()
go func() { go func() {
<-c.ctx.Done() <-c.ctx.Done()
cErr := c.close() cErr := c.close()
@ -88,18 +92,6 @@ func (c *Client) Connect() error {
log.Errorf("failed to close relay connection: %s", cErr) log.Errorf("failed to close relay connection: %s", cErr)
} }
}() }()
// blocking function
c.readLoop()
c.mu.Lock()
// close all Conn types
for _, container := range c.conns {
close(container.messages)
}
c.conns = make(map[string]*connContainer)
c.mu.Unlock()
return nil return nil
} }
@ -114,7 +106,7 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
hashedID, hashedStringID := messages.HashID(dstPeerID) hashedID, hashedStringID := messages.HashID(dstPeerID)
log.Infof("open connection to peer: %s", hashedStringID) log.Infof("open connection to peer: %s", hashedStringID)
messageBuffer := make(chan Msg, 2) messageBuffer := make(chan Msg, 2)
conn := NewConn(c, hashedID, c.generateConnReaderFN(messageBuffer)) conn := NewConn(c, hashedID, hashedStringID, c.generateConnReaderFN(messageBuffer))
c.conns[hashedStringID] = &connContainer{ c.conns[hashedStringID] = &connContainer{
conn, conn,
@ -140,6 +132,14 @@ func (c *Client) close() error {
err := c.relayConn.Close() err := c.relayConn.Close()
c.wgRelayConn.Wait()
// close all Conn types
for _, container := range c.conns {
close(container.messages)
}
c.conns = make(map[string]*connContainer)
return err return err
} }
@ -191,6 +191,7 @@ func (c *Client) handShake() error {
func (c *Client) readLoop() { func (c *Client) readLoop() {
defer func() { defer func() {
c.log.Tracef("exit from read loop") c.log.Tracef("exit from read loop")
c.wgRelayConn.Done()
}() }()
var errExit error var errExit error
var n int var n int
@ -237,7 +238,14 @@ func (c *Client) readLoop() {
} }
} }
func (c *Client) writeTo(dstID []byte, payload []byte) (int, error) { func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) {
c.mu.Lock()
_, ok := c.conns[id]
if !ok {
c.mu.Unlock()
return 0, io.EOF
}
c.mu.Unlock()
msg := messages.MarshalTransportMsg(dstID, payload) msg := messages.MarshalTransportMsg(dstID, payload)
n, err := c.relayConn.Write(msg) n, err := c.relayConn.Write(msg)
if err != nil { if err != nil {
@ -262,3 +270,17 @@ func (c *Client) generateConnReaderFN(msgChannel chan Msg) func(b []byte) (n int
return n, nil return n, nil
} }
} }
func (c *Client) closeConn(id string) error {
c.mu.Lock()
defer c.mu.Unlock()
conn, ok := c.conns[id]
if !ok {
return fmt.Errorf("connection already closed")
}
close(conn.messages)
delete(c.conns, id)
return nil
}

View File

@ -6,23 +6,25 @@ import (
) )
type Conn struct { type Conn struct {
client *Client client *Client
dstID []byte dstID []byte
readerFn func(b []byte) (n int, err error) dstStringID string
readerFn func(b []byte) (n int, err error)
} }
func NewConn(client *Client, dstID []byte, readerFn func(b []byte) (n int, err error)) *Conn { func NewConn(client *Client, dstID []byte, dstStringID string, readerFn func(b []byte) (n int, err error)) *Conn {
c := &Conn{ c := &Conn{
client: client, client: client,
dstID: dstID, dstID: dstID,
readerFn: readerFn, dstStringID: dstStringID,
readerFn: readerFn,
} }
return c return c
} }
func (c *Conn) Write(p []byte) (n int, err error) { func (c *Conn) Write(p []byte) (n int, err error) {
return c.client.writeTo(c.dstID, p) return c.client.writeTo(c.dstStringID, c.dstID, p)
} }
func (c *Conn) Read(b []byte) (n int, err error) { func (c *Conn) Read(b []byte) (n int, err error) {
@ -30,7 +32,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
} }
func (c *Conn) Close() error { func (c *Conn) Close() error {
return nil return c.client.closeConn(c.dstStringID)
} }
func (c *Conn) LocalAddr() net.Addr { func (c *Conn) LocalAddr() net.Addr {

View File

@ -354,3 +354,51 @@ func TestBindReconnect(t *testing.T) {
t.Errorf("failed to close client: %s", err) t.Errorf("failed to close client: %s", err)
} }
} }
func TestCloseConn(t *testing.T) {
ctx := context.Background()
addr := "localhost:1234"
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
if err != nil {
t.Errorf("failed to bind server: %s", err)
}
}()
defer func() {
log.Infof("closing server")
err := srv.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
clientAlice := client.NewClient(ctx, addr, "alice")
err := clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
conn, err := clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
log.Infof("closing connection")
err = conn.Close()
if err != nil {
t.Errorf("failed to close connection: %s", err)
}
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Errorf("unexpected reading from closed connection")
}
_, err = conn.Write([]byte("hello"))
if err == nil {
t.Errorf("unexpected writing from closed connection")
}
}