diff --git a/.travis.yml b/.travis.yml index 3ccb28ef..31d65002 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,8 +3,8 @@ language: go go: - 1.5.4 - - 1.6.3 - - 1.7 + - 1.6.4 + - 1.7.4 install: - make diff --git a/conf/frpc.ini b/conf/frpc.ini index 083f84f0..5a064008 100644 --- a/conf/frpc.ini +++ b/conf/frpc.ini @@ -29,6 +29,13 @@ use_gzip = false # connections will be established in advance, default value is zero pool_count = 10 +[dns] +type = udp +local_ip = 127.0.0.1 +local_port = 53 +use_encryption = true +use_gzip = true + # Resolve your domain names to [server_addr] so you can use http://web01.yourdomain.com to browse web01 and http://web02.yourdomain.com to browse web02, the domains are set in frps.ini [web01] type = http diff --git a/conf/frps.ini b/conf/frps.ini index 6163bc98..23079765 100644 --- a/conf/frps.ini +++ b/conf/frps.ini @@ -34,6 +34,12 @@ auth_token = 123 bind_addr = 0.0.0.0 listen_port = 6000 +[dns] +type = udp +auth_token = 123 +bind_addr = 0.0.0.0 +listen_port = 53 + [web01] # if type equals http, vhost_http_port must be set type = http diff --git a/src/cmd/frpc/control.go b/src/cmd/frpc/control.go index dcb8b0d7..98a70d7c 100644 --- a/src/cmd/frpc/control.go +++ b/src/cmd/frpc/control.go @@ -120,7 +120,7 @@ func msgSender(cli *client.ProxyClient, c *conn.Conn, msgSendChan chan interface } buf, _ := json.Marshal(msg) - err := c.Write(string(buf) + "\n") + err := c.WriteString(string(buf) + "\n") if err != nil { log.Warn("ProxyName [%s], write to server error, proxy exit", cli.Name) c.Close() @@ -165,7 +165,7 @@ func loginToServer(cli *client.ProxyClient) (c *conn.Conn, err error) { } buf, _ := json.Marshal(req) - err = c.Write(string(buf) + "\n") + err = c.WriteString(string(buf) + "\n") if err != nil { log.Error("ProxyName [%s], write to server error, %v", cli.Name, err) return @@ -190,6 +190,12 @@ func loginToServer(cli *client.ProxyClient) (c *conn.Conn, err error) { } log.Info("ProxyName [%s], connect to server [%s:%d] success!", cli.Name, client.ServerAddr, client.ServerPort) + + if cli.Type == "udp" { + // we only need one udp work connection + // all udp messages will be forwarded throngh this connection + go cli.StartUdpTunnelOnce(client.ServerAddr, client.ServerPort) + } return } diff --git a/src/cmd/frps/control.go b/src/cmd/frps/control.go index 81e449cd..08ff9814 100644 --- a/src/cmd/frps/control.go +++ b/src/cmd/frps/control.go @@ -71,14 +71,14 @@ func controlWorker(c *conn.Conn) { // login when type is NewCtlConn or NewWorkConn ret, info := doLogin(cliReq, c) // if login type is NewWorkConn, nothing will be send to frpc - if cliReq.Type != consts.NewWorkConn { + if cliReq.Type == consts.NewCtlConn { cliRes := &msg.ControlRes{ Type: consts.NewCtlConnRes, Code: ret, Msg: info, } byteBuf, _ := json.Marshal(cliRes) - err = c.Write(string(byteBuf) + "\n") + err = c.WriteString(string(byteBuf) + "\n") if err != nil { log.Warn("ProxyName [%s], write to client error, proxy exit", cliReq.ProxyName) return @@ -144,9 +144,11 @@ func msgReader(s *server.ProxyServer, c *conn.Conn, msgSendChan chan interface{} if err != nil { if err == io.EOF { log.Warn("ProxyName [%s], client is dead!", s.Name) + s.Close() return err } else if c == nil || c.IsClosed() { log.Warn("ProxyName [%s], client connection is closed", s.Name) + s.Close() return err } log.Warn("ProxyName [%s], read error: %v", s.Name, err) @@ -183,7 +185,7 @@ func msgSender(s *server.ProxyServer, c *conn.Conn, msgSendChan chan interface{} } buf, _ := json.Marshal(msg) - err := c.Write(string(buf) + "\n") + err := c.WriteString(string(buf) + "\n") if err != nil { log.Warn("ProxyName [%s], write to client error, proxy exit", s.Name) s.Close() @@ -193,6 +195,9 @@ func msgSender(s *server.ProxyServer, c *conn.Conn, msgSendChan chan interface{} } // if success, ret equals 0, otherwise greater than 0 +// NewCtlConn +// NewWorkConn +// NewWorkConnUdp func doLogin(req *msg.ControlReq, c *conn.Conn) (ret int64, info string) { ret = 1 // check if PrivilegeMode is enabled @@ -325,6 +330,13 @@ func doLogin(req *msg.ControlReq, c *conn.Conn) (ret int64, info string) { } // the connection will close after join over s.RegisterNewWorkConn(c) + } else if req.Type == consts.NewWorkConnUdp { + // work conn for udp + if s.Status != consts.Working { + log.Warn("ProxyName [%s], is not working when it gets one new work connnection for udp", req.ProxyName) + return + } + s.RegisterNewWorkConnUdp(c) } else { info = fmt.Sprintf("Unsupport login message type [%d]", req.Type) log.Warn("Unsupport login message type [%d]", req.Type) diff --git a/src/models/client/client.go b/src/models/client/client.go index 2548b237..a108bebb 100644 --- a/src/models/client/client.go +++ b/src/models/client/client.go @@ -17,6 +17,7 @@ package client import ( "encoding/json" "fmt" + "sync" "time" "github.com/fatedier/frp/src/models/config" @@ -34,19 +35,71 @@ type ProxyClient struct { RemotePort int64 CustomDomains []string + + udpTunnel *conn.Conn + once sync.Once } -func (p *ProxyClient) GetLocalConn() (c *conn.Conn, err error) { - c, err = conn.ConnectServer(fmt.Sprintf("%s:%d", p.LocalIp, p.LocalPort)) +// if proxy type is udp, keep a tcp connection for transferring udp packages +func (pc *ProxyClient) StartUdpTunnelOnce(addr string, port int64) { + pc.once.Do(func() { + var err error + var c *conn.Conn + udpProcessor := NewUdpProcesser(nil, pc.LocalIp, pc.LocalPort) + for { + if pc.udpTunnel == nil || pc.udpTunnel.IsClosed() { + if HttpProxy == "" { + c, err = conn.ConnectServer(fmt.Sprintf("%s:%d", addr, port)) + } else { + c, err = conn.ConnectServerByHttpProxy(HttpProxy, fmt.Sprintf("%s:%d", addr, port)) + } + if err != nil { + log.Error("ProxyName [%s], udp tunnel connect to server [%s:%d] error, %v", pc.Name, addr, port, err) + time.Sleep(5 * time.Second) + continue + } + + nowTime := time.Now().Unix() + req := &msg.ControlReq{ + Type: consts.NewWorkConnUdp, + ProxyName: pc.Name, + PrivilegeMode: pc.PrivilegeMode, + Timestamp: nowTime, + } + if pc.PrivilegeMode == true { + req.PrivilegeKey = pcrypto.GetAuthKey(pc.Name + PrivilegeToken + fmt.Sprintf("%d", nowTime)) + } else { + req.AuthKey = pcrypto.GetAuthKey(pc.Name + pc.AuthToken + fmt.Sprintf("%d", nowTime)) + } + + buf, _ := json.Marshal(req) + err = c.WriteString(string(buf) + "\n") + if err != nil { + log.Error("ProxyName [%s], udp tunnel write to server error, %v", pc.Name, err) + c.Close() + time.Sleep(1 * time.Second) + continue + } + pc.udpTunnel = c + udpProcessor.UpdateTcpConn(pc.udpTunnel) + udpProcessor.Run() + } + time.Sleep(1 * time.Second) + } + }) +} + +func (pc *ProxyClient) GetLocalConn() (c *conn.Conn, err error) { + c, err = conn.ConnectServer(fmt.Sprintf("%s:%d", pc.LocalIp, pc.LocalPort)) if err != nil { - log.Error("ProxyName [%s], connect to local port error, %v", p.Name, err) + log.Error("ProxyName [%s], connect to local port error, %v", pc.Name, err) } return } -func (p *ProxyClient) GetRemoteConn(addr string, port int64) (c *conn.Conn, err error) { +func (pc *ProxyClient) GetRemoteConn(addr string, port int64) (c *conn.Conn, err error) { defer func() { - if err != nil { + if err != nil && c != nil { c.Close() } }() @@ -57,29 +110,27 @@ func (p *ProxyClient) GetRemoteConn(addr string, port int64) (c *conn.Conn, err c, err = conn.ConnectServerByHttpProxy(HttpProxy, fmt.Sprintf("%s:%d", addr, port)) } if err != nil { - log.Error("ProxyName [%s], connect to server [%s:%d] error, %v", p.Name, addr, port, err) + log.Error("ProxyName [%s], connect to server [%s:%d] error, %v", pc.Name, addr, port, err) return } nowTime := time.Now().Unix() req := &msg.ControlReq{ Type: consts.NewWorkConn, - ProxyName: p.Name, - PrivilegeMode: p.PrivilegeMode, + ProxyName: pc.Name, + PrivilegeMode: pc.PrivilegeMode, Timestamp: nowTime, } - if p.PrivilegeMode == true { - privilegeKey := pcrypto.GetAuthKey(p.Name + PrivilegeToken + fmt.Sprintf("%d", nowTime)) - req.PrivilegeKey = privilegeKey + if pc.PrivilegeMode == true { + req.PrivilegeKey = pcrypto.GetAuthKey(pc.Name + PrivilegeToken + fmt.Sprintf("%d", nowTime)) } else { - authKey := pcrypto.GetAuthKey(p.Name + p.AuthToken + fmt.Sprintf("%d", nowTime)) - req.AuthKey = authKey + req.AuthKey = pcrypto.GetAuthKey(pc.Name + pc.AuthToken + fmt.Sprintf("%d", nowTime)) } buf, _ := json.Marshal(req) - err = c.Write(string(buf) + "\n") + err = c.WriteString(string(buf) + "\n") if err != nil { - log.Error("ProxyName [%s], write to server error, %v", p.Name, err) + log.Error("ProxyName [%s], write to server error, %v", pc.Name, err) return } @@ -87,12 +138,12 @@ func (p *ProxyClient) GetRemoteConn(addr string, port int64) (c *conn.Conn, err return } -func (p *ProxyClient) StartTunnel(serverAddr string, serverPort int64) (err error) { - localConn, err := p.GetLocalConn() +func (pc *ProxyClient) StartTunnel(serverAddr string, serverPort int64) (err error) { + localConn, err := pc.GetLocalConn() if err != nil { return } - remoteConn, err := p.GetRemoteConn(serverAddr, serverPort) + remoteConn, err := pc.GetRemoteConn(serverAddr, serverPort) if err != nil { return } @@ -101,7 +152,7 @@ func (p *ProxyClient) StartTunnel(serverAddr string, serverPort int64) (err erro log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", localConn.GetLocalAddr(), localConn.GetRemoteAddr(), remoteConn.GetLocalAddr(), remoteConn.GetRemoteAddr()) needRecord := false - go msg.JoinMore(localConn, remoteConn, p.BaseConf, needRecord) + go msg.JoinMore(localConn, remoteConn, pc.BaseConf, needRecord) return nil } diff --git a/src/models/client/config.go b/src/models/client/config.go index 417ea3b2..49b81224 100644 --- a/src/models/client/config.go +++ b/src/models/client/config.go @@ -130,7 +130,7 @@ func LoadConf(confFile string) (err error) { proxyClient.Type = "tcp" tmpStr, ok = section["type"] if ok { - if tmpStr != "tcp" && tmpStr != "http" && tmpStr != "https" { + if tmpStr != "tcp" && tmpStr != "http" && tmpStr != "https" && tmpStr != "udp" { return fmt.Errorf("Parse conf error: proxy [%s] type error", proxyClient.Name) } proxyClient.Type = tmpStr diff --git a/src/models/client/process_udp.go b/src/models/client/process_udp.go new file mode 100644 index 00000000..d1dbcc21 --- /dev/null +++ b/src/models/client/process_udp.go @@ -0,0 +1,153 @@ +// Copyright 2016 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "fmt" + "io" + "net" + "sync" + "time" + + "github.com/fatedier/frp/src/models/msg" + "github.com/fatedier/frp/src/utils/conn" + "github.com/fatedier/frp/src/utils/pool" +) + +type UdpProcesser struct { + tcpConn *conn.Conn + closeCh chan struct{} + + localAddr string + + // cache local udp connections + // key is remoteAddr + localUdpConns map[string]*net.UDPConn + mutex sync.RWMutex + tcpConnMutex sync.RWMutex +} + +func NewUdpProcesser(c *conn.Conn, localIp string, localPort int64) *UdpProcesser { + return &UdpProcesser{ + tcpConn: c, + closeCh: make(chan struct{}), + localAddr: fmt.Sprintf("%s:%d", localIp, localPort), + localUdpConns: make(map[string]*net.UDPConn), + } +} + +func (up *UdpProcesser) UpdateTcpConn(c *conn.Conn) { + up.tcpConnMutex.Lock() + defer up.tcpConnMutex.Unlock() + up.tcpConn = c +} + +func (up *UdpProcesser) Run() { + go up.ReadLoop() +} + +func (up *UdpProcesser) ReadLoop() { + var ( + buf string + err error + ) + for { + udpPacket := &msg.UdpPacket{} + + // read udp package from frps + buf, err = up.tcpConn.ReadLine() + if err != nil { + if err == io.EOF { + return + } else { + continue + } + } + err = udpPacket.UnPack([]byte(buf)) + if err != nil { + continue + } + + // write to local udp port + sendConn, ok := up.GetUdpConn(udpPacket.SrcStr) + if !ok { + dstAddr, err := net.ResolveUDPAddr("udp", up.localAddr) + if err != nil { + continue + } + sendConn, err = net.DialUDP("udp", nil, dstAddr) + if err != nil { + continue + } + + up.SetUdpConn(udpPacket.SrcStr, sendConn) + } + + _, err = sendConn.Write(udpPacket.Content) + if err != nil { + sendConn.Close() + continue + } + + if !ok { + go up.Forward(udpPacket, sendConn) + } + } +} + +func (up *UdpProcesser) Forward(udpPacket *msg.UdpPacket, singleConn *net.UDPConn) { + addr := udpPacket.SrcStr + defer up.RemoveUdpConn(addr) + + buf := pool.GetBuf(2048) + for { + singleConn.SetReadDeadline(time.Now().Add(120 * time.Second)) + n, remoteAddr, err := singleConn.ReadFromUDP(buf) + if err != nil { + return + } + + // forward to frps + forwardPacket := msg.NewUdpPacket(buf[0:n], remoteAddr, udpPacket.Src) + up.tcpConnMutex.RLock() + err = up.tcpConn.WriteString(string(forwardPacket.Pack()) + "\n") + up.tcpConnMutex.RUnlock() + if err != nil { + return + } + } +} + +func (up *UdpProcesser) GetUdpConn(addr string) (singleConn *net.UDPConn, ok bool) { + up.mutex.RLock() + defer up.mutex.RUnlock() + singleConn, ok = up.localUdpConns[addr] + return +} + +func (up *UdpProcesser) SetUdpConn(addr string, conn *net.UDPConn) { + up.mutex.Lock() + defer up.mutex.Unlock() + up.localUdpConns[addr] = conn +} + +func (up *UdpProcesser) RemoveUdpConn(addr string) { + up.mutex.Lock() + defer up.mutex.Unlock() + if c, ok := up.localUdpConns[addr]; ok { + c.Close() + } + delete(up.localUdpConns, addr) +} diff --git a/src/models/consts/consts.go b/src/models/consts/consts.go index 90ca9036..f41e1836 100644 --- a/src/models/consts/consts.go +++ b/src/models/consts/consts.go @@ -37,4 +37,5 @@ const ( NewCtlConnRes HeartbeatReq HeartbeatRes + NewWorkConnUdp ) diff --git a/src/models/msg/process.go b/src/models/msg/process.go index a1b84c23..bfc31303 100644 --- a/src/models/msg/process.go +++ b/src/models/msg/process.go @@ -53,9 +53,9 @@ func Join(c1 *conn.Conn, c2 *conn.Conn) { } // join two connections and do some operations -func JoinMore(c1 *conn.Conn, c2 *conn.Conn, conf config.BaseConf, needRecord bool) { +func JoinMore(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser, conf config.BaseConf, needRecord bool) { var wait sync.WaitGroup - encryptPipe := func(from *conn.Conn, to *conn.Conn) { + encryptPipe := func(from io.ReadCloser, to io.WriteCloser) { defer from.Close() defer to.Close() defer wait.Done() @@ -64,7 +64,7 @@ func JoinMore(c1 *conn.Conn, c2 *conn.Conn, conf config.BaseConf, needRecord boo pipeEncrypt(from, to, conf, needRecord) } - decryptPipe := func(to *conn.Conn, from *conn.Conn) { + decryptPipe := func(to io.ReadCloser, from io.WriteCloser) { defer from.Close() defer to.Close() defer wait.Done() @@ -109,7 +109,7 @@ func unpkgMsg(data []byte) (int, []byte, []byte) { } // decrypt msg from reader, then write into writer -func pipeDecrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bool) (err error) { +func pipeDecrypt(r io.Reader, w io.Writer, conf config.BaseConf, needRecord bool) (err error) { laes := new(pcrypto.Pcrypto) key := conf.AuthToken if conf.PrivilegeMode { @@ -175,7 +175,7 @@ func pipeDecrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bo } } - _, err = w.WriteBytes(res) + _, err = w.Write(res) if err != nil { return err } @@ -192,7 +192,7 @@ func pipeDecrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bo } // recvive msg from reader, then encrypt msg into writer -func pipeEncrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bool) (err error) { +func pipeEncrypt(r io.Reader, w io.Writer, conf config.BaseConf, needRecord bool) (err error) { laes := new(pcrypto.Pcrypto) key := conf.AuthToken if conf.PrivilegeMode { @@ -247,7 +247,7 @@ func pipeEncrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bo } res = pkgMsg(res) - _, err = w.WriteBytes(res) + _, err = w.Write(res) if err != nil { return err } diff --git a/src/models/msg/udp.go b/src/models/msg/udp.go new file mode 100644 index 00000000..5cc01f64 --- /dev/null +++ b/src/models/msg/udp.go @@ -0,0 +1,72 @@ +// Copyright 2016 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package msg + +import ( + "encoding/base64" + "encoding/json" + "net" +) + +type UdpPacket struct { + Content []byte `json:"-"` + Src *net.UDPAddr `json:"-"` + Dst *net.UDPAddr `json:"-"` + + EncodeContent string `json:"content"` + SrcStr string `json:"src"` + DstStr string `json:"dst"` +} + +func NewUdpPacket(content []byte, src, dst *net.UDPAddr) *UdpPacket { + up := &UdpPacket{ + Src: src, + Dst: dst, + EncodeContent: base64.StdEncoding.EncodeToString(content), + SrcStr: src.String(), + DstStr: dst.String(), + } + return up +} + +// parse one udp packet struct to bytes +func (up *UdpPacket) Pack() []byte { + b, _ := json.Marshal(up) + return b +} + +// parse from bytes to UdpPacket struct +func (up *UdpPacket) UnPack(packet []byte) error { + err := json.Unmarshal(packet, &up) + if err != nil { + return err + } + + up.Content, err = base64.StdEncoding.DecodeString(up.EncodeContent) + if err != nil { + return err + } + + up.Src, err = net.ResolveUDPAddr("udp", up.SrcStr) + if err != nil { + return err + } + + up.Dst, err = net.ResolveUDPAddr("udp", up.DstStr) + if err != nil { + return err + } + return nil +} diff --git a/src/models/msg/udp_test.go b/src/models/msg/udp_test.go new file mode 100644 index 00000000..614b124e --- /dev/null +++ b/src/models/msg/udp_test.go @@ -0,0 +1,50 @@ +// Copyright 2016 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package msg + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +var ( + content string = "udp packet test" + src string = "1.1.1.1:1000" + dst string = "2.2.2.2:2000" + + udpMsg *UdpPacket +) + +func init() { + srcAddr, _ := net.ResolveUDPAddr("udp", src) + dstAddr, _ := net.ResolveUDPAddr("udp", dst) + udpMsg = NewUdpPacket([]byte(content), srcAddr, dstAddr) +} + +func TestPack(t *testing.T) { + assert := assert.New(t) + msg := udpMsg.Pack() + assert.Equal(string(msg), `{"content":"dWRwIHBhY2tldCB0ZXN0","src":"1.1.1.1:1000","dst":"2.2.2.2:2000"}`) +} + +func TestUnpack(t *testing.T) { + assert := assert.New(t) + udpMsg.UnPack([]byte(`{"content":"dWRwIHBhY2tldCB0ZXN0","src":"1.1.1.1:1000","dst":"2.2.2.2:2000"}`)) + assert.Equal(content, string(udpMsg.Content)) + assert.Equal(src, udpMsg.Src.String()) + assert.Equal(dst, udpMsg.Dst.String()) +} diff --git a/src/models/server/config.go b/src/models/server/config.go index 58ae0da3..bc880edd 100644 --- a/src/models/server/config.go +++ b/src/models/server/config.go @@ -240,7 +240,7 @@ func loadProxyConf(confFile string) (proxyServers map[string]*ProxyServer, err e proxyServer.Type, ok = section["type"] if ok { - if proxyServer.Type != "tcp" && proxyServer.Type != "http" && proxyServer.Type != "https" { + if proxyServer.Type != "tcp" && proxyServer.Type != "http" && proxyServer.Type != "https" && proxyServer.Type != "udp" { return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] type error", proxyServer.Name) } } else { @@ -252,8 +252,8 @@ func loadProxyConf(confFile string) (proxyServers map[string]*ProxyServer, err e return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] no auth_token found", proxyServer.Name) } - // for tcp - if proxyServer.Type == "tcp" { + // for tcp and udp + if proxyServer.Type == "tcp" || proxyServer.Type == "udp" { proxyServer.BindAddr, ok = section["bind_addr"] if !ok { proxyServer.BindAddr = "0.0.0.0" diff --git a/src/models/server/server.go b/src/models/server/server.go index 2e2618d5..33c278bb 100644 --- a/src/models/server/server.go +++ b/src/models/server/server.go @@ -16,6 +16,7 @@ package server import ( "fmt" + "net" "sync" "time" @@ -25,6 +26,7 @@ import ( "github.com/fatedier/frp/src/models/msg" "github.com/fatedier/frp/src/utils/conn" "github.com/fatedier/frp/src/utils/log" + "github.com/fatedier/frp/src/utils/pool" ) type Listener interface { @@ -38,13 +40,17 @@ type ProxyServer struct { ListenPort int64 CustomDomains []string - Status int64 - CtlConn *conn.Conn // control connection with frpc - listeners []Listener // accept new connection from remote users - ctlMsgChan chan int64 // every time accept a new user conn, put "1" to the channel - workConnChan chan *conn.Conn // get new work conns from control goroutine - mutex sync.RWMutex - closeChan chan struct{} // for notify other goroutines that the proxy is closed by close this channel + Status int64 + CtlConn *conn.Conn // control connection with frpc + WorkConnUdp *conn.Conn // work connection for udp + + udpConn *net.UDPConn + listeners []Listener // accept new connection from remote users + ctlMsgChan chan int64 // every time accept a new user conn, put "1" to the channel + workConnChan chan *conn.Conn // get new work conns from control goroutine + udpSenderChan chan *msg.UdpPacket + mutex sync.RWMutex + closeChan chan struct{} // close this channel for notifying other goroutines that the proxy is closed } func NewProxyServer() (p *ProxyServer) { @@ -83,6 +89,7 @@ func (p *ProxyServer) Init() { metric.SetStatus(p.Name, p.Status) p.workConnChan = make(chan *conn.Conn, p.PoolCount+10) p.ctlMsgChan = make(chan int64, p.PoolCount+10) + p.udpSenderChan = make(chan *msg.UdpPacket, 1024) p.listeners = make([]Listener, 0) p.closeChan = make(chan struct{}) p.Unlock() @@ -150,41 +157,68 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) { go p.connectionPoolManager(p.closeChan) } - // start a goroutine for every listener to accept user connection - for _, listener := range p.listeners { - go func(l Listener) { + if p.Type == "udp" { + // udp is special + p.udpConn, err = conn.ListenUDP(p.BindAddr, p.ListenPort) + if err != nil { + log.Warn("ProxyName [%s], listen udp port error: %v", p.Name, err) + return err + } + go func() { for { - // block - // if listener is closed, err returned - c, err := l.Accept() + buf := pool.GetBuf(2048) + n, remoteAddr, err := p.udpConn.ReadFromUDP(buf) if err != nil { - log.Info("ProxyName [%s], listener is closed", p.Name) + log.Info("ProxyName [%s], udp listener is closed", p.Name) return } - log.Debug("ProxyName [%s], get one new user conn [%s]", p.Name, c.GetRemoteAddr()) - - if p.Status != consts.Working { - log.Debug("ProxyName [%s] is not working, new user conn close", p.Name) - c.Close() - return + localAddr, _ := net.ResolveUDPAddr("udp", p.udpConn.LocalAddr().String()) + udpPacket := msg.NewUdpPacket(buf[0:n], remoteAddr, localAddr) + select { + case p.udpSenderChan <- udpPacket: + default: + log.Warn("ProxyName [%s], udp sender channel is full", p.Name) } - - go func(userConn *conn.Conn) { - workConn, err := p.getWorkConn() + pool.PutBuf(buf) + } + }() + } else { + // start a goroutine for every listener to accept user connection + for _, listener := range p.listeners { + go func(l Listener) { + for { + // block + // if listener is closed, err returned + c, err := l.Accept() if err != nil { + log.Info("ProxyName [%s], listener is closed", p.Name) + return + } + log.Debug("ProxyName [%s], get one new user conn [%s]", p.Name, c.GetRemoteAddr()) + + if p.Status != consts.Working { + log.Debug("ProxyName [%s] is not working, new user conn close", p.Name) + c.Close() return } - // message will be transferred to another without modifying - // l means local, r means remote - log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", workConn.GetLocalAddr(), workConn.GetRemoteAddr(), - userConn.GetLocalAddr(), userConn.GetRemoteAddr()) + go func(userConn *conn.Conn) { + workConn, err := p.getWorkConn() + if err != nil { + return + } - needRecord := true - go msg.JoinMore(userConn, workConn, p.BaseConf, needRecord) - }(c) - } - }(listener) + // message will be transferred to another without modifying + // l means local, r means remote + log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", workConn.GetLocalAddr(), workConn.GetRemoteAddr(), + userConn.GetLocalAddr(), userConn.GetRemoteAddr()) + + needRecord := true + go msg.JoinMore(userConn, workConn, p.BaseConf, needRecord) + }(c) + } + }(listener) + } } return nil } @@ -200,10 +234,18 @@ func (p *ProxyServer) Close() { } close(p.ctlMsgChan) close(p.workConnChan) + close(p.udpSenderChan) close(p.closeChan) if p.CtlConn != nil { p.CtlConn.Close() } + if p.WorkConnUdp != nil { + p.WorkConnUdp.Close() + } + if p.udpConn != nil { + p.udpConn.Close() + p.udpConn = nil + } } metric.SetStatus(p.Name, p.Status) // if the proxy created by PrivilegeMode, delete it when closed @@ -228,9 +270,60 @@ func (p *ProxyServer) RegisterNewWorkConn(c *conn.Conn) { case p.workConnChan <- c: default: log.Debug("ProxyName [%s], workConnChan is full, so close this work connection", p.Name) + c.Close() } } +// create a tcp connection for forwarding udp packages +func (p *ProxyServer) RegisterNewWorkConnUdp(c *conn.Conn) { + if p.WorkConnUdp != nil && !p.WorkConnUdp.IsClosed() { + p.WorkConnUdp.Close() + } + p.WorkConnUdp = c + + // read + go func() { + var ( + buf string + err error + ) + for { + buf, err = c.ReadLine() + if err != nil { + log.Warn("ProxyName [%s], work connection for udp closed", p.Name) + return + } + udpPacket := &msg.UdpPacket{} + err = udpPacket.UnPack([]byte(buf)) + if err != nil { + log.Warn("ProxyName [%s], unpack udp packet error: %v", p.Name, err) + continue + } + + // send to user + _, err = p.udpConn.WriteToUDP(udpPacket.Content, udpPacket.Dst) + if err != nil { + continue + } + } + }() + + // write + go func() { + for { + udpPacket, ok := <-p.udpSenderChan + if !ok { + return + } + err := c.WriteString(string(udpPacket.Pack()) + "\n") + if err != nil { + log.Debug("ProxyName [%s], write to work connection for udp error: %v", p.Name, err) + return + } + } + }() +} + // When frps get one user connection, we get one work connection from the pool and return it. // If no workConn available in the pool, send message to frpc to get one or more // and wait until it is available. diff --git a/src/utils/conn/conn.go b/src/utils/conn/conn.go index 2434a089..4b6eb157 100644 --- a/src/utils/conn/conn.go +++ b/src/utils/conn/conn.go @@ -202,12 +202,12 @@ func (c *Conn) ReadLine() (buff string, err error) { return buff, err } -func (c *Conn) WriteBytes(content []byte) (n int, err error) { +func (c *Conn) Write(content []byte) (n int, err error) { n, err = c.TcpConn.Write(content) return } -func (c *Conn) Write(content string) (err error) { +func (c *Conn) WriteString(content string) (err error) { _, err = c.TcpConn.Write([]byte(content)) return err } @@ -220,13 +220,14 @@ func (c *Conn) SetReadDeadline(t time.Time) error { return c.TcpConn.SetReadDeadline(t) } -func (c *Conn) Close() { +func (c *Conn) Close() error { c.mutex.Lock() + defer c.mutex.Unlock() if c.TcpConn != nil && c.closeFlag == false { c.closeFlag = true c.TcpConn.Close() } - c.mutex.Unlock() + return nil } func (c *Conn) IsClosed() (closeFlag bool) { @@ -245,7 +246,6 @@ func (c *Conn) CheckClosed() bool { } c.mutex.RUnlock() - // err := c.TcpConn.SetReadDeadline(time.Now().Add(100 * time.Microsecond)) err := c.TcpConn.SetReadDeadline(time.Now().Add(time.Millisecond)) if err != nil { c.Close() diff --git a/src/utils/conn/udp_conn.go b/src/utils/conn/udp_conn.go new file mode 100644 index 00000000..c9bf8f39 --- /dev/null +++ b/src/utils/conn/udp_conn.go @@ -0,0 +1,29 @@ +// Copyright 2016 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package conn + +import ( + "fmt" + "net" +) + +func ListenUDP(bindAddr string, bindPort int64) (conn *net.UDPConn, err error) { + udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", bindAddr, bindPort)) + if err != nil { + return conn, err + } + conn, err = net.ListenUDP("udp", udpAddr) + return +} diff --git a/test/echo_server.go b/test/echo_server.go index 6b766a49..67dbba26 100644 --- a/test/echo_server.go +++ b/test/echo_server.go @@ -40,6 +40,6 @@ func echoWorker(c *conn.Conn) { return } - c.Write(buff) + c.WriteString(buff) } } diff --git a/test/func_test.go b/test/func_test.go index 3683c516..26a16b86 100644 --- a/test/func_test.go +++ b/test/func_test.go @@ -26,7 +26,7 @@ func TestEchoServer(t *testing.T) { timer := time.Now().Add(time.Duration(5) * time.Second) c.SetDeadline(timer) - c.Write(ECHO_TEST_STR) + c.WriteString(ECHO_TEST_STR) buff, err := c.ReadLine() if err != nil {