From b2c846664d0a627aa16f164bf7b5d1e19212b225 Mon Sep 17 00:00:00 2001 From: fatedier Date: Wed, 17 Jan 2018 21:49:37 +0800 Subject: [PATCH] new feature: assign a random port if remote_port is 0 in type tcp and udp --- .travis.yml | 2 +- client/admin.go | 2 +- client/admin_api.go | 12 ++- client/visitor.go | 8 +- cmd/frpc/main.go | 4 +- cmd/frps/main.go | 2 +- models/config/client_common.go | 26 +++-- models/config/proxy.go | 13 +-- models/config/server_common.go | 155 +++++++++++++++++----------- models/msg/msg.go | 4 +- server/dashboard.go | 2 +- server/dashboard_api.go | 4 +- server/ports.go | 180 +++++++++++++++++++++++++++++++++ server/proxy.go | 36 ++++++- server/service.go | 10 +- tests/func_test.go | 22 ++-- utils/net/kcp.go | 2 +- utils/net/tcp.go | 2 +- utils/net/udp.go | 2 +- utils/util/util.go | 65 ------------ utils/util/util_test.go | 64 ------------ 21 files changed, 379 insertions(+), 238 deletions(-) create mode 100644 server/ports.go diff --git a/.travis.yml b/.travis.yml index 303e1a21..51c2421c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,7 +3,7 @@ language: go go: - 1.8.x - - 1.x + - 1.9.x install: - make diff --git a/client/admin.go b/client/admin.go index 37cdf4c1..e34f44d2 100644 --- a/client/admin.go +++ b/client/admin.go @@ -31,7 +31,7 @@ var ( httpServerWriteTimeout = 10 * time.Second ) -func (svr *Service) RunAdminServer(addr string, port int64) (err error) { +func (svr *Service) RunAdminServer(addr string, port int) (err error) { // url router router := httprouter.New() diff --git a/client/admin_api.go b/client/admin_api.go index 3c64b917..1a947521 100644 --- a/client/admin_api.go +++ b/client/admin_api.go @@ -124,12 +124,20 @@ func NewProxyStatusResp(status *ProxyStatus) ProxyStatusResp { psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort) } psr.Plugin = cfg.Plugin - psr.RemoteAddr = config.ClientCommonCfg.ServerAddr + status.RemoteAddr + if status.Err != "" { + psr.RemoteAddr = fmt.Sprintf("%s:%d", config.ClientCommonCfg.ServerAddr, cfg.RemotePort) + } else { + psr.RemoteAddr = config.ClientCommonCfg.ServerAddr + status.RemoteAddr + } case *config.UdpProxyConf: if cfg.LocalPort != 0 { psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort) } - psr.RemoteAddr = config.ClientCommonCfg.ServerAddr + status.RemoteAddr + if status.Err != "" { + psr.RemoteAddr = fmt.Sprintf("%s:%d", config.ClientCommonCfg.ServerAddr, cfg.RemotePort) + } else { + psr.RemoteAddr = config.ClientCommonCfg.ServerAddr + status.RemoteAddr + } case *config.HttpProxyConf: if cfg.LocalPort != 0 { psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort) diff --git a/client/visitor.go b/client/visitor.go index e7a22d80..fd182255 100644 --- a/client/visitor.go +++ b/client/visitor.go @@ -77,7 +77,7 @@ type StcpVisitor struct { } func (sv *StcpVisitor) Run() (err error) { - sv.l, err = frpNet.ListenTcp(sv.cfg.BindAddr, int64(sv.cfg.BindPort)) + sv.l, err = frpNet.ListenTcp(sv.cfg.BindAddr, sv.cfg.BindPort) if err != nil { return } @@ -164,7 +164,7 @@ type XtcpVisitor struct { } func (sv *XtcpVisitor) Run() (err error) { - sv.l, err = frpNet.ListenTcp(sv.cfg.BindAddr, int64(sv.cfg.BindPort)) + sv.l, err = frpNet.ListenTcp(sv.cfg.BindAddr, sv.cfg.BindPort) if err != nil { return } @@ -255,7 +255,7 @@ func (sv *XtcpVisitor) handleConn(userConn frpNet.Conn) { sv.Error("get natHoleResp client address error: %s", natHoleRespMsg.ClientAddr) return } - sv.sendDetectMsg(array[0], int64(port), laddr, []byte(natHoleRespMsg.Sid)) + sv.sendDetectMsg(array[0], int(port), laddr, []byte(natHoleRespMsg.Sid)) sv.Trace("send all detect msg done") // Listen for visitorConn's address and wait for client connection. @@ -302,7 +302,7 @@ func (sv *XtcpVisitor) handleConn(userConn frpNet.Conn) { sv.Debug("join connections closed") } -func (sv *XtcpVisitor) sendDetectMsg(addr string, port int64, laddr *net.UDPAddr, content []byte) (err error) { +func (sv *XtcpVisitor) sendDetectMsg(addr string, port int, laddr *net.UDPAddr, content []byte) (err error) { daddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addr, port)) if err != nil { return err diff --git a/cmd/frpc/main.go b/cmd/frpc/main.go index f1836db2..234d9e79 100644 --- a/cmd/frpc/main.go +++ b/cmd/frpc/main.go @@ -99,7 +99,7 @@ func main() { if args["status"] != nil { if args["status"].(bool) { if err = CmdStatus(); err != nil { - fmt.Println("frps get status error: %v\n", err) + fmt.Printf("frps get status error: %v\n", err) os.Exit(1) } else { os.Exit(0) @@ -132,7 +132,7 @@ func main() { os.Exit(1) } config.ClientCommonCfg.ServerAddr = addr[0] - config.ClientCommonCfg.ServerPort = serverPort + config.ClientCommonCfg.ServerPort = int(serverPort) } if args["-v"] != nil { diff --git a/cmd/frps/main.go b/cmd/frps/main.go index fc5d6436..c3b495ad 100644 --- a/cmd/frps/main.go +++ b/cmd/frps/main.go @@ -91,7 +91,7 @@ func main() { os.Exit(1) } config.ServerCommonCfg.BindAddr = addr[0] - config.ServerCommonCfg.BindPort = bindPort + config.ServerCommonCfg.BindPort = int(bindPort) } if args["-v"] != nil { diff --git a/models/config/client_common.go b/models/config/client_common.go index f98169e7..4b9ede72 100644 --- a/models/config/client_common.go +++ b/models/config/client_common.go @@ -29,8 +29,8 @@ var ClientCommonCfg *ClientCommonConf type ClientCommonConf struct { ConfigFile string ServerAddr string - ServerPort int64 - ServerUdpPort int64 // this is specified by login response message from frps + ServerPort int + ServerUdpPort int // this is specified by login response message from frps HttpProxy string LogFile string LogWay string @@ -38,7 +38,7 @@ type ClientCommonConf struct { LogMaxDays int64 PrivilegeToken string AdminAddr string - AdminPort int64 + AdminPort int AdminUser string AdminPwd string PoolCount int @@ -93,7 +93,12 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) { tmpStr, ok = conf.Get("common", "server_port") if ok { - cfg.ServerPort, _ = strconv.ParseInt(tmpStr, 10, 64) + v, err = strconv.ParseInt(tmpStr, 10, 64) + if err != nil { + err = fmt.Errorf("Parse conf error: invalid server_port") + return + } + cfg.ServerPort = int(v) } tmpStr, ok = conf.Get("common", "http_proxy") @@ -139,7 +144,10 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) { tmpStr, ok = conf.Get("common", "admin_port") if ok { if v, err = strconv.ParseInt(tmpStr, 10, 64); err == nil { - cfg.AdminPort = v + cfg.AdminPort = int(v) + } else { + err = fmt.Errorf("Parse conf error: invalid admin_port") + return } } @@ -203,7 +211,7 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) { if ok { v, err = strconv.ParseInt(tmpStr, 10, 64) if err != nil { - err = fmt.Errorf("Parse conf error: heartbeat_timeout is incorrect") + err = fmt.Errorf("Parse conf error: invalid heartbeat_timeout") return } else { cfg.HeartBeatTimeout = v @@ -214,7 +222,7 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) { if ok { v, err = strconv.ParseInt(tmpStr, 10, 64) if err != nil { - err = fmt.Errorf("Parse conf error: heartbeat_interval is incorrect") + err = fmt.Errorf("Parse conf error: invalid heartbeat_interval") return } else { cfg.HeartBeatInterval = v @@ -222,12 +230,12 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) { } if cfg.HeartBeatInterval <= 0 { - err = fmt.Errorf("Parse conf error: heartbeat_interval is incorrect") + err = fmt.Errorf("Parse conf error: invalid heartbeat_interval") return } if cfg.HeartBeatTimeout < cfg.HeartBeatInterval { - err = fmt.Errorf("Parse conf error: heartbeat_timeout is incorrect, heartbeat_timeout is less than heartbeat_interval") + err = fmt.Errorf("Parse conf error: invalid heartbeat_timeout, heartbeat_timeout is less than heartbeat_interval") return } return diff --git a/models/config/proxy.go b/models/config/proxy.go index e87b7eca..022e64f4 100644 --- a/models/config/proxy.go +++ b/models/config/proxy.go @@ -23,7 +23,6 @@ import ( "github.com/fatedier/frp/models/consts" "github.com/fatedier/frp/models/msg" - "github.com/fatedier/frp/utils/util" ini "github.com/vaughan0/go-ini" ) @@ -163,7 +162,7 @@ func (cfg *BaseProxyConf) UnMarshalToMsg(pMsg *msg.NewProxy) { // Bind info type BindInfoConf struct { BindAddr string `json:"bind_addr"` - RemotePort int64 `json:"remote_port"` + RemotePort int `json:"remote_port"` } func (cfg *BindInfoConf) compare(cmp *BindInfoConf) bool { @@ -183,10 +182,13 @@ func (cfg *BindInfoConf) LoadFromFile(name string, section ini.Section) (err err var ( tmpStr string ok bool + v int64 ) if tmpStr, ok = section["remote_port"]; ok { - if cfg.RemotePort, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { return fmt.Errorf("Parse conf error: proxy [%s] remote_port error", name) + } else { + cfg.RemotePort = int(v) } } else { return fmt.Errorf("Parse conf error: proxy [%s] remote_port not found", name) @@ -199,11 +201,6 @@ func (cfg *BindInfoConf) UnMarshalToMsg(pMsg *msg.NewProxy) { } func (cfg *BindInfoConf) check() (err error) { - if len(ServerCommonCfg.PrivilegeAllowPorts) != 0 { - if ok := util.ContainsPort(ServerCommonCfg.PrivilegeAllowPorts, cfg.RemotePort); !ok { - return fmt.Errorf("remote port [%d] isn't allowed", cfg.RemotePort) - } - } return nil } diff --git a/models/config/server_common.go b/models/config/server_common.go index 4d177665..37892b4e 100644 --- a/models/config/server_common.go +++ b/models/config/server_common.go @@ -19,7 +19,6 @@ import ( "strconv" "strings" - "github.com/fatedier/frp/utils/util" ini "github.com/vaughan0/go-ini" ) @@ -29,20 +28,20 @@ var ServerCommonCfg *ServerCommonConf type ServerCommonConf struct { ConfigFile string BindAddr string - BindPort int64 - BindUdpPort int64 - KcpBindPort int64 + BindPort int + BindUdpPort int + KcpBindPort int ProxyBindAddr string // If VhostHttpPort equals 0, don't listen a public port for http protocol. - VhostHttpPort int64 + VhostHttpPort int // if VhostHttpsPort equals 0, don't listen a public port for https protocol - VhostHttpsPort int64 + VhostHttpsPort int DashboardAddr string // if DashboardPort equals 0, dashboard is not available - DashboardPort int64 + DashboardPort int DashboardUser string DashboardPwd string AssetsDir string @@ -56,8 +55,7 @@ type ServerCommonConf struct { SubDomainHost string TcpMux bool - // if PrivilegeAllowPorts is not nil, tcp proxies which remote port exist in this map can be connected - PrivilegeAllowPorts [][2]int64 + PrivilegeAllowPorts map[int]struct{} MaxPoolCount int64 HeartBeatTimeout int64 UserConnTimeout int64 @@ -65,31 +63,32 @@ type ServerCommonConf struct { func GetDefaultServerCommonConf() *ServerCommonConf { return &ServerCommonConf{ - ConfigFile: "./frps.ini", - BindAddr: "0.0.0.0", - BindPort: 7000, - BindUdpPort: 0, - KcpBindPort: 0, - ProxyBindAddr: "0.0.0.0", - VhostHttpPort: 0, - VhostHttpsPort: 0, - DashboardAddr: "0.0.0.0", - DashboardPort: 0, - DashboardUser: "admin", - DashboardPwd: "admin", - AssetsDir: "", - LogFile: "console", - LogWay: "console", - LogLevel: "info", - LogMaxDays: 3, - PrivilegeMode: true, - PrivilegeToken: "", - AuthTimeout: 900, - SubDomainHost: "", - TcpMux: true, - MaxPoolCount: 5, - HeartBeatTimeout: 90, - UserConnTimeout: 10, + ConfigFile: "./frps.ini", + BindAddr: "0.0.0.0", + BindPort: 7000, + BindUdpPort: 0, + KcpBindPort: 0, + ProxyBindAddr: "0.0.0.0", + VhostHttpPort: 0, + VhostHttpsPort: 0, + DashboardAddr: "0.0.0.0", + DashboardPort: 0, + DashboardUser: "admin", + DashboardPwd: "admin", + AssetsDir: "", + LogFile: "console", + LogWay: "console", + LogLevel: "info", + LogMaxDays: 3, + PrivilegeMode: true, + PrivilegeToken: "", + AuthTimeout: 900, + SubDomainHost: "", + TcpMux: true, + PrivilegeAllowPorts: make(map[int]struct{}), + MaxPoolCount: 5, + HeartBeatTimeout: 90, + UserConnTimeout: 10, } } @@ -109,25 +108,31 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) { tmpStr, ok = conf.Get("common", "bind_port") if ok { - v, err = strconv.ParseInt(tmpStr, 10, 64) - if err == nil { - cfg.BindPort = v + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + err = fmt.Errorf("Parse conf error: invalid bind_port") + return + } else { + cfg.BindPort = int(v) } } tmpStr, ok = conf.Get("common", "bind_udp_port") if ok { - v, err = strconv.ParseInt(tmpStr, 10, 64) - if err == nil { - cfg.BindUdpPort = v + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + err = fmt.Errorf("Parse conf error: invalid bind_udp_port") + return + } else { + cfg.BindUdpPort = int(v) } } tmpStr, ok = conf.Get("common", "kcp_bind_port") if ok { - v, err = strconv.ParseInt(tmpStr, 10, 64) - if err == nil && v > 0 { - cfg.KcpBindPort = v + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + err = fmt.Errorf("Parse conf error: invalid kcp_bind_port") + return + } else { + cfg.KcpBindPort = int(v) } } @@ -140,10 +145,11 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) { tmpStr, ok = conf.Get("common", "vhost_http_port") if ok { - cfg.VhostHttpPort, err = strconv.ParseInt(tmpStr, 10, 64) - if err != nil { - err = fmt.Errorf("Parse conf error: vhost_http_port is incorrect") + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + err = fmt.Errorf("Parse conf error: invalid vhost_http_port") return + } else { + cfg.VhostHttpPort = int(v) } } else { cfg.VhostHttpPort = 0 @@ -151,10 +157,11 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) { tmpStr, ok = conf.Get("common", "vhost_https_port") if ok { - cfg.VhostHttpsPort, err = strconv.ParseInt(tmpStr, 10, 64) - if err != nil { - err = fmt.Errorf("Parse conf error: vhost_https_port is incorrect") + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + err = fmt.Errorf("Parse conf error: invalid vhost_https_port") return + } else { + cfg.VhostHttpsPort = int(v) } } else { cfg.VhostHttpsPort = 0 @@ -169,10 +176,11 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) { tmpStr, ok = conf.Get("common", "dashboard_port") if ok { - cfg.DashboardPort, err = strconv.ParseInt(tmpStr, 10, 64) - if err != nil { - err = fmt.Errorf("Parse conf error: dashboard_port is incorrect") + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + err = fmt.Errorf("Parse conf error: invalid dashboard_port") return + } else { + cfg.DashboardPort = int(v) } } else { cfg.DashboardPort = 0 @@ -228,12 +236,45 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) { cfg.PrivilegeToken, _ = conf.Get("common", "privilege_token") allowPortsStr, ok := conf.Get("common", "privilege_allow_ports") - // TODO: check if conflicts exist in port ranges if ok { - cfg.PrivilegeAllowPorts, err = util.GetPortRanges(allowPortsStr) - if err != nil { - err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", err) - return + // e.g. 1000-2000,2001,2002,3000-4000 + portRanges := strings.Split(allowPortsStr, ",") + for _, portRangeStr := range portRanges { + // 1000-2000 or 2001 + portArray := strings.Split(portRangeStr, "-") + // length: only 1 or 2 is correct + rangeType := len(portArray) + if rangeType == 1 { + // single port + singlePort, errRet := strconv.ParseInt(portArray[0], 10, 64) + if errRet != nil { + err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet) + return + } + cfg.PrivilegeAllowPorts[int(singlePort)] = struct{}{} + } else if rangeType == 2 { + // range ports + min, errRet := strconv.ParseInt(portArray[0], 10, 64) + if errRet != nil { + err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet) + return + } + max, errRet := strconv.ParseInt(portArray[1], 10, 64) + if errRet != nil { + err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet) + return + } + if max < min { + err = fmt.Errorf("Parse conf error: privilege_allow_ports range incorrect") + return + } + for i := min; i <= max; i++ { + cfg.PrivilegeAllowPorts[int(i)] = struct{}{} + } + } else { + err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect") + return + } } } } diff --git a/models/msg/msg.go b/models/msg/msg.go index dd0dde71..0cdb3f47 100644 --- a/models/msg/msg.go +++ b/models/msg/msg.go @@ -92,7 +92,7 @@ type Login struct { type LoginResp struct { Version string `json:"version"` RunId string `json:"run_id"` - ServerUdpPort int64 `json:"server_udp_port"` + ServerUdpPort int `json:"server_udp_port"` Error string `json:"error"` } @@ -104,7 +104,7 @@ type NewProxy struct { UseCompression bool `json:"use_compression"` // tcp and udp only - RemotePort int64 `json:"remote_port"` + RemotePort int `json:"remote_port"` // http and https only CustomDomains []string `json:"custom_domains"` diff --git a/server/dashboard.go b/server/dashboard.go index 01f71591..3c77875c 100644 --- a/server/dashboard.go +++ b/server/dashboard.go @@ -32,7 +32,7 @@ var ( httpServerWriteTimeout = 10 * time.Second ) -func RunDashboardServer(addr string, port int64) (err error) { +func RunDashboardServer(addr string, port int) (err error) { // url router router := httprouter.New() diff --git a/server/dashboard_api.go b/server/dashboard_api.go index 89d285c5..3f9acd0f 100644 --- a/server/dashboard_api.go +++ b/server/dashboard_api.go @@ -36,8 +36,8 @@ type ServerInfoResp struct { GeneralResponse Version string `json:"version"` - VhostHttpPort int64 `json:"vhost_http_port"` - VhostHttpsPort int64 `json:"vhost_https_port"` + VhostHttpPort int `json:"vhost_http_port"` + VhostHttpsPort int `json:"vhost_https_port"` AuthTimeout int64 `json:"auth_timeout"` SubdomainHost string `json:"subdomain_host"` MaxPoolCount int64 `json:"max_pool_count"` diff --git a/server/ports.go b/server/ports.go new file mode 100644 index 00000000..b9cc4c16 --- /dev/null +++ b/server/ports.go @@ -0,0 +1,180 @@ +package server + +import ( + "errors" + "fmt" + "net" + "sync" + "time" +) + +const ( + MinPort = 1025 + MaxPort = 65535 + MaxPortReservedDuration = time.Duration(24) * time.Hour + CleanReservedPortsInterval = time.Hour +) + +var ( + ErrPortAlreadyUsed = errors.New("port already used") + ErrPortNotAllowed = errors.New("port not allowed") + ErrPortUnAvailable = errors.New("port unavailable") + ErrNoAvailablePort = errors.New("no available port") +) + +type PortCtx struct { + ProxyName string + Port int + Closed bool + UpdateTime time.Time +} + +type PortManager struct { + reservedPorts map[string]*PortCtx + usedPorts map[int]*PortCtx + freePorts map[int]struct{} + + bindAddr string + netType string + mu sync.Mutex +} + +func NewPortManager(netType string, bindAddr string, allowPorts map[int]struct{}) *PortManager { + pm := &PortManager{ + reservedPorts: make(map[string]*PortCtx), + usedPorts: make(map[int]*PortCtx), + freePorts: make(map[int]struct{}), + bindAddr: bindAddr, + netType: netType, + } + if len(allowPorts) > 0 { + for port, _ := range allowPorts { + pm.freePorts[port] = struct{}{} + } + } else { + for i := MinPort; i <= MaxPort; i++ { + pm.freePorts[i] = struct{}{} + } + } + go pm.cleanReservedPortsWorker() + return pm +} + +func (pm *PortManager) Acquire(name string, port int) (realPort int, err error) { + portCtx := &PortCtx{ + ProxyName: name, + Closed: false, + UpdateTime: time.Now(), + } + + var ok bool + + pm.mu.Lock() + defer func() { + if err == nil { + portCtx.Port = realPort + } + pm.mu.Unlock() + }() + + // check reserved ports first + if port == 0 { + if ctx, ok := pm.reservedPorts[name]; ok { + if pm.isPortAvailable(ctx.Port) { + realPort = ctx.Port + pm.usedPorts[realPort] = portCtx + pm.reservedPorts[name] = portCtx + delete(pm.freePorts, realPort) + return + } + } + } + + if port == 0 { + // get random port + count := 0 + maxTryTimes := 5 + for k, _ := range pm.freePorts { + count++ + if count > maxTryTimes { + break + } + if pm.isPortAvailable(k) { + realPort = k + pm.usedPorts[realPort] = portCtx + pm.reservedPorts[name] = portCtx + delete(pm.freePorts, realPort) + break + } + } + if realPort == 0 { + err = ErrNoAvailablePort + } + } else { + // specified port + if _, ok = pm.freePorts[port]; ok { + if pm.isPortAvailable(port) { + realPort = port + pm.usedPorts[realPort] = portCtx + pm.reservedPorts[name] = portCtx + delete(pm.freePorts, realPort) + } else { + err = ErrPortUnAvailable + } + } else { + if _, ok = pm.usedPorts[port]; ok { + err = ErrPortAlreadyUsed + } else { + err = ErrPortNotAllowed + } + } + } + return +} + +func (pm *PortManager) isPortAvailable(port int) bool { + if pm.netType == "udp" { + addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pm.bindAddr, port)) + if err != nil { + return false + } + l, err := net.ListenUDP("udp", addr) + if err != nil { + return false + } + l.Close() + return true + } else { + l, err := net.Listen(pm.netType, fmt.Sprintf("%s:%d", pm.bindAddr, port)) + if err != nil { + return false + } + l.Close() + return true + } +} + +func (pm *PortManager) Release(port int) { + pm.mu.Lock() + defer pm.mu.Unlock() + if ctx, ok := pm.usedPorts[port]; ok { + pm.freePorts[port] = struct{}{} + delete(pm.usedPorts, port) + ctx.Closed = true + ctx.UpdateTime = time.Now() + } +} + +// Release reserved port if it isn't used in last 24 hours. +func (pm *PortManager) cleanReservedPortsWorker() { + for { + time.Sleep(CleanReservedPortsInterval) + pm.mu.Lock() + for name, ctx := range pm.reservedPorts { + if ctx.Closed && time.Since(ctx.UpdateTime) > MaxPortReservedDuration { + delete(pm.reservedPorts, name) + } + } + pm.mu.Unlock() + } +} diff --git a/server/proxy.go b/server/proxy.go index f744b8ba..bfb9793a 100644 --- a/server/proxy.go +++ b/server/proxy.go @@ -165,11 +165,24 @@ func NewProxy(ctl *Control, pxyConf config.ProxyConf) (pxy Proxy, err error) { type TcpProxy struct { BaseProxy cfg *config.TcpProxyConf + + realPort int } func (pxy *TcpProxy) Run() (remoteAddr string, err error) { - remoteAddr = fmt.Sprintf(":%d", pxy.cfg.RemotePort) - listener, errRet := frpNet.ListenTcp(config.ServerCommonCfg.ProxyBindAddr, pxy.cfg.RemotePort) + pxy.realPort, err = pxy.ctl.svr.tcpPortManager.Acquire(pxy.name, pxy.cfg.RemotePort) + if err != nil { + return + } + defer func() { + if err != nil { + pxy.ctl.svr.tcpPortManager.Release(pxy.realPort) + } + }() + + remoteAddr = fmt.Sprintf(":%d", pxy.realPort) + pxy.cfg.RemotePort = pxy.realPort + listener, errRet := frpNet.ListenTcp(config.ServerCommonCfg.ProxyBindAddr, pxy.realPort) if errRet != nil { err = errRet return @@ -188,6 +201,7 @@ func (pxy *TcpProxy) GetConf() config.ProxyConf { func (pxy *TcpProxy) Close() { pxy.BaseProxy.Close() + pxy.ctl.svr.tcpPortManager.Release(pxy.realPort) } type HttpProxy struct { @@ -412,6 +426,8 @@ type UdpProxy struct { BaseProxy cfg *config.UdpProxyConf + realPort int + // udpConn is the listener of udp packages udpConn *net.UDPConn @@ -432,8 +448,19 @@ type UdpProxy struct { } func (pxy *UdpProxy) Run() (remoteAddr string, err error) { - remoteAddr = fmt.Sprintf(":%d", pxy.cfg.RemotePort) - addr, errRet := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", config.ServerCommonCfg.ProxyBindAddr, pxy.cfg.RemotePort)) + pxy.realPort, err = pxy.ctl.svr.udpPortManager.Acquire(pxy.name, pxy.cfg.RemotePort) + if err != nil { + return + } + defer func() { + if err != nil { + pxy.ctl.svr.udpPortManager.Release(pxy.realPort) + } + }() + + remoteAddr = fmt.Sprintf(":%d", pxy.realPort) + pxy.cfg.RemotePort = pxy.realPort + addr, errRet := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", config.ServerCommonCfg.ProxyBindAddr, pxy.realPort)) if errRet != nil { err = errRet return @@ -581,6 +608,7 @@ func (pxy *UdpProxy) Close() { close(pxy.readCh) close(pxy.sendCh) } + pxy.ctl.svr.udpPortManager.Release(pxy.realPort) } // HandleUserTcpConnection is used for incoming tcp user connections. diff --git a/server/service.go b/server/service.go index a510b179..e976658a 100644 --- a/server/service.go +++ b/server/service.go @@ -60,17 +60,25 @@ type Service struct { // Manage all visitor listeners. visitorManager *VisitorManager + // Manage all tcp ports. + tcpPortManager *PortManager + + // Manage all udp ports. + udpPortManager *PortManager + // Controller for nat hole connections. natHoleController *NatHoleController } func NewService() (svr *Service, err error) { + cfg := config.ServerCommonCfg svr = &Service{ ctlManager: NewControlManager(), pxyManager: NewProxyManager(), visitorManager: NewVisitorManager(), + tcpPortManager: NewPortManager("tcp", cfg.ProxyBindAddr, cfg.PrivilegeAllowPorts), + udpPortManager: NewPortManager("udp", cfg.ProxyBindAddr, cfg.PrivilegeAllowPorts), } - cfg := config.ServerCommonCfg // Init assets. err = assets.Load(cfg.AssetsDir) diff --git a/tests/func_test.go b/tests/func_test.go index 238046fb..1f154089 100644 --- a/tests/func_test.go +++ b/tests/func_test.go @@ -10,28 +10,28 @@ import ( var ( TEST_STR = "frp is a fast reverse proxy to help you expose a local server behind a NAT or firewall to the internet." - TEST_TCP_PORT int64 = 10701 - TEST_TCP_FRP_PORT int64 = 10801 - TEST_TCP_EC_FRP_PORT int64 = 10901 + TEST_TCP_PORT int = 10701 + TEST_TCP_FRP_PORT int = 10801 + TEST_TCP_EC_FRP_PORT int = 10901 TEST_TCP_ECHO_STR string = "tcp type:" + TEST_STR - TEST_UDP_PORT int64 = 10702 - TEST_UDP_FRP_PORT int64 = 10802 - TEST_UDP_EC_FRP_PORT int64 = 10902 + TEST_UDP_PORT int = 10702 + TEST_UDP_FRP_PORT int = 10802 + TEST_UDP_EC_FRP_PORT int = 10902 TEST_UDP_ECHO_STR string = "udp type:" + TEST_STR TEST_UNIX_DOMAIN_ADDR string = "/tmp/frp_echo_server.sock" - TEST_UNIX_DOMAIN_FRP_PORT int64 = 10803 + TEST_UNIX_DOMAIN_FRP_PORT int = 10803 TEST_UNIX_DOMAIN_STR string = "unix domain type:" + TEST_STR - TEST_HTTP_PORT int64 = 10704 - TEST_HTTP_FRP_PORT int64 = 10804 + TEST_HTTP_PORT int = 10704 + TEST_HTTP_FRP_PORT int = 10804 TEST_HTTP_NORMAL_STR string = "http normal string: " + TEST_STR TEST_HTTP_FOO_STR string = "http foo string: " + TEST_STR TEST_HTTP_BAR_STR string = "http bar string: " + TEST_STR - TEST_STCP_FRP_PORT int64 = 10805 - TEST_STCP_EC_FRP_PORT int64 = 10905 + TEST_STCP_FRP_PORT int = 10805 + TEST_STCP_EC_FRP_PORT int = 10905 TEST_STCP_ECHO_STR string = "stcp type:" + TEST_STR ) diff --git a/utils/net/kcp.go b/utils/net/kcp.go index 18979c12..3d080fdd 100644 --- a/utils/net/kcp.go +++ b/utils/net/kcp.go @@ -31,7 +31,7 @@ type KcpListener struct { log.Logger } -func ListenKcp(bindAddr string, bindPort int64) (l *KcpListener, err error) { +func ListenKcp(bindAddr string, bindPort int) (l *KcpListener, err error) { listener, err := kcp.ListenWithOptions(fmt.Sprintf("%s:%d", bindAddr, bindPort), nil, 10, 3) if err != nil { return l, err diff --git a/utils/net/tcp.go b/utils/net/tcp.go index ca71de0a..b2c5a2b6 100644 --- a/utils/net/tcp.go +++ b/utils/net/tcp.go @@ -33,7 +33,7 @@ type TcpListener struct { log.Logger } -func ListenTcp(bindAddr string, bindPort int64) (l *TcpListener, err error) { +func ListenTcp(bindAddr string, bindPort int) (l *TcpListener, err error) { tcpAddr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort)) if err != nil { return l, err diff --git a/utils/net/udp.go b/utils/net/udp.go index ec2fb261..f2e9a797 100644 --- a/utils/net/udp.go +++ b/utils/net/udp.go @@ -167,7 +167,7 @@ type UdpListener struct { log.Logger } -func ListenUDP(bindAddr string, bindPort int64) (l *UdpListener, err error) { +func ListenUDP(bindAddr string, bindPort int) (l *UdpListener, err error) { udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", bindAddr, bindPort)) if err != nil { return l, err diff --git a/utils/util/util.go b/utils/util/util.go index 88180e35..4439f1aa 100644 --- a/utils/util/util.go +++ b/utils/util/util.go @@ -19,8 +19,6 @@ import ( "crypto/rand" "encoding/hex" "fmt" - "strconv" - "strings" ) // RandId return a rand string used in frp. @@ -48,69 +46,6 @@ func GetAuthKey(token string, timestamp int64) (key string) { return hex.EncodeToString(data) } -// for example: rangeStr is "1000-2000,2001,2002,3000-4000", return an array as port ranges. -func GetPortRanges(rangeStr string) (portRanges [][2]int64, err error) { - // for example: 1000-2000,2001,2002,3000-4000 - rangeArray := strings.Split(rangeStr, ",") - for _, portRangeStr := range rangeArray { - // 1000-2000 or 2001 - portArray := strings.Split(portRangeStr, "-") - // length: only 1 or 2 is correct - rangeType := len(portArray) - if rangeType == 1 { - singlePort, err := strconv.ParseInt(portArray[0], 10, 64) - if err != nil { - return [][2]int64{}, err - } - portRanges = append(portRanges, [2]int64{singlePort, singlePort}) - } else if rangeType == 2 { - min, err := strconv.ParseInt(portArray[0], 10, 64) - if err != nil { - return [][2]int64{}, err - } - max, err := strconv.ParseInt(portArray[1], 10, 64) - if err != nil { - return [][2]int64{}, err - } - if max < min { - return [][2]int64{}, fmt.Errorf("range incorrect") - } - portRanges = append(portRanges, [2]int64{min, max}) - } else { - return [][2]int64{}, fmt.Errorf("format error") - } - } - return portRanges, nil -} - -func ContainsPort(portRanges [][2]int64, port int64) bool { - for _, pr := range portRanges { - if port >= pr[0] && port <= pr[1] { - return true - } - } - return false -} - -func PortRangesCut(portRanges [][2]int64, port int64) [][2]int64 { - var tmpRanges [][2]int64 - for _, pr := range portRanges { - if port >= pr[0] && port <= pr[1] { - leftRange := [2]int64{pr[0], port - 1} - rightRange := [2]int64{port + 1, pr[1]} - if leftRange[0] <= leftRange[1] { - tmpRanges = append(tmpRanges, leftRange) - } - if rightRange[0] <= rightRange[1] { - tmpRanges = append(tmpRanges, rightRange) - } - } else { - tmpRanges = append(tmpRanges, pr) - } - } - return tmpRanges -} - func CanonicalAddr(host string, port int) (addr string) { if port == 80 || port == 443 { addr = host diff --git a/utils/util/util_test.go b/utils/util/util_test.go index 17d77547..8210c613 100644 --- a/utils/util/util_test.go +++ b/utils/util/util_test.go @@ -20,67 +20,3 @@ func TestGetAuthKey(t *testing.T) { t.Log(key) assert.Equal("6df41a43725f0c770fd56379e12acf8c", key) } - -func TestGetPortRanges(t *testing.T) { - assert := assert.New(t) - - rangesStr := "2000-3000,3001,4000-50000" - expect := [][2]int64{ - [2]int64{2000, 3000}, - [2]int64{3001, 3001}, - [2]int64{4000, 50000}, - } - actual, err := GetPortRanges(rangesStr) - assert.Nil(err) - t.Log(actual) - assert.Equal(expect, actual) -} - -func TestContainsPort(t *testing.T) { - assert := assert.New(t) - - rangesStr := "2000-3000,3001,4000-50000" - portRanges, err := GetPortRanges(rangesStr) - assert.Nil(err) - - type Case struct { - Port int64 - Answer bool - } - cases := []Case{ - Case{ - Port: 3001, - Answer: true, - }, - Case{ - Port: 3002, - Answer: false, - }, - Case{ - Port: 44444, - Answer: true, - }, - } - for _, elem := range cases { - ok := ContainsPort(portRanges, elem.Port) - assert.Equal(elem.Answer, ok) - } -} - -func TestPortRangesCut(t *testing.T) { - assert := assert.New(t) - - rangesStr := "2000-3000,3001,4000-50000" - portRanges, err := GetPortRanges(rangesStr) - assert.Nil(err) - - expect := [][2]int64{ - [2]int64{2000, 3000}, - [2]int64{3001, 3001}, - [2]int64{4000, 44443}, - [2]int64{44445, 50000}, - } - actual := PortRangesCut(portRanges, 44444) - t.Log(actual) - assert.Equal(expect, actual) -}