diff --git a/src/models/server/server.go b/src/models/server/server.go index 8167f513..374707cc 100644 --- a/src/models/server/server.go +++ b/src/models/server/server.go @@ -384,6 +384,7 @@ func (p *ProxyServer) getWorkConn() (workConn *conn.Conn, err error) { err = fmt.Errorf("ProxyName [%s], no work connections available, control is closing", p.Name) return } + log.Debug("ProxyName [%s], get work connection from pool", p.Name) default: // no work connections available in the poll, send message to frpc to get more p.ctlMsgChan <- 1 diff --git a/src/utils/conn/conn.go b/src/utils/conn/conn.go index 4b6eb157..ee80f33c 100644 --- a/src/utils/conn/conn.go +++ b/src/utils/conn/conn.go @@ -16,6 +16,7 @@ package conn import ( "bufio" + "bytes" "encoding/base64" "fmt" "io" @@ -25,6 +26,8 @@ import ( "strings" "sync" "time" + + "github.com/fatedier/frp/src/utils/pool" ) type Listener struct { @@ -61,11 +64,7 @@ func Listen(bindAddr string, bindPort int64) (l *Listener, err error) { continue } - c := &Conn{ - TcpConn: conn, - closeFlag: false, - } - c.Reader = bufio.NewReader(c.TcpConn) + c := NewConn(conn) l.accept <- c } }() @@ -95,20 +94,23 @@ func (l *Listener) Close() error { type Conn struct { TcpConn net.Conn Reader *bufio.Reader + buffer *bytes.Buffer closeFlag bool - mutex sync.RWMutex + + mutex sync.RWMutex } func NewConn(conn net.Conn) (c *Conn) { - c = &Conn{} - c.TcpConn = conn + c = &Conn{ + TcpConn: conn, + buffer: nil, + closeFlag: false, + } c.Reader = bufio.NewReader(c.TcpConn) - c.closeFlag = false - return c + return } func ConnectServer(addr string) (c *Conn, err error) { - c = &Conn{} servertAddr, err := net.ResolveTCPAddr("tcp", addr) if err != nil { return @@ -117,9 +119,7 @@ func ConnectServer(addr string) (c *Conn, err error) { if err != nil { return } - c.TcpConn = conn - c.Reader = bufio.NewReader(c.TcpConn) - c.closeFlag = false + c = NewConn(conn) return c, nil } @@ -185,7 +185,23 @@ func (c *Conn) GetLocalAddr() (addr string) { } func (c *Conn) Read(p []byte) (n int, err error) { - n, err = c.Reader.Read(p) + c.mutex.RLock() + if c.buffer == nil { + c.mutex.RUnlock() + return c.Reader.Read(p) + } + c.mutex.RUnlock() + + n, err = c.buffer.Read(p) + if err == io.EOF { + c.mutex.Lock() + c.buffer = nil + c.mutex.Unlock() + var n2 int + n2, err = c.Reader.Read(p[n:]) + + n += n2 + } return } @@ -212,6 +228,16 @@ func (c *Conn) WriteString(content string) (err error) { return err } +func (c *Conn) AppendReaderBuffer(content []byte) { + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.buffer == nil { + c.buffer = bytes.NewBuffer(make([]byte, 0, 2048)) + } + c.buffer.Write(content) +} + func (c *Conn) SetDeadline(t time.Time) error { return c.TcpConn.SetDeadline(t) } @@ -238,22 +264,36 @@ func (c *Conn) IsClosed() (closeFlag bool) { } // when you call this function, you should make sure that -// remote client won't send any bytes to this socket +// no bytes were read before func (c *Conn) CheckClosed() bool { c.mutex.RLock() if c.closeFlag { + c.mutex.RUnlock() return true } c.mutex.RUnlock() + tmp := pool.GetBuf(2048) + defer pool.PutBuf(tmp) err := c.TcpConn.SetReadDeadline(time.Now().Add(time.Millisecond)) if err != nil { c.Close() return true } - var tmp []byte = make([]byte, 1) - _, err = c.TcpConn.Read(tmp) + n, err := c.TcpConn.Read(tmp) + if err == io.EOF { + return true + } + + var tmp2 []byte = make([]byte, 1) + err = c.TcpConn.SetReadDeadline(time.Now().Add(time.Millisecond)) + if err != nil { + c.Close() + return true + } + + n2, err := c.TcpConn.Read(tmp2) if err == io.EOF { return true } @@ -263,5 +303,12 @@ func (c *Conn) CheckClosed() bool { c.Close() return true } + + if n > 0 { + c.AppendReaderBuffer(tmp[:n]) + } + if n2 > 0 { + c.AppendReaderBuffer(tmp2[:n2]) + } return false } diff --git a/src/utils/vhost/vhost.go b/src/utils/vhost/vhost.go index 4a8dc524..e118afbc 100644 --- a/src/utils/vhost/vhost.go +++ b/src/utils/vhost/vhost.go @@ -205,16 +205,18 @@ func (sc *sharedConn) Read(p []byte) (n int, err error) { sc.Unlock() return sc.Conn.Read(p) } + sc.Unlock() n, err = sc.buff.Read(p) if err == io.EOF { + sc.Lock() sc.buff = nil + sc.Unlock() var n2 int n2, err = sc.Conn.Read(p[n:]) n += n2 } - sc.Unlock() return }