new feature: assign a random port if remote_port is 0 in type tcp and

udp
This commit is contained in:
fatedier 2018-01-17 21:49:37 +08:00
parent 3f6799c06a
commit b2c846664d
21 changed files with 379 additions and 238 deletions

View File

@ -3,7 +3,7 @@ language: go
go:
- 1.8.x
- 1.x
- 1.9.x
install:
- make

View File

@ -31,7 +31,7 @@ var (
httpServerWriteTimeout = 10 * time.Second
)
func (svr *Service) RunAdminServer(addr string, port int64) (err error) {
func (svr *Service) RunAdminServer(addr string, port int) (err error) {
// url router
router := httprouter.New()

View File

@ -124,12 +124,20 @@ func NewProxyStatusResp(status *ProxyStatus) ProxyStatusResp {
psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort)
}
psr.Plugin = cfg.Plugin
psr.RemoteAddr = config.ClientCommonCfg.ServerAddr + status.RemoteAddr
if status.Err != "" {
psr.RemoteAddr = fmt.Sprintf("%s:%d", config.ClientCommonCfg.ServerAddr, cfg.RemotePort)
} else {
psr.RemoteAddr = config.ClientCommonCfg.ServerAddr + status.RemoteAddr
}
case *config.UdpProxyConf:
if cfg.LocalPort != 0 {
psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort)
}
psr.RemoteAddr = config.ClientCommonCfg.ServerAddr + status.RemoteAddr
if status.Err != "" {
psr.RemoteAddr = fmt.Sprintf("%s:%d", config.ClientCommonCfg.ServerAddr, cfg.RemotePort)
} else {
psr.RemoteAddr = config.ClientCommonCfg.ServerAddr + status.RemoteAddr
}
case *config.HttpProxyConf:
if cfg.LocalPort != 0 {
psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort)

View File

