Fix reference check

This commit is contained in:
Zoltán Papp 2024-07-18 13:16:50 +02:00
parent f3282bea80
commit 894d68adf2
2 changed files with 14 additions and 13 deletions

View File

@ -384,20 +384,18 @@ func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppe
return true
}
// todo check by reference too, the id is not enought because the id come from the outer conn
func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) {
func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload []byte) (int, error) {
c.mu.Lock()
// conn, ok := c.conns[id]
_, ok := c.conns[id]
conn, ok := c.conns[id]
c.mu.Unlock()
if !ok {
return 0, io.EOF
}
/*
if conn != clientRef {
return 0, io.EOF
}
*/
if conn.conn != connReference {
return 0, io.EOF
}
// todo: use buffer pool instead of create new transport msg.
msg, err := messages.MarshalTransportMsg(dstID, payload)
if err != nil {
@ -439,8 +437,7 @@ func (c *Client) closeAllConns() {
c.conns = make(map[string]*connContainer)
}
// todo check by reference too, the id is not enought because the id come from the outer conn
func (c *Client) closeConn(id string) error {
func (c *Client) closeConn(connReference *Conn, id string) error {
c.mu.Lock()
defer c.mu.Unlock()
@ -448,6 +445,10 @@ func (c *Client) closeConn(id string) error {
if !ok {
return fmt.Errorf("connection already closed")
}
if container.conn != connReference {
return fmt.Errorf("conn reference mismatch")
}
container.close()
delete(c.conns, id)

View File

@ -27,7 +27,7 @@ func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan
}
func (c *Conn) Write(p []byte) (n int, err error) {
return c.client.writeTo(c.dstStringID, c.dstID, p)
return c.client.writeTo(c, c.dstStringID, c.dstID, p)
}
func (c *Conn) Read(b []byte) (n int, err error) {
@ -42,7 +42,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
}
func (c *Conn) Close() error {
return c.client.closeConn(c.dstStringID)
return c.client.closeConn(c, c.dstStringID)
}
func (c *Conn) LocalAddr() net.Addr {