support udp type

This commit is contained in:
fatedier 2016-12-19 01:22:21 +08:00
parent adcb2c1ea5
commit f2999e3317
18 changed files with 556 additions and 76 deletions

View File

@ -3,8 +3,8 @@ language: go
go: go:
- 1.5.4 - 1.5.4
- 1.6.3 - 1.6.4
- 1.7 - 1.7.4
install: install:
- make - make

View File

@ -29,6 +29,13 @@ use_gzip = false
# connections will be established in advance, default value is zero # connections will be established in advance, default value is zero
pool_count = 10 pool_count = 10
[dns]
type = udp
local_ip = 127.0.0.1
local_port = 53
use_encryption = true
use_gzip = true
# Resolve your domain names to [server_addr] so you can use http://web01.yourdomain.com to browse web01 and http://web02.yourdomain.com to browse web02, the domains are set in frps.ini # Resolve your domain names to [server_addr] so you can use http://web01.yourdomain.com to browse web01 and http://web02.yourdomain.com to browse web02, the domains are set in frps.ini
[web01] [web01]
type = http type = http

View File

@ -34,6 +34,12 @@ auth_token = 123
bind_addr = 0.0.0.0 bind_addr = 0.0.0.0
listen_port = 6000 listen_port = 6000
[dns]
type = udp
auth_token = 123
bind_addr = 0.0.0.0
listen_port = 53
[web01] [web01]
# if type equals http, vhost_http_port must be set # if type equals http, vhost_http_port must be set
type = http type = http

View File

@ -120,7 +120,7 @@ func msgSender(cli *client.ProxyClient, c *conn.Conn, msgSendChan chan interface
} }
buf, _ := json.Marshal(msg) buf, _ := json.Marshal(msg)
err := c.Write(string(buf) + "\n") err := c.WriteString(string(buf) + "\n")
if err != nil { if err != nil {
log.Warn("ProxyName [%s], write to server error, proxy exit", cli.Name) log.Warn("ProxyName [%s], write to server error, proxy exit", cli.Name)
c.Close() c.Close()
@ -165,7 +165,7 @@ func loginToServer(cli *client.ProxyClient) (c *conn.Conn, err error) {
} }
buf, _ := json.Marshal(req) buf, _ := json.Marshal(req)
err = c.Write(string(buf) + "\n") err = c.WriteString(string(buf) + "\n")
if err != nil { if err != nil {
log.Error("ProxyName [%s], write to server error, %v", cli.Name, err) log.Error("ProxyName [%s], write to server error, %v", cli.Name, err)
return return
@ -190,6 +190,12 @@ func loginToServer(cli *client.ProxyClient) (c *conn.Conn, err error) {
} }
log.Info("ProxyName [%s], connect to server [%s:%d] success!", cli.Name, client.ServerAddr, client.ServerPort) log.Info("ProxyName [%s], connect to server [%s:%d] success!", cli.Name, client.ServerAddr, client.ServerPort)
if cli.Type == "udp" {
// we only need one udp work connection
// all udp messages will be forwarded throngh this connection
go cli.StartUdpTunnelOnce(client.ServerAddr, client.ServerPort)
}
return return
} }

View File

