diff --git a/relay/client/client.go b/relay/client/client.go index e7492afee..44325e1f9 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -39,6 +39,7 @@ type Client struct { relayConn net.Conn relayConnState bool + wgRelayConn sync.WaitGroup mu sync.Mutex } @@ -81,6 +82,9 @@ func (c *Client) Connect() error { c.relayConnState = true c.mu.Unlock() + c.wgRelayConn.Add(1) + go c.readLoop() + go func() { <-c.ctx.Done() cErr := c.close() @@ -88,18 +92,6 @@ func (c *Client) Connect() error { log.Errorf("failed to close relay connection: %s", cErr) } }() - // blocking function - c.readLoop() - - c.mu.Lock() - - // close all Conn types - for _, container := range c.conns { - close(container.messages) - } - c.conns = make(map[string]*connContainer) - - c.mu.Unlock() return nil } @@ -114,7 +106,7 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) { hashedID, hashedStringID := messages.HashID(dstPeerID) log.Infof("open connection to peer: %s", hashedStringID) messageBuffer := make(chan Msg, 2) - conn := NewConn(c, hashedID, c.generateConnReaderFN(messageBuffer)) + conn := NewConn(c, hashedID, hashedStringID, c.generateConnReaderFN(messageBuffer)) c.conns[hashedStringID] = &connContainer{ conn, @@ -140,6 +132,14 @@ func (c *Client) close() error { err := c.relayConn.Close() + c.wgRelayConn.Wait() + + // close all Conn types + for _, container := range c.conns { + close(container.messages) + } + c.conns = make(map[string]*connContainer) + return err } @@ -191,6 +191,7 @@ func (c *Client) handShake() error { func (c *Client) readLoop() { defer func() { c.log.Tracef("exit from read loop") + c.wgRelayConn.Done() }() var errExit error var n int @@ -237,7 +238,14 @@ func (c *Client) readLoop() { } } -func (c *Client) writeTo(dstID []byte, payload []byte) (int, error) { +func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) { + c.mu.Lock() + _, ok := c.conns[id] + if !ok { + c.mu.Unlock() + return 0, io.EOF + } + c.mu.Unlock() msg := messages.MarshalTransportMsg(dstID, payload) n, err := c.relayConn.Write(msg) if err != nil { @@ -262,3 +270,17 @@ func (c *Client) generateConnReaderFN(msgChannel chan Msg) func(b []byte) (n int return n, nil } } + +func (c *Client) closeConn(id string) error { + c.mu.Lock() + defer c.mu.Unlock() + + conn, ok := c.conns[id] + if !ok { + return fmt.Errorf("connection already closed") + } + close(conn.messages) + delete(c.conns, id) + + return nil +} diff --git a/relay/client/conn.go b/relay/client/conn.go index aea07ff9b..647b0fae4 100644 --- a/relay/client/conn.go +++ b/relay/client/conn.go @@ -6,23 +6,25 @@ import ( ) type Conn struct { - client *Client - dstID []byte - readerFn func(b []byte) (n int, err error) + client *Client + dstID []byte + dstStringID string + readerFn func(b []byte) (n int, err error) } -func NewConn(client *Client, dstID []byte, readerFn func(b []byte) (n int, err error)) *Conn { +func NewConn(client *Client, dstID []byte, dstStringID string, readerFn func(b []byte) (n int, err error)) *Conn { c := &Conn{ - client: client, - dstID: dstID, - readerFn: readerFn, + client: client, + dstID: dstID, + dstStringID: dstStringID, + readerFn: readerFn, } return c } func (c *Conn) Write(p []byte) (n int, err error) { - return c.client.writeTo(c.dstID, p) + return c.client.writeTo(c.dstStringID, c.dstID, p) } func (c *Conn) Read(b []byte) (n int, err error) { @@ -30,7 +32,7 @@ func (c *Conn) Read(b []byte) (n int, err error) { } func (c *Conn) Close() error { - return nil + return c.client.closeConn(c.dstStringID) } func (c *Conn) LocalAddr() net.Addr { diff --git a/relay/test/client_test.go b/relay/test/client_test.go index 1d3748abb..f70048d22 100644 --- a/relay/test/client_test.go +++ b/relay/test/client_test.go @@ -354,3 +354,51 @@ func TestBindReconnect(t *testing.T) { t.Errorf("failed to close client: %s", err) } } + +func TestCloseConn(t *testing.T) { + ctx := context.Background() + + addr := "localhost:1234" + srv := server.NewServer() + go func() { + err := srv.Listen(addr) + if err != nil { + t.Errorf("failed to bind server: %s", err) + } + }() + + defer func() { + log.Infof("closing server") + err := srv.Close() + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + clientAlice := client.NewClient(ctx, addr, "alice") + err := clientAlice.Connect() + if err != nil { + t.Errorf("failed to connect to server: %s", err) + } + + conn, err := clientAlice.OpenConn("bob") + if err != nil { + t.Errorf("failed to bind channel: %s", err) + } + + log.Infof("closing connection") + err = conn.Close() + if err != nil { + t.Errorf("failed to close connection: %s", err) + } + + _, err = conn.Read(make([]byte, 1)) + if err == nil { + t.Errorf("unexpected reading from closed connection") + } + + _, err = conn.Write([]byte("hello")) + if err == nil { + t.Errorf("unexpected writing from closed connection") + } +}