From 894d68adf22a5405aaae091552f9855bcf40a031 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20Papp?= Date: Thu, 18 Jul 2024 13:16:50 +0200 Subject: [PATCH] Fix reference check --- relay/client/client.go | 23 ++++++++++++----------- relay/client/conn.go | 4 ++-- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/relay/client/client.go b/relay/client/client.go index 1590cfc72..a0d5124dc 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -384,20 +384,18 @@ func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppe return true } -// todo check by reference too, the id is not enought because the id come from the outer conn -func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) { +func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload []byte) (int, error) { c.mu.Lock() - // conn, ok := c.conns[id] - _, ok := c.conns[id] + conn, ok := c.conns[id] c.mu.Unlock() if !ok { return 0, io.EOF } - /* - if conn != clientRef { - return 0, io.EOF - } - */ + + if conn.conn != connReference { + return 0, io.EOF + } + // todo: use buffer pool instead of create new transport msg. msg, err := messages.MarshalTransportMsg(dstID, payload) if err != nil { @@ -439,8 +437,7 @@ func (c *Client) closeAllConns() { c.conns = make(map[string]*connContainer) } -// todo check by reference too, the id is not enought because the id come from the outer conn -func (c *Client) closeConn(id string) error { +func (c *Client) closeConn(connReference *Conn, id string) error { c.mu.Lock() defer c.mu.Unlock() @@ -448,6 +445,10 @@ func (c *Client) closeConn(id string) error { if !ok { return fmt.Errorf("connection already closed") } + + if container.conn != connReference { + return fmt.Errorf("conn reference mismatch") + } container.close() delete(c.conns, id) diff --git a/relay/client/conn.go b/relay/client/conn.go index a3a2fdabb..783b6a660 100644 --- a/relay/client/conn.go +++ b/relay/client/conn.go @@ -27,7 +27,7 @@ func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan } func (c *Conn) Write(p []byte) (n int, err error) { - return c.client.writeTo(c.dstStringID, c.dstID, p) + return c.client.writeTo(c, c.dstStringID, c.dstID, p) } func (c *Conn) Read(b []byte) (n int, err error) { @@ -42,7 +42,7 @@ func (c *Conn) Read(b []byte) (n int, err error) { } func (c *Conn) Close() error { - return c.client.closeConn(c.dstStringID) + return c.client.closeConn(c, c.dstStringID) } func (c *Conn) LocalAddr() net.Addr {