From 4f49458af0cd69c203c2185bb10d9b3183bd93b7 Mon Sep 17 00:00:00 2001 From: fatedier Date: Wed, 20 Jul 2016 16:00:35 +0800 Subject: [PATCH 1/2] frp/models/msg: fix a bug if local service write to socket immediately every time accept one user connection, fix #20 --- src/frp/models/msg/process.go | 22 +++++++++------------- src/frp/models/server/server.go | 5 ++--- src/frp/utils/conn/conn.go | 11 ++++++++++- 3 files changed, 21 insertions(+), 17 deletions(-) diff --git a/src/frp/models/msg/process.go b/src/frp/models/msg/process.go index 4c7783bd..43e34ab6 100644 --- a/src/frp/models/msg/process.go +++ b/src/frp/models/msg/process.go @@ -15,12 +15,10 @@ package msg import ( - "bufio" "bytes" "encoding/binary" "fmt" "io" - "net" "sync" "frp/models/config" @@ -61,7 +59,7 @@ func JoinMore(c1 *conn.Conn, c2 *conn.Conn, conf config.BaseConf, needRecord boo defer wait.Done() // we don't care about errors here - pipeEncrypt(from.TcpConn, to.TcpConn, conf, needRecord) + pipeEncrypt(from, to, conf, needRecord) } decryptPipe := func(to *conn.Conn, from *conn.Conn) { @@ -70,7 +68,7 @@ func JoinMore(c1 *conn.Conn, c2 *conn.Conn, conf config.BaseConf, needRecord boo defer wait.Done() // we don't care about errors here - pipeDecrypt(to.TcpConn, from.TcpConn, conf, needRecord) + pipeDecrypt(to, from, conf, needRecord) } wait.Add(2) @@ -106,7 +104,7 @@ func unpkgMsg(data []byte) (int, []byte, []byte) { } // decrypt msg from reader, then write into writer -func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) (err error) { +func pipeDecrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bool) (err error) { laes := new(pcrypto.Pcrypto) key := conf.AuthToken if conf.PrivilegeMode { @@ -119,7 +117,7 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) buf := make([]byte, 5*1024+4) var left, res []byte - var cnt int + var cnt int = -1 // record var flowBytes int64 = 0 @@ -129,13 +127,12 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) }() } - nreader := bufio.NewReader(r) for { // there may be more than 1 package in variable // and we read more bytes if unpkgMsg returns an error var newBuf []byte if cnt < 0 { - n, err := nreader.Read(buf) + n, err := r.Read(buf) if err != nil { return err } @@ -165,7 +162,7 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) } } - _, err = w.Write(res) + _, err = w.WriteBytes(res) if err != nil { return err } @@ -182,7 +179,7 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) } // recvive msg from reader, then encrypt msg into writer -func pipeEncrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) (err error) { +func pipeEncrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bool) (err error) { laes := new(pcrypto.Pcrypto) key := conf.AuthToken if conf.PrivilegeMode { @@ -201,10 +198,9 @@ func pipeEncrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) }() } - nreader := bufio.NewReader(r) buf := make([]byte, 5*1024) for { - n, err := nreader.Read(buf) + n, err := r.Read(buf) if err != nil { return err } @@ -235,7 +231,7 @@ func pipeEncrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) } res = pkgMsg(res) - _, err = w.Write(res) + _, err = w.WriteBytes(res) if err != nil { return err } diff --git a/src/frp/models/server/server.go b/src/frp/models/server/server.go index 139c9899..e69a2793 100644 --- a/src/frp/models/server/server.go +++ b/src/frp/models/server/server.go @@ -154,13 +154,12 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) { } // start another goroutine for join two conns from frpc and user - go func() { + go func(userConn *conn.Conn) { workConn, err := p.getWorkConn() if err != nil { return } - userConn := c // msg will transfer to another without modifying // l means local, r means remote log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", workConn.GetLocalAddr(), workConn.GetRemoteAddr(), @@ -169,7 +168,7 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) { needRecord := true go msg.JoinMore(userConn, workConn, p.BaseConf, needRecord) metric.OpenConnection(p.Name) - }() + }(c) } }(listener) } diff --git a/src/frp/utils/conn/conn.go b/src/frp/utils/conn/conn.go index ed330f68..a3981104 100644 --- a/src/frp/utils/conn/conn.go +++ b/src/frp/utils/conn/conn.go @@ -125,6 +125,11 @@ func (c *Conn) GetLocalAddr() (addr string) { return c.TcpConn.LocalAddr().String() } +func (c *Conn) Read(p []byte) (n int, err error) { + n, err = c.Reader.Read(p) + return +} + func (c *Conn) ReadLine() (buff string, err error) { buff, err = c.Reader.ReadString('\n') if err != nil { @@ -138,10 +143,14 @@ func (c *Conn) ReadLine() (buff string, err error) { return buff, err } +func (c *Conn) WriteBytes(content []byte) (n int, err error) { + n, err = c.TcpConn.Write(content) + return +} + func (c *Conn) Write(content string) (err error) { _, err = c.TcpConn.Write([]byte(content)) return err - } func (c *Conn) SetDeadline(t time.Time) error { From 926d0b74a9a99bed17bf0deaa7e7eb2cc1059acf Mon Sep 17 00:00:00 2001 From: fatedier Date: Wed, 20 Jul 2016 16:33:42 +0800 Subject: [PATCH 2/2] utils/vhost: update TcpConn with bufio.Reader --- src/frp/utils/conn/conn.go | 5 +++++ src/frp/utils/vhost/vhost.go | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/frp/utils/conn/conn.go b/src/frp/utils/conn/conn.go index a3981104..a5a7a335 100644 --- a/src/frp/utils/conn/conn.go +++ b/src/frp/utils/conn/conn.go @@ -117,6 +117,11 @@ func ConnectServer(host string, port int64) (c *Conn, err error) { return c, nil } +func (c *Conn) SetTcpConn(tcpConn net.Conn) { + c.TcpConn = tcpConn + c.Reader = bufio.NewReader(c.TcpConn) +} + func (c *Conn) GetRemoteAddr() (addr string) { return c.TcpConn.RemoteAddr().String() } diff --git a/src/frp/utils/vhost/vhost.go b/src/frp/utils/vhost/vhost.go index ecf080d1..ae672097 100644 --- a/src/frp/utils/vhost/vhost.go +++ b/src/frp/utils/vhost/vhost.go @@ -105,7 +105,7 @@ func (v *VhostMuxer) handle(c *conn.Conn) { if err = sConn.SetDeadline(time.Time{}); err != nil { return } - c.TcpConn = sConn + c.SetTcpConn(sConn) l.accept <- c }