@ -71,14 +71,14 @@ func controlWorker(c *conn.Conn) {
// login when type is NewCtlConn or NewWorkConn // login when type is NewCtlConn or NewWorkConn
ret, info := doLogin(cliReq, c) ret, info := doLogin(cliReq, c)
// if login type is NewWorkConn, nothing will be send to frpc // if login type is NewWorkConn, nothing will be send to frpc
if cliReq.Type != consts.NewWorkConn { if cliReq.Type == consts.NewCtlConn {
cliRes := &msg.ControlRes{ cliRes := &msg.ControlRes{
Type: consts.NewCtlConnRes, Type: consts.NewCtlConnRes,
Code: ret, Code: ret,
Msg: info, Msg: info,
} }
byteBuf, _ := json.Marshal(cliRes) byteBuf, _ := json.Marshal(cliRes)
err = c.Write(string(byteBuf) + "\n") err = c.WriteString(string(byteBuf) + "\n")
if err != nil { if err != nil {
log.Warn("ProxyName [%s], write to client error, proxy exit", cliReq.ProxyName) log.Warn("ProxyName [%s], write to client error, proxy exit", cliReq.ProxyName)
return return
@ -144,9 +144,11 @@ func msgReader(s *server.ProxyServer, c *conn.Conn, msgSendChan chan interface{}
if err != nil { if err != nil {
if err == io.EOF { if err == io.EOF {
log.Warn("ProxyName [%s], client is dead!", s.Name) log.Warn("ProxyName [%s], client is dead!", s.Name)
s.Close()
return err return err
} else if c == nil || c.IsClosed() { } else if c == nil || c.IsClosed() {
log.Warn("ProxyName [%s], client connection is closed", s.Name) log.Warn("ProxyName [%s], client connection is closed", s.Name)
s.Close()
return err return err
} }
log.Warn("ProxyName [%s], read error: %v", s.Name, err) log.Warn("ProxyName [%s], read error: %v", s.Name, err)
@ -183,7 +185,7 @@ func msgSender(s *server.ProxyServer, c *conn.Conn, msgSendChan chan interface{}
} }
buf, _ := json.Marshal(msg) buf, _ := json.Marshal(msg)
err := c.Write(string(buf) + "\n") err := c.WriteString(string(buf) + "\n")
if err != nil { if err != nil {
log.Warn("ProxyName [%s], write to client error, proxy exit", s.Name) log.Warn("ProxyName [%s], write to client error, proxy exit", s.Name)
s.Close() s.Close()
@ -193,6 +195,9 @@ func msgSender(s *server.ProxyServer, c *conn.Conn, msgSendChan chan interface{}
} }
// if success, ret equals 0, otherwise greater than 0 // if success, ret equals 0, otherwise greater than 0
// NewCtlConn
// NewWorkConn
// NewWorkConnUdp
func doLogin(req *msg.ControlReq, c *conn.Conn) (ret int64, info string) { func doLogin(req *msg.ControlReq, c *conn.Conn) (ret int64, info string) {
ret = 1 ret = 1
// check if PrivilegeMode is enabled // check if PrivilegeMode is enabled
@ -325,6 +330,13 @@ func doLogin(req *msg.ControlReq, c *conn.Conn) (ret int64, info string) {
} }
// the connection will close after join over // the connection will close after join over
s.RegisterNewWorkConn(c) s.RegisterNewWorkConn(c)
} else if req.Type == consts.NewWorkConnUdp {
// work conn for udp
if s.Status != consts.Working {
log.Warn("ProxyName [%s], is not working when it gets one new work connnection for udp", req.ProxyName)
return
}
s.RegisterNewWorkConnUdp(c)
} else { } else {
info = fmt.Sprintf("Unsupport login message type [%d]", req.Type) info = fmt.Sprintf("Unsupport login message type [%d]", req.Type)
log.Warn("Unsupport login message type [%d]", req.Type) log.Warn("Unsupport login message type [%d]", req.Type)

View File

@ -17,6 +17,7 @@ package client
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"sync"
"time" "time"
"github.com/fatedier/frp/src/models/config" "github.com/fatedier/frp/src/models/config"
@ -34,19 +35,71 @@ type ProxyClient struct {
RemotePort int64 RemotePort int64
CustomDomains []string CustomDomains []string
udpTunnel *conn.Conn
once sync.Once
} }
func (p *ProxyClient) GetLocalConn() (c *conn.Conn, err error) { // if proxy type is udp, keep a tcp connection for transferring udp packages
c, err = conn.ConnectServer(fmt.Sprintf("%s:%d", p.LocalIp, p.LocalPort)) func (pc *ProxyClient) StartUdpTunnelOnce(addr string, port int64) {
pc.once.Do(func() {
var err error
var c *conn.Conn
udpProcessor := NewUdpProcesser(nil, pc.LocalIp, pc.LocalPort)
for {
if pc.udpTunnel == nil || pc.udpTunnel.IsClosed() {
if HttpProxy == "" {
c, err = conn.ConnectServer(fmt.Sprintf("%s:%d", addr, port))
} else {
c, err = conn.ConnectServerByHttpProxy(HttpProxy, fmt.Sprintf("%s:%d", addr, port))
}
if err != nil {
log.Error("ProxyName [%s], udp tunnel connect to server [%s:%d] error, %v", pc.Name, addr, port, err)
time.Sleep(5 * time.Second)
continue
}
nowTime := time.Now().Unix()
req := &msg.ControlReq{
Type: consts.NewWorkConnUdp,
ProxyName: pc.Name,
PrivilegeMode: pc.PrivilegeMode,
Timestamp: nowTime,
}
if pc.PrivilegeMode == true {
req.PrivilegeKey = pcrypto.GetAuthKey(pc.Name + PrivilegeToken + fmt.Sprintf("%d", nowTime))
} else {
req.AuthKey = pcrypto.GetAuthKey(pc.Name + pc.AuthToken + fmt.Sprintf("%d", nowTime))
}
buf, _ := json.Marshal(req)
err = c.WriteString(string(buf) + "\n")
if err != nil {
log.Error("ProxyName [%s], udp tunnel write to server error, %v", pc.Name, err)
c.Close()
time.Sleep(1 * time.Second)
continue
}
pc.udpTunnel = c
udpProcessor.UpdateTcpConn(pc.udpTunnel)
udpProcessor.Run()
}
time.Sleep(1 * time.Second)
}
})
}
func (pc *ProxyClient) GetLocalConn() (c *conn.Conn, err error) {
c, err = conn.ConnectServer(fmt.Sprintf("%s:%d", pc.LocalIp, pc.LocalPort))
if err != nil { if err != nil {
log.Error("ProxyName [%s], connect to local port error, %v", p.Name, err) log.Error("ProxyName [%s], connect to local port error, %v", pc.Name, err)
} }
return return
} }
func (p *ProxyClient) GetRemoteConn(addr string, port int64) (c *conn.Conn, err error) { func (pc *ProxyClient) GetRemoteConn(addr string, port int64) (c *conn.Conn, err error) {
defer func() { defer func() {
if err != nil { if err != nil && c != nil {
c.Close() c.Close()
} }
}() }()
@ -57,29 +110,27 @@ func (p *ProxyClient) GetRemoteConn(addr string, port int64) (c *conn.Conn, err
c, err = conn.ConnectServerByHttpProxy(HttpProxy, fmt.Sprintf("%s:%d", addr, port)) c, err = conn.ConnectServerByHttpProxy(HttpProxy, fmt.Sprintf("%s:%d", addr, port))
} }
if err != nil { if err != nil {
log.Error("ProxyName [%s], connect to server [%s:%d] error, %v", p.Name, addr, port, err) log.Error("ProxyName [%s], connect to server [%s:%d] error, %v", pc.Name, addr, port, err)
return return
} }
nowTime := time.Now().Unix() nowTime := time.Now().Unix()
req := &msg.ControlReq{ req := &msg.ControlReq{
Type: consts.NewWorkConn, Type: consts.NewWorkConn,
ProxyName: p.Name, ProxyName: pc.Name,
PrivilegeMode: p.PrivilegeMode, PrivilegeMode: pc.PrivilegeMode,
Timestamp: nowTime, Timestamp: nowTime,
} }
if p.PrivilegeMode == true { if pc.PrivilegeMode == true {
privilegeKey := pcrypto.GetAuthKey(p.Name + PrivilegeToken + fmt.Sprintf("%d", nowTime)) req.PrivilegeKey = pcrypto.GetAuthKey(pc.Name + PrivilegeToken + fmt.Sprintf("%d", nowTime))
req.PrivilegeKey = privilegeKey
} else { } else {
authKey := pcrypto.GetAuthKey(p.Name + p.AuthToken + fmt.Sprintf("%d", nowTime)) req.AuthKey = pcrypto.GetAuthKey(pc.Name + pc.AuthToken + fmt.Sprintf("%d", nowTime))
req.AuthKey = authKey
} }
buf, _ := json.Marshal(req) buf, _ := json.Marshal(req)
err = c.Write(string(buf) + "\n") err = c.WriteString(string(buf) + "\n")
if err != nil { if err != nil {
log.Error("ProxyName [%s], write to server error, %v", p.Name, err) log.Error("ProxyName [%s], write to server error, %v", pc.Name, err)
return return
} }
@ -87,12 +138,12 @@ func (p *ProxyClient) GetRemoteConn(addr string, port int64) (c *conn.Conn, err
return return
} }
func (p *ProxyClient) StartTunnel(serverAddr string, serverPort int64) (err error) { func (pc *ProxyClient) StartTunnel(serverAddr string, serverPort int64) (err error) {
localConn, err := p.GetLocalConn() localConn, err := pc.GetLocalConn()
if err != nil { if err != nil {
return return
} }
remoteConn, err := p.GetRemoteConn(serverAddr, serverPort) remoteConn, err := pc.GetRemoteConn(serverAddr, serverPort)
if err != nil { if err != nil {
return return
} }
@ -101,7 +152,7 @@ func (p *ProxyClient) StartTunnel(serverAddr string, serverPort int64) (err erro
log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", localConn.GetLocalAddr(), localConn.GetRemoteAddr(), log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", localConn.GetLocalAddr(), localConn.GetRemoteAddr(),
remoteConn.GetLocalAddr(), remoteConn.GetRemoteAddr()) remoteConn.GetLocalAddr(), remoteConn.GetRemoteAddr())
needRecord := false needRecord := false
go msg.JoinMore(localConn, remoteConn, p.BaseConf, needRecord) go msg.JoinMore(localConn, remoteConn, pc.BaseConf, needRecord)
return nil return nil
} }

View File

@ -130,7 +130,7 @@ func LoadConf(confFile string) (err error) {
proxyClient.Type = "tcp" proxyClient.Type = "tcp"
tmpStr, ok = section["type"] tmpStr, ok = section["type"]
if ok { if ok {
if tmpStr != "tcp" && tmpStr != "http" && tmpStr != "https" { if tmpStr != "tcp" && tmpStr != "http" && tmpStr != "https" && tmpStr != "udp" {
return fmt.Errorf("Parse conf error: proxy [%s] type error", proxyClient.Name) return fmt.Errorf("Parse conf error: proxy [%s] type error", proxyClient.Name)
} }
proxyClient.Type = tmpStr proxyClient.Type = tmpStr

View File

@ -0,0 +1,153 @@
// Copyright 2016 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"fmt"
"io"
"net"
"sync"
"time"
"github.com/fatedier/frp/src/models/msg"
"github.com/fatedier/frp/src/utils/conn"
"github.com/fatedier/frp/src/utils/pool"
)
type UdpProcesser struct {
tcpConn *conn.Conn
closeCh chan struct{}
localAddr string
// cache local udp connections
// key is remoteAddr
localUdpConns map[string]*net.UDPConn
mutex sync.RWMutex
tcpConnMutex sync.RWMutex
}
func NewUdpProcesser(c *conn.Conn, localIp string, localPort int64) *UdpProcesser {
return &UdpProcesser{
tcpConn: c,
closeCh: make(chan struct{}),
localAddr: fmt.Sprintf("%s:%d", localIp, localPort),
localUdpConns: make(map[string]*net.UDPConn),
}
}
func (up *UdpProcesser) UpdateTcpConn(c *conn.Conn) {
up.tcpConnMutex.Lock()
defer up.tcpConnMutex.Unlock()
up.tcpConn = c
}
func (up *UdpProcesser) Run() {
go up.ReadLoop()
}
func (up *UdpProcesser) ReadLoop() {
var (
buf string
err error
)
for {
udpPacket := &msg.UdpPacket{}
// read udp package from frps
buf, err = up.tcpConn.ReadLine()
if err != nil {
if err == io.EOF {
return
} else {
continue
}
}
err = udpPacket.UnPack([]byte(buf))
if err != nil {
continue
}
// write to local udp port
sendConn, ok := up.GetUdpConn(udpPacket.SrcStr)
if !ok {
dstAddr, err := net.ResolveUDPAddr("udp", up.localAddr)
if err != nil {
continue
}
sendConn, err = net.DialUDP("udp", nil, dstAddr)
if err != nil {
continue
}
up.SetUdpConn(udpPacket.SrcStr, sendConn)
}
_, err = sendConn.Write(udpPacket.Content)
if err != nil {
sendConn.Close()
continue
}
if !ok {
go up.Forward(udpPacket, sendConn)
}
}
}
func (up *UdpProcesser) Forward(udpPacket *msg.UdpPacket, singleConn *net.UDPConn) {
addr := udpPacket.SrcStr
defer up.RemoveUdpConn(addr)
buf := pool.GetBuf(2048)
for {
singleConn.SetReadDeadline(time.Now().Add(120 * time.Second))
n, remoteAddr, err := singleConn.ReadFromUDP(buf)
if err != nil {
return
}
// forward to frps
forwardPacket := msg.NewUdpPacket(buf[0:n], remoteAddr, udpPacket.Src)
up.tcpConnMutex.RLock()
err = up.tcpConn.WriteString(string(forwardPacket.Pack()) + "\n")
up.tcpConnMutex.RUnlock()
if err != nil {
return
}
}
}
func (up *UdpProcesser) GetUdpConn(addr string) (singleConn *net.UDPConn, ok bool) {
up.mutex.RLock()
defer up.mutex.RUnlock()
singleConn, ok = up.localUdpConns[addr]
return
}
func (up *UdpProcesser) SetUdpConn(addr string, conn *net.UDPConn) {
up.mutex.Lock()
defer up.mutex.Unlock()
up.localUdpConns[addr] = conn
}
func (up *UdpProcesser) RemoveUdpConn(addr string) {
up.mutex.Lock()
defer up.mutex.Unlock()
if c, ok := up.localUdpConns[addr]; ok {
c.Close()
}
delete(up.localUdpConns, addr)
}

View File

@ -37,4 +37,5 @@ const (
NewCtlConnRes NewCtlConnRes
HeartbeatReq HeartbeatReq
HeartbeatRes HeartbeatRes
NewWorkConnUdp
) )

View File

@ -53,9 +53,9 @@ func Join(c1 *conn.Conn, c2 *conn.Conn) {
} }
// join two connections and do some operations // join two connections and do some operations
func JoinMore(c1 *conn.Conn, c2 *conn.Conn, conf config.BaseConf, needRecord bool) { func JoinMore(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser, conf config.BaseConf, needRecord bool) {
var wait sync.WaitGroup var wait sync.WaitGroup
encryptPipe := func(from *conn.Conn, to *conn.Conn) { encryptPipe := func(from io.ReadCloser, to io.WriteCloser) {
defer from.Close() defer from.Close()
defer to.Close() defer to.Close()
defer wait.Done() defer wait.Done()
@ -64,7 +64,7 @@ func JoinMore(c1 *conn.Conn, c2 *conn.Conn, conf config.BaseConf, needRecord boo
pipeEncrypt(from, to, conf, needRecord) pipeEncrypt(from, to, conf, needRecord)
} }
decryptPipe := func(to *conn.Conn, from *conn.Conn) { decryptPipe := func(to io.ReadCloser, from io.WriteCloser) {
defer from.Close() defer from.Close()
defer to.Close() defer to.Close()
defer wait.Done() defer wait.Done()
@ -109,7 +109,7 @@ func unpkgMsg(data []byte) (int, []byte, []byte) {
} }
// decrypt msg from reader, then write into writer // decrypt msg from reader, then write into writer
func pipeDecrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bool) (err error) { func pipeDecrypt(r io.Reader, w io.Writer, conf config.BaseConf, needRecord bool) (err error) {
laes := new(pcrypto.Pcrypto) laes := new(pcrypto.Pcrypto)
key := conf.AuthToken key := conf.AuthToken
if conf.PrivilegeMode { if conf.PrivilegeMode {
@ -175,7 +175,7 @@ func pipeDecrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bo
} }
} }
_, err = w.WriteBytes(res) _, err = w.Write(res)
if err != nil { if err != nil {
return err return err
} }
@ -192,7 +192,7 @@ func pipeDecrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bo
} }
// recvive msg from reader, then encrypt msg into writer // recvive msg from reader, then encrypt msg into writer
func pipeEncrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bool) (err error) { func pipeEncrypt(r io.Reader, w io.Writer, conf config.BaseConf, needRecord bool) (err error) {
laes := new(pcrypto.Pcrypto) laes := new(pcrypto.Pcrypto)
key := conf.AuthToken key := conf.AuthToken
if conf.PrivilegeMode { if conf.PrivilegeMode {
@ -247,7 +247,7 @@ func pipeEncrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bo
} }
res = pkgMsg(res) res = pkgMsg(res)
_, err = w.WriteBytes(res) _, err = w.Write(res)
if err != nil { if err != nil {
return err return err
} }

72
src/models/msg/udp.go Normal file
View File

@ -0,0 +1,72 @@
// Copyright 2016 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msg
import (
"encoding/base64"
"encoding/json"
"net"
)
type UdpPacket struct {
Content []byte `json:"-"`
Src *net.UDPAddr `json:"-"`
Dst *net.UDPAddr `json:"-"`
EncodeContent string `json:"content"`
SrcStr string `json:"src"`
DstStr string `json:"dst"`
}
func NewUdpPacket(content []byte, src, dst *net.UDPAddr) *UdpPacket {
up := &UdpPacket{
Src: src,
Dst: dst,
EncodeContent: base64.StdEncoding.EncodeToString(content),
SrcStr: src.String(),
DstStr: dst.String(),
}
return up
}
// parse one udp packet struct to bytes
func (up *UdpPacket) Pack() []byte {
b, _ := json.Marshal(up)
return b
}
// parse from bytes to UdpPacket struct
func (up *UdpPacket) UnPack(packet []byte) error {
err := json.Unmarshal(packet, &up)
if err != nil {
return err
}
up.Content, err = base64.StdEncoding.DecodeString(up.EncodeContent)
if err != nil {
return err
}
up.Src, err = net.ResolveUDPAddr("udp", up.SrcStr)
if err != nil {
return err
}
up.Dst, err = net.ResolveUDPAddr("udp", up.DstStr)
if err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,50 @@
// Copyright 2016 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msg
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
var (
content string = "udp packet test"
src string = "1.1.1.1:1000"
dst string = "2.2.2.2:2000"
udpMsg *UdpPacket
)
func init() {
srcAddr, _ := net.ResolveUDPAddr("udp", src)
dstAddr, _ := net.ResolveUDPAddr("udp", dst)
udpMsg = NewUdpPacket([]byte(content), srcAddr, dstAddr)
}
func TestPack(t *testing.T) {
assert := assert.New(t)
msg := udpMsg.Pack()
assert.Equal(string(msg), `{"content":"dWRwIHBhY2tldCB0ZXN0","src":"1.1.1.1:1000","dst":"2.2.2.2:2000"}`)
}
func TestUnpack(t *testing.T) {
assert := assert.New(t)
udpMsg.UnPack([]byte(`{"content":"dWRwIHBhY2tldCB0ZXN0","src":"1.1.1.1:1000","dst":"2.2.2.2:2000"}`))
assert.Equal(content, string(udpMsg.Content))
assert.Equal(src, udpMsg.Src.String())
assert.Equal(dst, udpMsg.Dst.String())
}

View File

@ -240,7 +240,7 @@ func loadProxyConf(confFile string) (proxyServers map[string]*ProxyServer, err e
proxyServer.Type, ok = section["type"] proxyServer.Type, ok = section["type"]
if ok { if ok {
if proxyServer.Type != "tcp" && proxyServer.Type != "http" && proxyServer.Type != "https" { if proxyServer.Type != "tcp" && proxyServer.Type != "http" && proxyServer.Type != "https" && proxyServer.Type != "udp" {
return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] type error", proxyServer.Name) return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] type error", proxyServer.Name)
} }
} else { } else {
@ -252,8 +252,8 @@ func loadProxyConf(confFile string) (proxyServers map[string]*ProxyServer, err e
return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] no auth_token found", proxyServer.Name) return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] no auth_token found", proxyServer.Name)
} }
// for tcp // for tcp and udp
if proxyServer.Type == "tcp" { if proxyServer.Type == "tcp" || proxyServer.Type == "udp" {
proxyServer.BindAddr, ok = section["bind_addr"] proxyServer.BindAddr, ok = section["bind_addr"]
if !ok { if !ok {
proxyServer.BindAddr = "0.0.0.0" proxyServer.BindAddr = "0.0.0.0"

View File

@ -16,6 +16,7 @@ package server
import ( import (
"fmt" "fmt"
"net"
"sync" "sync"
"time" "time"
@ -25,6 +26,7 @@ import (
"github.com/fatedier/frp/src/models/msg" "github.com/fatedier/frp/src/models/msg"
"github.com/fatedier/frp/src/utils/conn" "github.com/fatedier/frp/src/utils/conn"
"github.com/fatedier/frp/src/utils/log" "github.com/fatedier/frp/src/utils/log"
"github.com/fatedier/frp/src/utils/pool"
) )
type Listener interface { type Listener interface {
@ -38,13 +40,17 @@ type ProxyServer struct {
ListenPort int64 ListenPort int64
CustomDomains []string CustomDomains []string
Status int64 Status int64
CtlConn *conn.Conn // control connection with frpc CtlConn *conn.Conn // control connection with frpc
listeners []Listener // accept new connection from remote users WorkConnUdp *conn.Conn // work connection for udp
ctlMsgChan chan int64 // every time accept a new user conn, put "1" to the channel
workConnChan chan *conn.Conn // get new work conns from control goroutine udpConn *net.UDPConn
mutex sync.RWMutex listeners []Listener // accept new connection from remote users
closeChan chan struct{} // for notify other goroutines that the proxy is closed by close this channel ctlMsgChan chan int64 // every time accept a new user conn, put "1" to the channel
workConnChan chan *conn.Conn // get new work conns from control goroutine
udpSenderChan chan *msg.UdpPacket
mutex sync.RWMutex
closeChan chan struct{} // close this channel for notifying other goroutines that the proxy is closed
} }
func NewProxyServer() (p *ProxyServer) { func NewProxyServer() (p *ProxyServer) {
@ -83,6 +89,7 @@ func (p *ProxyServer) Init() {
metric.SetStatus(p.Name, p.Status) metric.SetStatus(p.Name, p.Status)
p.workConnChan = make(chan *conn.Conn, p.PoolCount+10) p.workConnChan = make(chan *conn.Conn, p.PoolCount+10)
p.ctlMsgChan = make(chan int64, p.PoolCount+10) p.ctlMsgChan = make(chan int64, p.PoolCount+10)
p.udpSenderChan = make(chan *msg.UdpPacket, 1024)
p.listeners = make([]Listener, 0) p.listeners = make([]Listener, 0)
p.closeChan = make(chan struct{}) p.closeChan = make(chan struct{})
p.Unlock() p.Unlock()
@ -150,41 +157,68 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
go p.connectionPoolManager(p.closeChan) go p.connectionPoolManager(p.closeChan)
} }
// start a goroutine for every listener to accept user connection if p.Type == "udp" {
for _, listener := range p.listeners { // udp is special
go func(l Listener) { p.udpConn, err = conn.ListenUDP(p.BindAddr, p.ListenPort)
if err != nil {
log.Warn("ProxyName [%s], listen udp port error: %v", p.Name, err)
return err
}
go func() {
for { for {
// block buf := pool.GetBuf(2048)
// if listener is closed, err returned n, remoteAddr, err := p.udpConn.ReadFromUDP(buf)
c, err := l.Accept()
if err != nil { if err != nil {
log.Info("ProxyName [%s], listener is closed", p.Name) log.Info("ProxyName [%s], udp listener is closed", p.Name)
return return
} }
log.Debug("ProxyName [%s], get one new user conn [%s]", p.Name, c.GetRemoteAddr()) localAddr, _ := net.ResolveUDPAddr("udp", p.udpConn.LocalAddr().String())
udpPacket := msg.NewUdpPacket(buf[0:n], remoteAddr, localAddr)
if p.Status != consts.Working { select {
log.Debug("ProxyName [%s] is not working, new user conn close", p.Name) case p.udpSenderChan <- udpPacket:
c.Close() default:
return log.Warn("ProxyName [%s], udp sender channel is full", p.Name)
} }
pool.PutBuf(buf)
go func(userConn *conn.Conn) { }
workConn, err := p.getWorkConn() }()
} else {
// start a goroutine for every listener to accept user connection
for _, listener := range p.listeners {
go func(l Listener) {
for {
// block
// if listener is closed, err returned
c, err := l.Accept()
if err != nil { if err != nil {
log.Info("ProxyName [%s], listener is closed", p.Name)
return
}
log.Debug("ProxyName [%s], get one new user conn [%s]", p.Name, c.GetRemoteAddr())
if p.Status != consts.Working {
log.Debug("ProxyName [%s] is not working, new user conn close", p.Name)
c.Close()
return return
} }
// message will be transferred to another without modifying go func(userConn *conn.Conn) {
// l means local, r means remote workConn, err := p.getWorkConn()
log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", workConn.GetLocalAddr(), workConn.GetRemoteAddr(), if err != nil {
userConn.GetLocalAddr(), userConn.GetRemoteAddr()) return
}
needRecord := true // message will be transferred to another without modifying
go msg.JoinMore(userConn, workConn, p.BaseConf, needRecord) // l means local, r means remote
}(c) log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", workConn.GetLocalAddr(), workConn.GetRemoteAddr(),
} userConn.GetLocalAddr(), userConn.GetRemoteAddr())
}(listener)
needRecord := true
go msg.JoinMore(userConn, workConn, p.BaseConf, needRecord)
}(c)
}
}(listener)
}
} }
return nil return nil
} }
@ -200,10 +234,18 @@ func (p *ProxyServer) Close() {
} }
close(p.ctlMsgChan) close(p.ctlMsgChan)
close(p.workConnChan) close(p.workConnChan)
close(p.udpSenderChan)
close(p.closeChan) close(p.closeChan)
if p.CtlConn != nil { if p.CtlConn != nil {
p.CtlConn.Close() p.CtlConn.Close()
} }
if p.WorkConnUdp != nil {
p.WorkConnUdp.Close()
}
if p.udpConn != nil {
p.udpConn.Close()
p.udpConn = nil
}
} }
metric.SetStatus(p.Name, p.Status) metric.SetStatus(p.Name, p.Status)
// if the proxy created by PrivilegeMode, delete it when closed // if the proxy created by PrivilegeMode, delete it when closed
@ -228,9 +270,60 @@ func (p *ProxyServer) RegisterNewWorkConn(c *conn.Conn) {
case p.workConnChan <- c: case p.workConnChan <- c:
default: default:
log.Debug("ProxyName [%s], workConnChan is full, so close this work connection", p.Name) log.Debug("ProxyName [%s], workConnChan is full, so close this work connection", p.Name)
c.Close()
} }
} }
// create a tcp connection for forwarding udp packages
func (p *ProxyServer) RegisterNewWorkConnUdp(c *conn.Conn) {
if p.WorkConnUdp != nil && !p.WorkConnUdp.IsClosed() {
p.WorkConnUdp.Close()
}
p.WorkConnUdp = c
// read
go func() {
var (
buf string
err error
)
for {
buf, err = c.ReadLine()
if err != nil {
log.Warn("ProxyName [%s], work connection for udp closed", p.Name)
return
}
udpPacket := &msg.UdpPacket{}
err = udpPacket.UnPack([]byte(buf))
if err != nil {
log.Warn("ProxyName [%s], unpack udp packet error: %v", p.Name, err)
continue
}
// send to user
_, err = p.udpConn.WriteToUDP(udpPacket.Content, udpPacket.Dst)
if err != nil {
continue
}
}
}()
// write
go func() {
for {
udpPacket, ok := <-p.udpSenderChan
if !ok {
return
}
err := c.WriteString(string(udpPacket.Pack()) + "\n")
if err != nil {
log.Debug("ProxyName [%s], write to work connection for udp error: %v", p.Name, err)
return
}
}
}()
}
// When frps get one user connection, we get one work connection from the pool and return it. // When frps get one user connection, we get one work connection from the pool and return it.
// If no workConn available in the pool, send message to frpc to get one or more // If no workConn available in the pool, send message to frpc to get one or more
// and wait until it is available. // and wait until it is available.

View File

@ -202,12 +202,12 @@ func (c *Conn) ReadLine() (buff string, err error) {
return buff, err return buff, err
} }
func (c *Conn) WriteBytes(content []byte) (n int, err error) { func (c *Conn) Write(content []byte) (n int, err error) {
n, err = c.TcpConn.Write(content) n, err = c.TcpConn.Write(content)
return return
} }
func (c *Conn) Write(content string) (err error) { func (c *Conn) WriteString(content string) (err error) {
_, err = c.TcpConn.Write([]byte(content)) _, err = c.TcpConn.Write([]byte(content))
return err return err
} }
@ -220,13 +220,14 @@ func (c *Conn) SetReadDeadline(t time.Time) error {
return c.TcpConn.SetReadDeadline(t) return c.TcpConn.SetReadDeadline(t)
} }
func (c *Conn) Close() { func (c *Conn) Close() error {
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock()
if c.TcpConn != nil && c.closeFlag == false { if c.TcpConn != nil && c.closeFlag == false {
c.closeFlag = true c.closeFlag = true
c.TcpConn.Close() c.TcpConn.Close()
} }
c.mutex.Unlock() return nil
} }
func (c *Conn) IsClosed() (closeFlag bool) { func (c *Conn) IsClosed() (closeFlag bool) {
@ -245,7 +246,6 @@ func (c *Conn) CheckClosed() bool {
} }
c.mutex.RUnlock() c.mutex.RUnlock()
// err := c.TcpConn.SetReadDeadline(time.Now().Add(100 * time.Microsecond))
err := c.TcpConn.SetReadDeadline(time.Now().Add(time.Millisecond)) err := c.TcpConn.SetReadDeadline(time.Now().Add(time.Millisecond))
if err != nil { if err != nil {
c.Close() c.Close()

View File

@ -0,0 +1,29 @@
// Copyright 2016 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package conn
import (
"fmt"
"net"
)
func ListenUDP(bindAddr string, bindPort int64) (conn *net.UDPConn, err error) {
udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
if err != nil {
return conn, err
}
conn, err = net.ListenUDP("udp", udpAddr)
return
}

View File

@ -40,6 +40,6 @@ func echoWorker(c *conn.Conn) {
return return
} }
c.Write(buff) c.WriteString(buff)
} }
} }

View File

@ -26,7 +26,7 @@ func TestEchoServer(t *testing.T) {
timer := time.Now().Add(time.Duration(5) * time.Second) timer := time.Now().Add(time.Duration(5) * time.Second)
c.SetDeadline(timer) c.SetDeadline(timer)
c.Write(ECHO_TEST_STR) c.WriteString(ECHO_TEST_STR)
buff, err := c.ReadLine() buff, err := c.ReadLine()
if err != nil { if err != nil {