diff --git a/models/config/proxy.go b/models/config/proxy.go index 4718a8c6..84b9a8e4 100644 --- a/models/config/proxy.go +++ b/models/config/proxy.go @@ -23,6 +23,7 @@ import ( "github.com/fatedier/frp/models/consts" "github.com/fatedier/frp/models/msg" + "github.com/fatedier/frp/utils/util" ini "github.com/vaughan0/go-ini" ) @@ -173,7 +174,8 @@ func (cfg *BindInfoConf) UnMarshalToMsg(pMsg *msg.NewProxy) { func (cfg *BindInfoConf) check() (err error) { if len(ServerCommonCfg.PrivilegeAllowPorts) != 0 { - if _, ok := ServerCommonCfg.PrivilegeAllowPorts[cfg.RemotePort]; !ok { + // TODO: once linstenPort used, should remove the port from privilege ports + if ok := util.ContainsPort(ServerCommonCfg.PrivilegeAllowPorts, cfg.RemotePort); !ok { return fmt.Errorf("remote port [%d] isn't allowed", cfg.RemotePort) } } diff --git a/models/config/server_common.go b/models/config/server_common.go index 070de8aa..550b2b67 100644 --- a/models/config/server_common.go +++ b/models/config/server_common.go @@ -19,6 +19,7 @@ import ( "strconv" "strings" + "github.com/fatedier/frp/utils/util" ini "github.com/vaughan0/go-ini" ) @@ -52,7 +53,7 @@ type ServerCommonConf struct { TcpMux bool // if PrivilegeAllowPorts is not nil, tcp proxies which remote port exist in this map can be connected - PrivilegeAllowPorts map[int64]struct{} + PrivilegeAllowPorts [][2]int64 MaxPoolCount int64 HeartBeatTimeout int64 UserConnTimeout int64 @@ -198,47 +199,12 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) { return } - cfg.PrivilegeAllowPorts = make(map[int64]struct{}) - tmpStr, ok = conf.Get("common", "privilege_allow_ports") + allowPortsStr, ok := conf.Get("common", "privilege_allow_ports") + // TODO: check if conflicts exist in port ranges if ok { - // e.g. 1000-2000,2001,2002,3000-4000 - portRanges := strings.Split(tmpStr, ",") - for _, portRangeStr := range portRanges { - // 1000-2000 or 2001 - portArray := strings.Split(portRangeStr, "-") - // length: only 1 or 2 is correct - rangeType := len(portArray) - if rangeType == 1 { - // single port - singlePort, errRet := strconv.ParseInt(portArray[0], 10, 64) - if errRet != nil { - err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet) - return - } - ServerCommonCfg.PrivilegeAllowPorts[singlePort] = struct{}{} - } else if rangeType == 2 { - // range ports - min, errRet := strconv.ParseInt(portArray[0], 10, 64) - if errRet != nil { - err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet) - return - } - max, errRet := strconv.ParseInt(portArray[1], 10, 64) - if errRet != nil { - err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet) - return - } - if max < min { - err = fmt.Errorf("Parse conf error: privilege_allow_ports range incorrect") - return - } - for i := min; i <= max; i++ { - cfg.PrivilegeAllowPorts[i] = struct{}{} - } - } else { - err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect") - return - } + cfg.PrivilegeAllowPorts, err = util.GetPortRanges(allowPortsStr) + if err != nil { + return } } } diff --git a/utils/util/util.go b/utils/util/util.go index 3e8bca5a..5c44e51c 100644 --- a/utils/util/util.go +++ b/utils/util/util.go @@ -19,6 +19,8 @@ import ( "crypto/rand" "encoding/hex" "fmt" + "strconv" + "strings" ) // RandId return a rand string used in frp. @@ -45,3 +47,66 @@ func GetAuthKey(token string, timestamp int64) (key string) { data := md5Ctx.Sum(nil) return hex.EncodeToString(data) } + +// for example: rangeStr is "1000-2000,2001,2002,3000-4000", return an array as port ranges. +func GetPortRanges(rangeStr string) (portRanges [][2]int64, err error) { + // for example: 1000-2000,2001,2002,3000-4000 + rangeArray := strings.Split(rangeStr, ",") + for _, portRangeStr := range rangeArray { + // 1000-2000 or 2001 + portArray := strings.Split(portRangeStr, "-") + // length: only 1 or 2 is correct + rangeType := len(portArray) + if rangeType == 1 { + singlePort, err := strconv.ParseInt(portArray[0], 10, 64) + if err != nil { + return [][2]int64{}, fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", err) + } + portRanges = append(portRanges, [2]int64{singlePort, singlePort}) + } else if rangeType == 2 { + min, err := strconv.ParseInt(portArray[0], 10, 64) + if err != nil { + return [][2]int64{}, fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", err) + } + max, err := strconv.ParseInt(portArray[1], 10, 64) + if err != nil { + return [][2]int64{}, fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", err) + } + if max < min { + return [][2]int64{}, fmt.Errorf("Parse conf error: privilege_allow_ports range incorrect") + } + portRanges = append(portRanges, [2]int64{min, max}) + } else { + return [][2]int64{}, fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect") + } + } + return portRanges, nil +} + +func ContainsPort(portRanges [][2]int64, port int64) bool { + for _, pr := range portRanges { + if port >= pr[0] && port <= pr[1] { + return true + } + } + return false +} + +func PortRangesCut(portRanges [][2]int64, port int64) [][2]int64 { + var tmpRanges [][2]int64 + for _, pr := range portRanges { + if port >= pr[0] && port <= pr[1] { + leftRange := [2]int64{pr[0], port - 1} + rightRange := [2]int64{port + 1, pr[1]} + if leftRange[0] <= leftRange[1] { + tmpRanges = append(tmpRanges, leftRange) + } + if rightRange[0] <= rightRange[1] { + tmpRanges = append(tmpRanges, rightRange) + } + } else { + tmpRanges = append(tmpRanges, pr) + } + } + return tmpRanges +} diff --git a/utils/util/util_test.go b/utils/util/util_test.go index 8210c613..17d77547 100644 --- a/utils/util/util_test.go +++ b/utils/util/util_test.go @@ -20,3 +20,67 @@ func TestGetAuthKey(t *testing.T) { t.Log(key) assert.Equal("6df41a43725f0c770fd56379e12acf8c", key) } + +func TestGetPortRanges(t *testing.T) { + assert := assert.New(t) + + rangesStr := "2000-3000,3001,4000-50000" + expect := [][2]int64{ + [2]int64{2000, 3000}, + [2]int64{3001, 3001}, + [2]int64{4000, 50000}, + } + actual, err := GetPortRanges(rangesStr) + assert.Nil(err) + t.Log(actual) + assert.Equal(expect, actual) +} + +func TestContainsPort(t *testing.T) { + assert := assert.New(t) + + rangesStr := "2000-3000,3001,4000-50000" + portRanges, err := GetPortRanges(rangesStr) + assert.Nil(err) + + type Case struct { + Port int64 + Answer bool + } + cases := []Case{ + Case{ + Port: 3001, + Answer: true, + }, + Case{ + Port: 3002, + Answer: false, + }, + Case{ + Port: 44444, + Answer: true, + }, + } + for _, elem := range cases { + ok := ContainsPort(portRanges, elem.Port) + assert.Equal(elem.Answer, ok) + } +} + +func TestPortRangesCut(t *testing.T) { + assert := assert.New(t) + + rangesStr := "2000-3000,3001,4000-50000" + portRanges, err := GetPortRanges(rangesStr) + assert.Nil(err) + + expect := [][2]int64{ + [2]int64{2000, 3000}, + [2]int64{3001, 3001}, + [2]int64{4000, 44443}, + [2]int64{44445, 50000}, + } + actual := PortRangesCut(portRanges, 44444) + t.Log(actual) + assert.Equal(expect, actual) +}