diff --git a/relay/client/client.go b/relay/client/client.go index d972207ea..3f689e6dd 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -250,13 +250,6 @@ func (c *Client) connect() error { } func (c *Client) handShake() error { - defer func() { - err := c.relayConn.SetReadDeadline(time.Time{}) - if err != nil { - log.Errorf("failed to reset read deadline: %s", err) - } - }() - msg, err := messages.MarshalHelloMsg(c.hashedID) if err != nil { log.Errorf("failed to marshal hello message: %s", err) @@ -267,15 +260,8 @@ func (c *Client) handShake() error { log.Errorf("failed to send hello message: %s", err) return err } - - err = c.relayConn.SetReadDeadline(time.Now().Add(serverResponseTimeout)) - if err != nil { - log.Errorf("failed to set read deadline: %s", err) - return err - } - - buf := make([]byte, 1500) // todo: optimise buffer size - n, err := c.relayConn.Read(buf) + buf := make([]byte, messages.MaxHandshakeSize) + n, err := c.readWithTimeout(buf) if err != nil { log.Errorf("failed to read hello response: %s", err) return err @@ -391,6 +377,29 @@ func (c *Client) closeAllConns() { c.conns = make(map[string]*connContainer) } +func (c *Client) readWithTimeout(buf []byte) (int, error) { + ctx, cancel := context.WithTimeout(c.parentCtx, serverResponseTimeout) + defer cancel() + + readDone := make(chan struct{}) + var ( + n int + err error + ) + + go func() { + n, err = c.relayConn.Read(buf) + close(readDone) + }() + + select { + case <-ctx.Done(): + return 0, fmt.Errorf("read operation timed out") + case <-readDone: + return n, err + } +} + // todo check by reference too, the id is not enought because the id come from the outer conn func (c *Client) closeConn(id string) error { c.mu.Lock() diff --git a/relay/client/dialer/wsnhooyr/client_conn.go b/relay/client/dialer/wsnhooyr/client_conn.go index 0aa995286..219425460 100644 --- a/relay/client/dialer/wsnhooyr/client_conn.go +++ b/relay/client/dialer/wsnhooyr/client_conn.go @@ -51,27 +51,14 @@ func (c *Conn) LocalAddr() net.Addr { } func (c *Conn) SetReadDeadline(t time.Time) error { - // todo: implement me return nil } func (c *Conn) SetWriteDeadline(t time.Time) error { - // todo: implement me return nil } func (c *Conn) SetDeadline(t time.Time) error { - // todo: implement me - errR := c.SetReadDeadline(t) - errW := c.SetWriteDeadline(t) - - if errR != nil { - return errR - } - - if errW != nil { - return errW - } return nil } diff --git a/relay/client/manager.go b/relay/client/manager.go index 42cbcd280..ae44be4b9 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -194,7 +194,6 @@ func (m *Manager) isForeignServer(address string) (bool, error) { if err != nil { return false, fmt.Errorf("relay client not connected") } - log.Debugf("check if foreign server: %s != %s", rAddr.String(), address) return rAddr.String() != address, nil } diff --git a/relay/messages/message.go b/relay/messages/message.go index 92b8bced7..d2a6d46d7 100644 --- a/relay/messages/message.go +++ b/relay/messages/message.go @@ -15,6 +15,8 @@ const ( headerSizeTransport = 1 + IDSize // 1 byte for msg type, IDSize for peerID headerSizeHello = 1 + 4 + IDSize // 1 byte for msg type, 4 byte for magic header, IDSize for peerID + + MaxHandshakeSize = 90 ) var ( diff --git a/relay/server/listener/wsnhooyr/conn.go b/relay/server/listener/wsnhooyr/conn.go index b52e5d082..7a96c6529 100644 --- a/relay/server/listener/wsnhooyr/conn.go +++ b/relay/server/listener/wsnhooyr/conn.go @@ -13,6 +13,10 @@ import ( "nhooyr.io/websocket" ) +const ( + writeTimeout = 10 * time.Second +) + type Conn struct { *websocket.Conn lAddr *net.TCPAddr @@ -50,8 +54,14 @@ func (c *Conn) Read(b []byte) (n int, err error) { return n, err } +// Write writes a binary message with the given payload. +// It does not block until fill the internal buffer. +// If the buffer filled up, wait until the buffer is drained or timeout. func (c *Conn) Write(b []byte) (int, error) { - err := c.Conn.Write(c.ctx, websocket.MessageBinary, b) + ctx, ctxCancel := context.WithTimeout(c.ctx, writeTimeout) + defer ctxCancel() + + err := c.Conn.Write(ctx, websocket.MessageBinary, b) return len(b), err } diff --git a/relay/server/listener/wsnhooyr/listener.go b/relay/server/listener/wsnhooyr/listener.go index e47a60b47..b6bcd12c0 100644 --- a/relay/server/listener/wsnhooyr/listener.go +++ b/relay/server/listener/wsnhooyr/listener.go @@ -29,7 +29,6 @@ func NewListener(address string) listener.Listener { } } -// Listen todo: prevent multiple call func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { l.acceptFn = acceptFn mux := http.NewServeMux() diff --git a/relay/server/server.go b/relay/server/server.go index c7085c3d1..f52f7eab8 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -16,8 +16,7 @@ import ( ) const ( - bufferSize = 8820 - maxHandshakeSize = 90 + bufferSize = 8820 ) type Server struct { @@ -135,7 +134,6 @@ func (r *Server) accept(conn net.Conn) { peer.Log.Errorf("failed to update transport message: %s", err) continue } - peer.Log.Infof("write transport msg!!!!") _, err = dp.conn.Write(msg) if err != nil { peer.Log.Errorf("failed to write transport message to: %s", dp.String()) @@ -168,7 +166,7 @@ func (r *Server) sendCloseMsgs() { } func handShake(conn net.Conn) (*Peer, error) { - buf := make([]byte, maxHandshakeSize) + buf := make([]byte, messages.MaxHandshakeSize) n, err := conn.Read(buf) if err != nil { log.Errorf("failed to read message: %s", err)