diff --git a/relay/client/client.go b/relay/client/client.go index b2235ec46..130cb8592 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net" + "sync" "time" log "github.com/sirupsen/logrus" @@ -18,9 +19,14 @@ const ( serverResponseTimeout = 8 * time.Second ) +type bufMsg struct { + bufPtr *[]byte + buf []byte +} + type connContainer struct { conn *Conn - messages chan []byte + messages chan bufMsg } // Client Todo: @@ -31,9 +37,11 @@ type Client struct { channelsPending map[string]chan net.Conn // todo: protect map with mutex channels map[uint16]*connContainer + msgPool sync.Pool relayConn net.Conn relayConnState bool + mu sync.Mutex } func NewClient(serverAddress, peerID string) *Client { @@ -42,10 +50,18 @@ func NewClient(serverAddress, peerID string) *Client { peerID: peerID, channelsPending: make(map[string]chan net.Conn), channels: make(map[uint16]*connContainer), + msgPool: sync.Pool{ + New: func() any { + buf := make([]byte, bufferSize) + return &buf + }, + }, } } func (c *Client) Connect() error { + c.mu.Lock() + defer c.mu.Unlock() conn, err := udp.Dial(c.serverAddress) if err != nil { return err @@ -62,14 +78,23 @@ func (c *Client) Connect() error { return err } + err = c.relayConn.SetReadDeadline(time.Time{}) + if err != nil { + log.Errorf("failed to reset read deadline: %s", err) + return err + } + c.relayConnState = true go c.readLoop() return nil } func (c *Client) BindChannel(remotePeerID string) (net.Conn, error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.relayConn == nil { - return nil, fmt.Errorf("client not connected") + return nil, fmt.Errorf("client not connected to the relay server") } bindSuccessChan := make(chan net.Conn, 1) @@ -92,6 +117,9 @@ func (c *Client) BindChannel(remotePeerID string) (net.Conn, error) { } func (c *Client) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if !c.relayConnState { return nil } @@ -146,22 +174,20 @@ func (c *Client) readLoop() { log := log.WithField("client_id", c.peerID) var errExit error var n int - err := c.relayConn.SetReadDeadline(time.Time{}) - if err != nil { - log.Errorf("failed to set read deadline: %s", err) - return - } for { - buf := make([]byte, bufferSize) // todo optimise buffer size, use pool + bufPtr := c.msgPool.Get().(*[]byte) + buf := *bufPtr n, errExit = c.relayConn.Read(buf) if errExit != nil { log.Debugf("failed to read message from relay server: %s", errExit) + c.freeBuf(bufPtr) break } msgType, err := messages.DetermineServerMsgType(buf[:n]) if err != nil { log.Errorf("failed to determine message type: %s", err) + c.freeBuf(bufPtr) continue } @@ -173,21 +199,26 @@ func (c *Client) readLoop() { } else { c.handleBindResponse(channelId, peerId) } + c.freeBuf(bufPtr) continue case messages.MsgTypeTransport: channelId, err := messages.UnmarshalTransportID(buf[:n]) if err != nil { log.Errorf("failed to parse transport message: %v", err) + c.freeBuf(bufPtr) continue } container, ok := c.channels[channelId] if !ok { log.Errorf("unexpected transport message for channel: %d", channelId) + c.freeBuf(bufPtr) return } - container.messages <- buf[:n] - + container.messages <- bufMsg{ + bufPtr, + buf[:n], + } } } @@ -205,7 +236,7 @@ func (c *Client) handleBindResponse(channelId uint16, peerId string) { } delete(c.channelsPending, peerId) - messageBuffer := make(chan []byte, 2) + messageBuffer := make(chan bufMsg, 2) conn := NewConn(c, channelId, c.generateConnReaderFN(messageBuffer)) c.channels[channelId] = &connContainer{ @@ -226,21 +257,26 @@ func (c *Client) writeTo(channelID uint16, payload []byte) (int, error) { return n, err } -func (c *Client) generateConnReaderFN(messageBufferChan chan []byte) func(b []byte) (n int, err error) { +func (c *Client) generateConnReaderFN(messageBufferChan chan bufMsg) func(b []byte) (n int, err error) { return func(b []byte) (n int, err error) { select { - case msg, ok := <-messageBufferChan: + case bufMsg, ok := <-messageBufferChan: if !ok { return 0, io.EOF } - payload, err := messages.UnmarshalTransportPayload(msg) + payload, err := messages.UnmarshalTransportPayload(bufMsg.buf) if err != nil { return 0, err } n = copy(b, payload) + c.freeBuf(bufMsg.bufPtr) } return n, nil } } + +func (c *Client) freeBuf(ptr *[]byte) { + c.msgPool.Put(ptr) +}