@ -77,7 +77,7 @@ type StcpVisitor struct {
}
func (sv *StcpVisitor) Run() (err error) {
sv.l, err = frpNet.ListenTcp(sv.cfg.BindAddr, int64(sv.cfg.BindPort))
sv.l, err = frpNet.ListenTcp(sv.cfg.BindAddr, sv.cfg.BindPort)
if err != nil {
return
}
@ -164,7 +164,7 @@ type XtcpVisitor struct {
}
func (sv *XtcpVisitor) Run() (err error) {
sv.l, err = frpNet.ListenTcp(sv.cfg.BindAddr, int64(sv.cfg.BindPort))
sv.l, err = frpNet.ListenTcp(sv.cfg.BindAddr, sv.cfg.BindPort)
if err != nil {
return
}
@ -255,7 +255,7 @@ func (sv *XtcpVisitor) handleConn(userConn frpNet.Conn) {
sv.Error("get natHoleResp client address error: %s", natHoleRespMsg.ClientAddr)
return
}
sv.sendDetectMsg(array[0], int64(port), laddr, []byte(natHoleRespMsg.Sid))
sv.sendDetectMsg(array[0], int(port), laddr, []byte(natHoleRespMsg.Sid))
sv.Trace("send all detect msg done")
// Listen for visitorConn's address and wait for client connection.
@ -302,7 +302,7 @@ func (sv *XtcpVisitor) handleConn(userConn frpNet.Conn) {
sv.Debug("join connections closed")
}
func (sv *XtcpVisitor) sendDetectMsg(addr string, port int64, laddr *net.UDPAddr, content []byte) (err error) {
func (sv *XtcpVisitor) sendDetectMsg(addr string, port int, laddr *net.UDPAddr, content []byte) (err error) {
daddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addr, port))
if err != nil {
return err

View File

@ -99,7 +99,7 @@ func main() {
if args["status"] != nil {
if args["status"].(bool) {
if err = CmdStatus(); err != nil {
fmt.Println("frps get status error: %v\n", err)
fmt.Printf("frps get status error: %v\n", err)
os.Exit(1)
} else {
os.Exit(0)
@ -132,7 +132,7 @@ func main() {
os.Exit(1)
}
config.ClientCommonCfg.ServerAddr = addr[0]
config.ClientCommonCfg.ServerPort = serverPort
config.ClientCommonCfg.ServerPort = int(serverPort)
}
if args["-v"] != nil {

View File

@ -91,7 +91,7 @@ func main() {
os.Exit(1)
}
config.ServerCommonCfg.BindAddr = addr[0]
config.ServerCommonCfg.BindPort = bindPort
config.ServerCommonCfg.BindPort = int(bindPort)
}
if args["-v"] != nil {

View File

@ -29,8 +29,8 @@ var ClientCommonCfg *ClientCommonConf
type ClientCommonConf struct {
ConfigFile string
ServerAddr string
ServerPort int64
ServerUdpPort int64 // this is specified by login response message from frps
ServerPort int
ServerUdpPort int // this is specified by login response message from frps
HttpProxy string
LogFile string
LogWay string
@ -38,7 +38,7 @@ type ClientCommonConf struct {
LogMaxDays int64
PrivilegeToken string
AdminAddr string
AdminPort int64
AdminPort int
AdminUser string
AdminPwd string
PoolCount int
@ -93,7 +93,12 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) {
tmpStr, ok = conf.Get("common", "server_port")
if ok {
cfg.ServerPort, _ = strconv.ParseInt(tmpStr, 10, 64)
v, err = strconv.ParseInt(tmpStr, 10, 64)
if err != nil {
err = fmt.Errorf("Parse conf error: invalid server_port")
return
}
cfg.ServerPort = int(v)
}
tmpStr, ok = conf.Get("common", "http_proxy")
@ -139,7 +144,10 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) {
tmpStr, ok = conf.Get("common", "admin_port")
if ok {
if v, err = strconv.ParseInt(tmpStr, 10, 64); err == nil {
cfg.AdminPort = v
cfg.AdminPort = int(v)
} else {
err = fmt.Errorf("Parse conf error: invalid admin_port")
return
}
}
@ -203,7 +211,7 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) {
if ok {
v, err = strconv.ParseInt(tmpStr, 10, 64)
if err != nil {
err = fmt.Errorf("Parse conf error: heartbeat_timeout is incorrect")
err = fmt.Errorf("Parse conf error: invalid heartbeat_timeout")
return
} else {
cfg.HeartBeatTimeout = v
@ -214,7 +222,7 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) {
if ok {
v, err = strconv.ParseInt(tmpStr, 10, 64)
if err != nil {
err = fmt.Errorf("Parse conf error: heartbeat_interval is incorrect")
err = fmt.Errorf("Parse conf error: invalid heartbeat_interval")
return
} else {
cfg.HeartBeatInterval = v
@ -222,12 +230,12 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) {
}
if cfg.HeartBeatInterval <= 0 {
err = fmt.Errorf("Parse conf error: heartbeat_interval is incorrect")
err = fmt.Errorf("Parse conf error: invalid heartbeat_interval")
return
}
if cfg.HeartBeatTimeout < cfg.HeartBeatInterval {
err = fmt.Errorf("Parse conf error: heartbeat_timeout is incorrect, heartbeat_timeout is less than heartbeat_interval")
err = fmt.Errorf("Parse conf error: invalid heartbeat_timeout, heartbeat_timeout is less than heartbeat_interval")
return
}
return

View File

@ -23,7 +23,6 @@ import (
"github.com/fatedier/frp/models/consts"
"github.com/fatedier/frp/models/msg"
"github.com/fatedier/frp/utils/util"
ini "github.com/vaughan0/go-ini"
)
@ -163,7 +162,7 @@ func (cfg *BaseProxyConf) UnMarshalToMsg(pMsg *msg.NewProxy) {
// Bind info
type BindInfoConf struct {
BindAddr string `json:"bind_addr"`
RemotePort int64 `json:"remote_port"`
RemotePort int `json:"remote_port"`
}
func (cfg *BindInfoConf) compare(cmp *BindInfoConf) bool {
@ -183,10 +182,13 @@ func (cfg *BindInfoConf) LoadFromFile(name string, section ini.Section) (err err
var (
tmpStr string
ok bool
v int64
)
if tmpStr, ok = section["remote_port"]; ok {
if cfg.RemotePort, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
return fmt.Errorf("Parse conf error: proxy [%s] remote_port error", name)
} else {
cfg.RemotePort = int(v)
}
} else {
return fmt.Errorf("Parse conf error: proxy [%s] remote_port not found", name)
@ -199,11 +201,6 @@ func (cfg *BindInfoConf) UnMarshalToMsg(pMsg *msg.NewProxy) {
}
func (cfg *BindInfoConf) check() (err error) {
if len(ServerCommonCfg.PrivilegeAllowPorts) != 0 {
if ok := util.ContainsPort(ServerCommonCfg.PrivilegeAllowPorts, cfg.RemotePort); !ok {
return fmt.Errorf("remote port [%d] isn't allowed", cfg.RemotePort)
}
}
return nil
}

View File

@ -19,7 +19,6 @@ import (
"strconv"
"strings"
"github.com/fatedier/frp/utils/util"
ini "github.com/vaughan0/go-ini"
)
@ -29,20 +28,20 @@ var ServerCommonCfg *ServerCommonConf
type ServerCommonConf struct {
ConfigFile string
BindAddr string
BindPort int64
BindUdpPort int64
KcpBindPort int64
BindPort int
BindUdpPort int
KcpBindPort int
ProxyBindAddr string
// If VhostHttpPort equals 0, don't listen a public port for http protocol.
VhostHttpPort int64
VhostHttpPort int
// if VhostHttpsPort equals 0, don't listen a public port for https protocol
VhostHttpsPort int64
VhostHttpsPort int
DashboardAddr string
// if DashboardPort equals 0, dashboard is not available
DashboardPort int64
DashboardPort int
DashboardUser string
DashboardPwd string
AssetsDir string
@ -56,8 +55,7 @@ type ServerCommonConf struct {
SubDomainHost string
TcpMux bool
// if PrivilegeAllowPorts is not nil, tcp proxies which remote port exist in this map can be connected
PrivilegeAllowPorts [][2]int64
PrivilegeAllowPorts map[int]struct{}
MaxPoolCount int64
HeartBeatTimeout int64
UserConnTimeout int64
@ -65,31 +63,32 @@ type ServerCommonConf struct {
func GetDefaultServerCommonConf() *ServerCommonConf {
return &ServerCommonConf{
ConfigFile: "./frps.ini",
BindAddr: "0.0.0.0",
BindPort: 7000,
BindUdpPort: 0,
KcpBindPort: 0,
ProxyBindAddr: "0.0.0.0",
VhostHttpPort: 0,
VhostHttpsPort: 0,
DashboardAddr: "0.0.0.0",
DashboardPort: 0,
DashboardUser: "admin",
DashboardPwd: "admin",
AssetsDir: "",
LogFile: "console",
LogWay: "console",
LogLevel: "info",
LogMaxDays: 3,
PrivilegeMode: true,
PrivilegeToken: "",
AuthTimeout: 900,
SubDomainHost: "",
TcpMux: true,
MaxPoolCount: 5,
HeartBeatTimeout: 90,
UserConnTimeout: 10,
ConfigFile: "./frps.ini",
BindAddr: "0.0.0.0",
BindPort: 7000,
BindUdpPort: 0,
KcpBindPort: 0,
ProxyBindAddr: "0.0.0.0",
VhostHttpPort: 0,
VhostHttpsPort: 0,
DashboardAddr: "0.0.0.0",
DashboardPort: 0,
DashboardUser: "admin",
DashboardPwd: "admin",
AssetsDir: "",
LogFile: "console",
LogWay: "console",
LogLevel: "info",
LogMaxDays: 3,
PrivilegeMode: true,
PrivilegeToken: "",
AuthTimeout: 900,
SubDomainHost: "",
TcpMux: true,
PrivilegeAllowPorts: make(map[int]struct{}),
MaxPoolCount: 5,
HeartBeatTimeout: 90,
UserConnTimeout: 10,
}
}
@ -109,25 +108,31 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) {
tmpStr, ok = conf.Get("common", "bind_port")
if ok {
v, err = strconv.ParseInt(tmpStr, 10, 64)
if err == nil {
cfg.BindPort = v
if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
err = fmt.Errorf("Parse conf error: invalid bind_port")
return
} else {
cfg.BindPort = int(v)
}
}
tmpStr, ok = conf.Get("common", "bind_udp_port")
if ok {
v, err = strconv.ParseInt(tmpStr, 10, 64)
if err == nil {
cfg.BindUdpPort = v
if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
err = fmt.Errorf("Parse conf error: invalid bind_udp_port")
return
} else {
cfg.BindUdpPort = int(v)
}
}
tmpStr, ok = conf.Get("common", "kcp_bind_port")
if ok {
v, err = strconv.ParseInt(tmpStr, 10, 64)
if err == nil && v > 0 {
cfg.KcpBindPort = v
if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
err = fmt.Errorf("Parse conf error: invalid kcp_bind_port")
return
} else {
cfg.KcpBindPort = int(v)
}
}
@ -140,10 +145,11 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) {
tmpStr, ok = conf.Get("common", "vhost_http_port")
if ok {
cfg.VhostHttpPort, err = strconv.ParseInt(tmpStr, 10, 64)
if err != nil {
err = fmt.Errorf("Parse conf error: vhost_http_port is incorrect")
if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
err = fmt.Errorf("Parse conf error: invalid vhost_http_port")
return
} else {
cfg.VhostHttpPort = int(v)
}
} else {
cfg.VhostHttpPort = 0
@ -151,10 +157,11 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) {
tmpStr, ok = conf.Get("common", "vhost_https_port")
if ok {
cfg.VhostHttpsPort, err = strconv.ParseInt(tmpStr, 10, 64)
if err != nil {
err = fmt.Errorf("Parse conf error: vhost_https_port is incorrect")
if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
err = fmt.Errorf("Parse conf error: invalid vhost_https_port")
return
} else {
cfg.VhostHttpsPort = int(v)
}
} else {
cfg.VhostHttpsPort = 0
@ -169,10 +176,11 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) {
tmpStr, ok = conf.Get("common", "dashboard_port")
if ok {
cfg.DashboardPort, err = strconv.ParseInt(tmpStr, 10, 64)
if err != nil {
err = fmt.Errorf("Parse conf error: dashboard_port is incorrect")
if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
err = fmt.Errorf("Parse conf error: invalid dashboard_port")
return
} else {
cfg.DashboardPort = int(v)
}
} else {
cfg.DashboardPort = 0
@ -228,12 +236,45 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) {
cfg.PrivilegeToken, _ = conf.Get("common", "privilege_token")
allowPortsStr, ok := conf.Get("common", "privilege_allow_ports")
// TODO: check if conflicts exist in port ranges
if ok {
cfg.PrivilegeAllowPorts, err = util.GetPortRanges(allowPortsStr)
if err != nil {
err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", err)
return
// e.g. 1000-2000,2001,2002,3000-4000
portRanges := strings.Split(allowPortsStr, ",")
for _, portRangeStr := range portRanges {
// 1000-2000 or 2001
portArray := strings.Split(portRangeStr, "-")
// length: only 1 or 2 is correct
rangeType := len(portArray)
if rangeType == 1 {
// single port
singlePort, errRet := strconv.ParseInt(portArray[0], 10, 64)
if errRet != nil {
err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet)
return
}
cfg.PrivilegeAllowPorts[int(singlePort)] = struct{}{}
} else if rangeType == 2 {
// range ports
min, errRet := strconv.ParseInt(portArray[0], 10, 64)
if errRet != nil {
err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet)
return
}
max, errRet := strconv.ParseInt(portArray[1], 10, 64)
if errRet != nil {
err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet)
return
}
if max < min {
err = fmt.Errorf("Parse conf error: privilege_allow_ports range incorrect")
return
}
for i := min; i <= max; i++ {
cfg.PrivilegeAllowPorts[int(i)] = struct{}{}
}
} else {
err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect")
return
}
}
}
}

View File

@ -92,7 +92,7 @@ type Login struct {
type LoginResp struct {
Version string `json:"version"`
RunId string `json:"run_id"`
ServerUdpPort int64 `json:"server_udp_port"`
ServerUdpPort int `json:"server_udp_port"`
Error string `json:"error"`
}
@ -104,7 +104,7 @@ type NewProxy struct {
UseCompression bool `json:"use_compression"`
// tcp and udp only
RemotePort int64 `json:"remote_port"`
RemotePort int `json:"remote_port"`
// http and https only
CustomDomains []string `json:"custom_domains"`

View File

@ -32,7 +32,7 @@ var (
httpServerWriteTimeout = 10 * time.Second
)
func RunDashboardServer(addr string, port int64) (err error) {
func RunDashboardServer(addr string, port int) (err error) {
// url router
router := httprouter.New()

View File

@ -36,8 +36,8 @@ type ServerInfoResp struct {
GeneralResponse
Version string `json:"version"`
VhostHttpPort int64 `json:"vhost_http_port"`
VhostHttpsPort int64 `json:"vhost_https_port"`
VhostHttpPort int `json:"vhost_http_port"`
VhostHttpsPort int `json:"vhost_https_port"`
AuthTimeout int64 `json:"auth_timeout"`
SubdomainHost string `json:"subdomain_host"`
MaxPoolCount int64 `json:"max_pool_count"`

180
server/ports.go Normal file
View File

@ -0,0 +1,180 @@
package server
import (
"errors"
"fmt"
"net"
"sync"
"time"
)
const (
MinPort = 1025
MaxPort = 65535
MaxPortReservedDuration = time.Duration(24) * time.Hour
CleanReservedPortsInterval = time.Hour
)
var (
ErrPortAlreadyUsed = errors.New("port already used")
ErrPortNotAllowed = errors.New("port not allowed")
ErrPortUnAvailable = errors.New("port unavailable")
ErrNoAvailablePort = errors.New("no available port")
)
type PortCtx struct {
ProxyName string
Port int
Closed bool
UpdateTime time.Time
}
type PortManager struct {
reservedPorts map[string]*PortCtx
usedPorts map[int]*PortCtx
freePorts map[int]struct{}
bindAddr string
netType string
mu sync.Mutex
}
func NewPortManager(netType string, bindAddr string, allowPorts map[int]struct{}) *PortManager {
pm := &PortManager{
reservedPorts: make(map[string]*PortCtx),
usedPorts: make(map[int]*PortCtx),
freePorts: make(map[int]struct{}),
bindAddr: bindAddr,
netType: netType,
}
if len(allowPorts) > 0 {
for port, _ := range allowPorts {
pm.freePorts[port] = struct{}{}
}
} else {
for i := MinPort; i <= MaxPort; i++ {
pm.freePorts[i] = struct{}{}
}
}
go pm.cleanReservedPortsWorker()
return pm
}
func (pm *PortManager) Acquire(name string, port int) (realPort int, err error) {
portCtx := &PortCtx{
ProxyName: name,
Closed: false,
UpdateTime: time.Now(),
}
var ok bool
pm.mu.Lock()
defer func() {
if err == nil {
portCtx.Port = realPort
}
pm.mu.Unlock()
}()
// check reserved ports first
if port == 0 {
if ctx, ok := pm.reservedPorts[name]; ok {
if pm.isPortAvailable(ctx.Port) {
realPort = ctx.Port
pm.usedPorts[realPort] = portCtx
pm.reservedPorts[name] = portCtx
delete(pm.freePorts, realPort)
return
}
}
}
if port == 0 {
// get random port
count := 0
maxTryTimes := 5
for k, _ := range pm.freePorts {
count++
if count > maxTryTimes {
break
}
if pm.isPortAvailable(k) {
realPort = k
pm.usedPorts[realPort] = portCtx
pm.reservedPorts[name] = portCtx
delete(pm.freePorts, realPort)
break
}
}
if realPort == 0 {
err = ErrNoAvailablePort
}
} else {
// specified port
if _, ok = pm.freePorts[port]; ok {
if pm.isPortAvailable(port) {
realPort = port
pm.usedPorts[realPort] = portCtx
pm.reservedPorts[name] = portCtx
delete(pm.freePorts, realPort)
} else {
err = ErrPortUnAvailable
}
} else {
if _, ok = pm.usedPorts[port]; ok {
err = ErrPortAlreadyUsed
} else {
err = ErrPortNotAllowed
}
}
}
return
}
func (pm *PortManager) isPortAvailable(port int) bool {
if pm.netType == "udp" {
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pm.bindAddr, port))
if err != nil {
return false
}
l, err := net.ListenUDP("udp", addr)
if err != nil {
return false
}
l.Close()
return true
} else {
l, err := net.Listen(pm.netType, fmt.Sprintf("%s:%d", pm.bindAddr, port))
if err != nil {
return false
}
l.Close()
return true
}
}
func (pm *PortManager) Release(port int) {
pm.mu.Lock()
defer pm.mu.Unlock()
if ctx, ok := pm.usedPorts[port]; ok {
pm.freePorts[port] = struct{}{}
delete(pm.usedPorts, port)
ctx.Closed = true
ctx.UpdateTime = time.Now()
}
}
// Release reserved port if it isn't used in last 24 hours.
func (pm *PortManager) cleanReservedPortsWorker() {
for {
time.Sleep(CleanReservedPortsInterval)
pm.mu.Lock()
for name, ctx := range pm.reservedPorts {
if ctx.Closed && time.Since(ctx.UpdateTime) > MaxPortReservedDuration {
delete(pm.reservedPorts, name)
}
}
pm.mu.Unlock()
}
}

View File

@ -165,11 +165,24 @@ func NewProxy(ctl *Control, pxyConf config.ProxyConf) (pxy Proxy, err error) {
type TcpProxy struct {
BaseProxy
cfg *config.TcpProxyConf
realPort int
}
func (pxy *TcpProxy) Run() (remoteAddr string, err error) {
remoteAddr = fmt.Sprintf(":%d", pxy.cfg.RemotePort)
listener, errRet := frpNet.ListenTcp(config.ServerCommonCfg.ProxyBindAddr, pxy.cfg.RemotePort)
pxy.realPort, err = pxy.ctl.svr.tcpPortManager.Acquire(pxy.name, pxy.cfg.RemotePort)
if err != nil {
return
}
defer func() {
if err != nil {
pxy.ctl.svr.tcpPortManager.Release(pxy.realPort)
}
}()
remoteAddr = fmt.Sprintf(":%d", pxy.realPort)
pxy.cfg.RemotePort = pxy.realPort
listener, errRet := frpNet.ListenTcp(config.ServerCommonCfg.ProxyBindAddr, pxy.realPort)
if errRet != nil {
err = errRet
return
@ -188,6 +201,7 @@ func (pxy *TcpProxy) GetConf() config.ProxyConf {
func (pxy *TcpProxy) Close() {
pxy.BaseProxy.Close()
pxy.ctl.svr.tcpPortManager.Release(pxy.realPort)
}
type HttpProxy struct {
@ -412,6 +426,8 @@ type UdpProxy struct {
BaseProxy
cfg *config.UdpProxyConf
realPort int
// udpConn is the listener of udp packages
udpConn *net.UDPConn
@ -432,8 +448,19 @@ type UdpProxy struct {
}
func (pxy *UdpProxy) Run() (remoteAddr string, err error) {
remoteAddr = fmt.Sprintf(":%d", pxy.cfg.RemotePort)
addr, errRet := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", config.ServerCommonCfg.ProxyBindAddr, pxy.cfg.RemotePort))
pxy.realPort, err = pxy.ctl.svr.udpPortManager.Acquire(pxy.name, pxy.cfg.RemotePort)
if err != nil {
return
}
defer func() {
if err != nil {
pxy.ctl.svr.udpPortManager.Release(pxy.realPort)
}
}()
remoteAddr = fmt.Sprintf(":%d", pxy.realPort)
pxy.cfg.RemotePort = pxy.realPort
addr, errRet := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", config.ServerCommonCfg.ProxyBindAddr, pxy.realPort))
if errRet != nil {
err = errRet
return
@ -581,6 +608,7 @@ func (pxy *UdpProxy) Close() {
close(pxy.readCh)
close(pxy.sendCh)
}
pxy.ctl.svr.udpPortManager.Release(pxy.realPort)
}
// HandleUserTcpConnection is used for incoming tcp user connections.

View File

@ -60,17 +60,25 @@ type Service struct {
// Manage all visitor listeners.
visitorManager *VisitorManager
// Manage all tcp ports.
tcpPortManager *PortManager
// Manage all udp ports.
udpPortManager *PortManager
// Controller for nat hole connections.
natHoleController *NatHoleController
}
func NewService() (svr *Service, err error) {
cfg := config.ServerCommonCfg
svr = &Service{
ctlManager: NewControlManager(),
pxyManager: NewProxyManager(),
visitorManager: NewVisitorManager(),
tcpPortManager: NewPortManager("tcp", cfg.ProxyBindAddr, cfg.PrivilegeAllowPorts),
udpPortManager: NewPortManager("udp", cfg.ProxyBindAddr, cfg.PrivilegeAllowPorts),
}
cfg := config.ServerCommonCfg
// Init assets.
err = assets.Load(cfg.AssetsDir)

View File

@ -10,28 +10,28 @@ import (
var (
TEST_STR = "frp is a fast reverse proxy to help you expose a local server behind a NAT or firewall to the internet."
TEST_TCP_PORT int64 = 10701
TEST_TCP_FRP_PORT int64 = 10801
TEST_TCP_EC_FRP_PORT int64 = 10901
TEST_TCP_PORT int = 10701
TEST_TCP_FRP_PORT int = 10801
TEST_TCP_EC_FRP_PORT int = 10901
TEST_TCP_ECHO_STR string = "tcp type:" + TEST_STR
TEST_UDP_PORT int64 = 10702
TEST_UDP_FRP_PORT int64 = 10802
TEST_UDP_EC_FRP_PORT int64 = 10902
TEST_UDP_PORT int = 10702
TEST_UDP_FRP_PORT int = 10802
TEST_UDP_EC_FRP_PORT int = 10902
TEST_UDP_ECHO_STR string = "udp type:" + TEST_STR
TEST_UNIX_DOMAIN_ADDR string = "/tmp/frp_echo_server.sock"
TEST_UNIX_DOMAIN_FRP_PORT int64 = 10803
TEST_UNIX_DOMAIN_FRP_PORT int = 10803
TEST_UNIX_DOMAIN_STR string = "unix domain type:" + TEST_STR
TEST_HTTP_PORT int64 = 10704
TEST_HTTP_FRP_PORT int64 = 10804
TEST_HTTP_PORT int = 10704
TEST_HTTP_FRP_PORT int = 10804
TEST_HTTP_NORMAL_STR string = "http normal string: " + TEST_STR
TEST_HTTP_FOO_STR string = "http foo string: " + TEST_STR
TEST_HTTP_BAR_STR string = "http bar string: " + TEST_STR
TEST_STCP_FRP_PORT int64 = 10805
TEST_STCP_EC_FRP_PORT int64 = 10905
TEST_STCP_FRP_PORT int = 10805
TEST_STCP_EC_FRP_PORT int = 10905
TEST_STCP_ECHO_STR string = "stcp type:" + TEST_STR
)

View File

@ -31,7 +31,7 @@ type KcpListener struct {
log.Logger
}
func ListenKcp(bindAddr string, bindPort int64) (l *KcpListener, err error) {
func ListenKcp(bindAddr string, bindPort int) (l *KcpListener, err error) {
listener, err := kcp.ListenWithOptions(fmt.Sprintf("%s:%d", bindAddr, bindPort), nil, 10, 3)
if err != nil {
return l, err

View File

@ -33,7 +33,7 @@ type TcpListener struct {
log.Logger
}
func ListenTcp(bindAddr string, bindPort int64) (l *TcpListener, err error) {
func ListenTcp(bindAddr string, bindPort int) (l *TcpListener, err error) {
tcpAddr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
if err != nil {
return l, err

View File

@ -167,7 +167,7 @@ type UdpListener struct {
log.Logger
}
func ListenUDP(bindAddr string, bindPort int64) (l *UdpListener, err error) {
func ListenUDP(bindAddr string, bindPort int) (l *UdpListener, err error) {
udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
if err != nil {
return l, err

View File

@ -19,8 +19,6 @@ import (
"crypto/rand"
"encoding/hex"
"fmt"
"strconv"
"strings"
)
// RandId return a rand string used in frp.
@ -48,69 +46,6 @@ func GetAuthKey(token string, timestamp int64) (key string) {
return hex.EncodeToString(data)
}
// for example: rangeStr is "1000-2000,2001,2002,3000-4000", return an array as port ranges.
func GetPortRanges(rangeStr string) (portRanges [][2]int64, err error) {
// for example: 1000-2000,2001,2002,3000-4000
rangeArray := strings.Split(rangeStr, ",")
for _, portRangeStr := range rangeArray {
// 1000-2000 or 2001
portArray := strings.Split(portRangeStr, "-")
// length: only 1 or 2 is correct
rangeType := len(portArray)
if rangeType == 1 {
singlePort, err := strconv.ParseInt(portArray[0], 10, 64)
if err != nil {
return [][2]int64{}, err
}
portRanges = append(portRanges, [2]int64{singlePort, singlePort})
} else if rangeType == 2 {
min, err := strconv.ParseInt(portArray[0], 10, 64)
if err != nil {
return [][2]int64{}, err
}
max, err := strconv.ParseInt(portArray[1], 10, 64)
if err != nil {
return [][2]int64{}, err
}
if max < min {
return [][2]int64{}, fmt.Errorf("range incorrect")
}
portRanges = append(portRanges, [2]int64{min, max})
} else {
return [][2]int64{}, fmt.Errorf("format error")
}
}
return portRanges, nil
}
func ContainsPort(portRanges [][2]int64, port int64) bool {
for _, pr := range portRanges {
if port >= pr[0] && port <= pr[1] {
return true
}
}
return false
}
func PortRangesCut(portRanges [][2]int64, port int64) [][2]int64 {
var tmpRanges [][2]int64
for _, pr := range portRanges {
if port >= pr[0] && port <= pr[1] {
leftRange := [2]int64{pr[0], port - 1}
rightRange := [2]int64{port + 1, pr[1]}
if leftRange[0] <= leftRange[1] {
tmpRanges = append(tmpRanges, leftRange)
}
if rightRange[0] <= rightRange[1] {
tmpRanges = append(tmpRanges, rightRange)
}
} else {
tmpRanges = append(tmpRanges, pr)
}
}
return tmpRanges
}
func CanonicalAddr(host string, port int) (addr string) {
if port == 80 || port == 443 {
addr = host

View File

@ -20,67 +20,3 @@ func TestGetAuthKey(t *testing.T) {
t.Log(key)
assert.Equal("6df41a43725f0c770fd56379e12acf8c", key)
}
func TestGetPortRanges(t *testing.T) {
assert := assert.New(t)
rangesStr := "2000-3000,3001,4000-50000"
expect := [][2]int64{
[2]int64{2000, 3000},
[2]int64{3001, 3001},
[2]int64{4000, 50000},
}
actual, err := GetPortRanges(rangesStr)
assert.Nil(err)
t.Log(actual)
assert.Equal(expect, actual)
}
func TestContainsPort(t *testing.T) {
assert := assert.New(t)
rangesStr := "2000-3000,3001,4000-50000"
portRanges, err := GetPortRanges(rangesStr)
assert.Nil(err)
type Case struct {
Port int64
Answer bool
}
cases := []Case{
Case{
Port: 3001,
Answer: true,
},
Case{
Port: 3002,
Answer: false,
},
Case{
Port: 44444,
Answer: true,
},
}
for _, elem := range cases {
ok := ContainsPort(portRanges, elem.Port)
assert.Equal(elem.Answer, ok)
}
}
func TestPortRangesCut(t *testing.T) {
assert := assert.New(t)
rangesStr := "2000-3000,3001,4000-50000"
portRanges, err := GetPortRanges(rangesStr)
assert.Nil(err)
expect := [][2]int64{
[2]int64{2000, 3000},
[2]int64{3001, 3001},
[2]int64{4000, 44443},
[2]int64{44445, 50000},
}
actual := PortRangesCut(portRanges, 44444)
t.Log(actual)
assert.Equal(expect, actual)
}