diff --git a/src/cmd/frpc/control.go b/src/cmd/frpc/control.go index 55d87984..01eb8a00 100644 --- a/src/cmd/frpc/control.go +++ b/src/cmd/frpc/control.go @@ -55,15 +55,24 @@ func msgReader(cli *client.ProxyClient, c *conn.Conn, msgSendChan chan interface var heartbeatTimeout bool = false timer := time.AfterFunc(time.Duration(client.HeartBeatTimeout)*time.Second, func() { heartbeatTimeout = true - c.Close() + if c != nil { + c.Close() + } + if cli != nil { + // if it's not udp type, nothing will happen + cli.CloseUdpTunnel() + cli.SetCloseFlag(true) + } log.Error("ProxyName [%s], heartbeatRes from frps timeout", cli.Name) }) defer timer.Stop() for { buf, err := c.ReadLine() - if err == io.EOF || c == nil || c.IsClosed() { + if err == io.EOF || c.IsClosed() { + timer.Stop() c.Close() + cli.SetCloseFlag(true) log.Warn("ProxyName [%s], frps close this control conn!", cli.Name) var delayTime time.Duration = 1 @@ -76,11 +85,14 @@ func msgReader(cli *client.ProxyClient, c *conn.Conn, msgSendChan chan interface msgSendChan = make(chan interface{}, 1024) go heartbeatSender(c, msgSendChan) go msgSender(cli, c, msgSendChan) + cli.SetCloseFlag(false) break } - if delayTime < 60 { + if delayTime < 30 { delayTime = delayTime * 2 + } else { + delayTime = 30 } time.Sleep(delayTime * time.Second) } diff --git a/src/cmd/frps/control.go b/src/cmd/frps/control.go index f6de7aa5..a6cd6a3f 100644 --- a/src/cmd/frps/control.go +++ b/src/cmd/frps/control.go @@ -85,7 +85,9 @@ func controlWorker(c *conn.Conn) { return } } else { - closeFlag = false + if ret == 0 { + closeFlag = false + } return } diff --git a/src/models/client/client.go b/src/models/client/client.go index 2e9dbf92..5c5b2b20 100644 --- a/src/models/client/client.go +++ b/src/models/client/client.go @@ -39,6 +39,9 @@ type ProxyClient struct { udpTunnel *conn.Conn once sync.Once + closeFlag bool + + mutex sync.RWMutex } // if proxy type is udp, keep a tcp connection for transferring udp packages @@ -48,7 +51,7 @@ func (pc *ProxyClient) StartUdpTunnelOnce(addr string, port int64) { var c *conn.Conn udpProcessor := NewUdpProcesser(nil, pc.LocalIp, pc.LocalPort) for { - if pc.udpTunnel == nil || pc.udpTunnel.IsClosed() { + if !pc.IsClosed() && (pc.udpTunnel == nil || pc.udpTunnel.IsClosed()) { if HttpProxy == "" { c, err = conn.ConnectServer(fmt.Sprintf("%s:%d", addr, port)) } else { @@ -82,8 +85,11 @@ func (pc *ProxyClient) StartUdpTunnelOnce(addr string, port int64) { time.Sleep(1 * time.Second) continue } + pc.mutex.Lock() pc.udpTunnel = c udpProcessor.UpdateTcpConn(pc.udpTunnel) + pc.mutex.Unlock() + udpProcessor.Run() } time.Sleep(1 * time.Second) @@ -91,6 +97,14 @@ func (pc *ProxyClient) StartUdpTunnelOnce(addr string, port int64) { }) } +func (pc *ProxyClient) CloseUdpTunnel() { + pc.mutex.RLock() + defer pc.mutex.RUnlock() + if pc.udpTunnel != nil { + pc.udpTunnel.Close() + } +} + func (pc *ProxyClient) GetLocalConn() (c *conn.Conn, err error) { c, err = conn.ConnectServer(fmt.Sprintf("%s:%d", pc.LocalIp, pc.LocalPort)) if err != nil { @@ -158,3 +172,15 @@ func (pc *ProxyClient) StartTunnel(serverAddr string, serverPort int64) (err err return nil } + +func (pc *ProxyClient) SetCloseFlag(closeFlag bool) { + pc.mutex.Lock() + defer pc.mutex.Unlock() + pc.closeFlag = closeFlag +} + +func (pc *ProxyClient) IsClosed() bool { + pc.mutex.RLock() + defer pc.mutex.RUnlock() + return pc.closeFlag +}