frpc: add admin server for reload configure file

This commit is contained in:
fatedier 2017-07-13 02:20:49 +08:00
parent f63a4f0cdd
commit d246400a71
11 changed files with 546 additions and 111 deletions

60
client/admin.go Normal file
View File

@ -0,0 +1,60 @@
// Copyright 2017 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"fmt"
"net"
"net/http"
"time"
"github.com/fatedier/frp/models/config"
frpNet "github.com/fatedier/frp/utils/net"
"github.com/julienschmidt/httprouter"
)
var (
httpServerReadTimeout = 10 * time.Second
httpServerWriteTimeout = 10 * time.Second
)
func (svr *Service) RunAdminServer(addr string, port int64) (err error) {
// url router
router := httprouter.New()
user, passwd := config.ClientCommonCfg.AdminUser, config.ClientCommonCfg.AdminPwd
// api, see dashboard_api.go
router.GET("/api/reload", frpNet.HttprouterBasicAuth(svr.apiReload, user, passwd))
address := fmt.Sprintf("%s:%d", addr, port)
server := &http.Server{
Addr: address,
Handler: router,
ReadTimeout: httpServerReadTimeout,
WriteTimeout: httpServerWriteTimeout,
}
if address == "" {
address = ":http"
}
ln, err := net.Listen("tcp", address)
if err != nil {
return err
}
go server.Serve(ln)
return
}

78
client/admin_api.go Normal file
View File

@ -0,0 +1,78 @@
// Copyright 2017 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"encoding/json"
"net/http"
"github.com/julienschmidt/httprouter"
ini "github.com/vaughan0/go-ini"
"github.com/fatedier/frp/models/config"
"github.com/fatedier/frp/utils/log"
)
type GeneralResponse struct {
Code int64 `json:"code"`
Msg string `json:"msg"`
}
// api/reload
type ReloadResp struct {
GeneralResponse
}
func (svr *Service) apiReload(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
var (
buf []byte
res ReloadResp
)
defer func() {
log.Info("Http response [/api/reload]: code [%d]", res.Code)
buf, _ = json.Marshal(&res)
w.Write(buf)
}()
log.Info("Http request: [/api/reload]")
conf, err := ini.LoadFile(config.ClientCommonCfg.ConfigFile)
if err != nil {
res.Code = 1
res.Msg = err.Error()
log.Error("reload frpc config file error: %v", err)
return
}
newCommonCfg, err := config.LoadClientCommonConf(conf)
if err != nil {
res.Code = 2
res.Msg = err.Error()
log.Error("reload frpc common section error: %v", err)
return
}
pxyCfgs, vistorCfgs, err := config.LoadProxyConfFromFile(newCommonCfg.User, conf, newCommonCfg.Start)
if err != nil {
res.Code = 3
res.Msg = err.Error()
log.Error("reload frpc proxy config error: %v", err)
return
}
svr.ctl.reloadConf(pxyCfgs, vistorCfgs)
log.Info("success reload conf")
return
}

View File

@ -388,7 +388,7 @@ func (ctl *Control) manager() {
ctl.Warn("[%s] start error: %s", m.ProxyName, m.Error) ctl.Warn("[%s] start error: %s", m.ProxyName, m.Error)
continue continue
} }
cfg, ok := ctl.pxyCfgs[m.ProxyName] cfg, ok := ctl.getProxyConf(m.ProxyName)
if !ok { if !ok {
// it will never go to this branch now // it will never go to this branch now
ctl.Warn("[%s] no proxy conf found", m.ProxyName) ctl.Warn("[%s] no proxy conf found", m.ProxyName)
@ -424,20 +424,36 @@ func (ctl *Control) controler() {
maxDelayTime := 30 * time.Second maxDelayTime := 30 * time.Second
delayTime := time.Second delayTime := time.Second
checkInterval := 30 * time.Second checkInterval := 10 * time.Second
checkProxyTicker := time.NewTicker(checkInterval) checkProxyTicker := time.NewTicker(checkInterval)
for { for {
select { select {
case <-checkProxyTicker.C: case <-checkProxyTicker.C:
// Every 30 seconds, check which proxy registered failed and reregister it to server. // Every 10 seconds, check which proxy registered failed and reregister it to server.
ctl.mu.RLock()
for _, cfg := range ctl.pxyCfgs { for _, cfg := range ctl.pxyCfgs {
if _, exist := ctl.getProxy(cfg.GetName()); !exist { if _, exist := ctl.proxies[cfg.GetName()]; !exist {
ctl.Info("try to reregister proxy [%s]", cfg.GetName()) ctl.Info("try to register proxy [%s]", cfg.GetName())
var newProxyMsg msg.NewProxy var newProxyMsg msg.NewProxy
cfg.UnMarshalToMsg(&newProxyMsg) cfg.UnMarshalToMsg(&newProxyMsg)
ctl.sendCh <- &newProxyMsg ctl.sendCh <- &newProxyMsg
} }
} }
for _, cfg := range ctl.vistorCfgs {
if _, exist := ctl.vistors[cfg.GetName()]; !exist {
ctl.Info("try to start vistor [%s]", cfg.GetName())
vistor := NewVistor(ctl, cfg)
err = vistor.Run()
if err != nil {
vistor.Warn("start error: %v", err)
continue
}
ctl.vistors[cfg.GetName()] = vistor
vistor.Info("start vistor success")
}
}
ctl.mu.RUnlock()
case _, ok := <-ctl.closedCh: case _, ok := <-ctl.closedCh:
// we won't get any variable from this channel // we won't get any variable from this channel
if !ok { if !ok {
@ -485,11 +501,13 @@ func (ctl *Control) controler() {
go ctl.reader() go ctl.reader()
// send NewProxy message for all configured proxies // send NewProxy message for all configured proxies
ctl.mu.RLock()
for _, cfg := range ctl.pxyCfgs { for _, cfg := range ctl.pxyCfgs {
var newProxyMsg msg.NewProxy var newProxyMsg msg.NewProxy
cfg.UnMarshalToMsg(&newProxyMsg) cfg.UnMarshalToMsg(&newProxyMsg)
ctl.sendCh <- &newProxyMsg ctl.sendCh <- &newProxyMsg
} }
ctl.mu.RUnlock()
checkProxyTicker.Stop() checkProxyTicker.Stop()
checkProxyTicker = time.NewTicker(checkInterval) checkProxyTicker = time.NewTicker(checkInterval)
@ -522,3 +540,82 @@ func (ctl *Control) addProxy(name string, pxy Proxy) {
defer ctl.mu.Unlock() defer ctl.mu.Unlock()
ctl.proxies[name] = pxy ctl.proxies[name] = pxy
} }
func (ctl *Control) getProxyConf(name string) (conf config.ProxyConf, ok bool) {
ctl.mu.RLock()
defer ctl.mu.RUnlock()
conf, ok = ctl.pxyCfgs[name]
return
}
func (ctl *Control) reloadConf(pxyCfgs map[string]config.ProxyConf, vistorCfgs map[string]config.ProxyConf) {
ctl.mu.Lock()
defer ctl.mu.Unlock()
removedPxyNames := make([]string, 0)
for name, oldCfg := range ctl.pxyCfgs {
del := false
cfg, ok := pxyCfgs[name]
if !ok {
del = true
} else {
if !oldCfg.Compare(cfg) {
del = true
}
}
if del {
removedPxyNames = append(removedPxyNames, name)
delete(ctl.pxyCfgs, name)
if pxy, ok := ctl.proxies[name]; ok {
pxy.Close()
}
delete(ctl.proxies, name)
ctl.sendCh <- &msg.CloseProxy{
ProxyName: name,
}
}
}
ctl.Info("proxy removed: %v", removedPxyNames)
addedPxyNames := make([]string, 0)
for name, cfg := range pxyCfgs {
if _, ok := ctl.pxyCfgs[name]; !ok {
ctl.pxyCfgs[name] = cfg
addedPxyNames = append(addedPxyNames, name)
}
}
ctl.Info("proxy added: %v", addedPxyNames)
removedVistorName := make([]string, 0)
for name, oldVistorCfg := range ctl.vistorCfgs {
del := false
cfg, ok := vistorCfgs[name]
if !ok {
del = true
} else {
if !oldVistorCfg.Compare(cfg) {
del = true
}
}
if del {
removedVistorName = append(removedVistorName, name)
delete(ctl.vistorCfgs, name)
if vistor, ok := ctl.vistors[name]; ok {
vistor.Close()
}
delete(ctl.vistors, name)
}
}
ctl.Info("vistor removed: %v", removedVistorName)
addedVistorName := make([]string, 0)
for name, vistorCfg := range vistorCfgs {
if _, ok := ctl.vistorCfgs[name]; !ok {
ctl.vistorCfgs[name] = vistorCfg
addedVistorName = append(addedVistorName, name)
}
}
ctl.Info("vistor added: %v", addedVistorName)
}

View File

@ -14,7 +14,10 @@
package client package client
import "github.com/fatedier/frp/models/config" import (
"github.com/fatedier/frp/models/config"
"github.com/fatedier/frp/utils/log"
)
type Service struct { type Service struct {
// manager control connection with server // manager control connection with server
@ -38,6 +41,14 @@ func (svr *Service) Run() error {
return err return err
} }
if config.ClientCommonCfg.AdminPort != 0 {
err = svr.RunAdminServer(config.ClientCommonCfg.AdminAddr, config.ClientCommonCfg.AdminPort)
if err != nil {
log.Warn("run admin server error: %v", err)
}
log.Info("admin server listen on %s:%d", config.ClientCommonCfg.AdminAddr, config.ClientCommonCfg.AdminPort)
}
<-svr.closedCh <-svr.closedCh
return nil return nil
} }

View File

@ -54,7 +54,7 @@ Options:
func main() { func main() {
var err error var err error
confFile := "./frpc.ini" confFile := "./frps.ini"
// the configures parsed from file will be replaced by those from command line if exist // the configures parsed from file will be replaced by those from command line if exist
args, err := docopt.Parse(usage, nil, true, version.Full(), false) args, err := docopt.Parse(usage, nil, true, version.Full(), false)
@ -73,6 +73,7 @@ func main() {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
} }
config.ClientCommonCfg.ConfigFile = confFile
if args["-L"] != nil { if args["-L"] != nil {
if args["-L"].(string) == "console" { if args["-L"].(string) == "console" {

View File

@ -20,6 +20,12 @@ log_max_days = 3
# for authentication # for authentication
privilege_token = 12345678 privilege_token = 12345678
# set admin address for control frpc's action by http api such as reload
admin_addr = 127.0.0.1
admin_port = 7400
admin_user = admin
admin_pwd = admin
# connections will be established in advance, default value is zero # connections will be established in advance, default value is zero
pool_count = 5 pool_count = 5

View File

@ -16,7 +16,7 @@ kcp_bind_port = 7000
vhost_http_port = 80 vhost_http_port = 80
vhost_https_port = 443 vhost_https_port = 443
# if you want to configure or reload frps by dashboard, dashboard_port must be set # set dashboard_port to view dashboard of frps
dashboard_port = 7500 dashboard_port = 7500
# dashboard user and pwd for basic auth protect, if not set, both default value is admin # dashboard user and pwd for basic auth protect, if not set, both default value is admin

View File

@ -36,6 +36,10 @@ type ClientCommonConf struct {
LogLevel string LogLevel string
LogMaxDays int64 LogMaxDays int64
PrivilegeToken string PrivilegeToken string
AdminAddr string
AdminPort int64
AdminUser string
AdminPwd string
PoolCount int PoolCount int
TcpMux bool TcpMux bool
User string User string
@ -57,6 +61,10 @@ func GetDeaultClientCommonConf() *ClientCommonConf {
LogLevel: "info", LogLevel: "info",
LogMaxDays: 3, LogMaxDays: 3,
PrivilegeToken: "", PrivilegeToken: "",
AdminAddr: "127.0.0.1",
AdminPort: 0,
AdminUser: "",
AdminPwd: "",
PoolCount: 1, PoolCount: 1,
TcpMux: true, TcpMux: true,
User: "", User: "",
@ -111,7 +119,9 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) {
tmpStr, ok = conf.Get("common", "log_max_days") tmpStr, ok = conf.Get("common", "log_max_days")
if ok { if ok {
cfg.LogMaxDays, _ = strconv.ParseInt(tmpStr, 10, 64) if v, err = strconv.ParseInt(tmpStr, 10, 64); err == nil {
cfg.LogMaxDays = v
}
} }
tmpStr, ok = conf.Get("common", "privilege_token") tmpStr, ok = conf.Get("common", "privilege_token")
@ -119,6 +129,28 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) {
cfg.PrivilegeToken = tmpStr cfg.PrivilegeToken = tmpStr
} }
tmpStr, ok = conf.Get("common", "admin_addr")
if ok {
cfg.AdminAddr = tmpStr
}
tmpStr, ok = conf.Get("common", "admin_port")
if ok {
if v, err = strconv.ParseInt(tmpStr, 10, 64); err == nil {
cfg.AdminPort = v
}
}
tmpStr, ok = conf.Get("common", "admin_user")
if ok {
cfg.AdminUser = tmpStr
}
tmpStr, ok = conf.Get("common", "admin_pwd")
if ok {
cfg.AdminPwd = tmpStr
}
tmpStr, ok = conf.Get("common", "pool_count") tmpStr, ok = conf.Get("common", "pool_count")
if ok { if ok {
v, err = strconv.ParseInt(tmpStr, 10, 64) v, err = strconv.ParseInt(tmpStr, 10, 64)
@ -145,7 +177,7 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) {
if ok { if ok {
proxyNames := strings.Split(tmpStr, ",") proxyNames := strings.Split(tmpStr, ",")
for _, name := range proxyNames { for _, name := range proxyNames {
cfg.Start[name] = struct{}{} cfg.Start[strings.TrimSpace(name)] = struct{}{}
} }
} }

View File

@ -56,6 +56,7 @@ type ProxyConf interface {
LoadFromFile(name string, conf ini.Section) error LoadFromFile(name string, conf ini.Section) error
UnMarshalToMsg(pMsg *msg.NewProxy) UnMarshalToMsg(pMsg *msg.NewProxy)
Check() error Check() error
Compare(conf ProxyConf) bool
} }
func NewProxyConf(pMsg *msg.NewProxy) (cfg ProxyConf, err error) { func NewProxyConf(pMsg *msg.NewProxy) (cfg ProxyConf, err error) {
@ -105,6 +106,16 @@ func (cfg *BaseProxyConf) GetBaseInfo() *BaseProxyConf {
return cfg return cfg
} }
func (cfg *BaseProxyConf) compare(cmp *BaseProxyConf) bool {
if cfg.ProxyName != cmp.ProxyName ||
cfg.ProxyType != cmp.ProxyType ||
cfg.UseEncryption != cmp.UseEncryption ||
cfg.UseCompression != cmp.UseCompression {
return false
}
return true
}
func (cfg *BaseProxyConf) LoadFromMsg(pMsg *msg.NewProxy) { func (cfg *BaseProxyConf) LoadFromMsg(pMsg *msg.NewProxy) {
cfg.ProxyName = pMsg.ProxyName cfg.ProxyName = pMsg.ProxyName
cfg.ProxyType = pMsg.ProxyType cfg.ProxyType = pMsg.ProxyType
@ -149,8 +160,16 @@ type BindInfoConf struct {
RemotePort int64 `json:"remote_port"` RemotePort int64 `json:"remote_port"`
} }
func (cfg *BindInfoConf) compare(cmp *BindInfoConf) bool {
if cfg.BindAddr != cmp.BindAddr ||
cfg.RemotePort != cmp.RemotePort {
return false
}
return true
}
func (cfg *BindInfoConf) LoadFromMsg(pMsg *msg.NewProxy) { func (cfg *BindInfoConf) LoadFromMsg(pMsg *msg.NewProxy) {
cfg.BindAddr = ServerCommonCfg.BindAddr cfg.BindAddr = ServerCommonCfg.ProxyBindAddr
cfg.RemotePort = pMsg.RemotePort cfg.RemotePort = pMsg.RemotePort
} }
@ -188,6 +207,14 @@ type DomainConf struct {
SubDomain string `json:"sub_domain"` SubDomain string `json:"sub_domain"`
} }
func (cfg *DomainConf) compare(cmp *DomainConf) bool {
if strings.Join(cfg.CustomDomains, " ") != strings.Join(cmp.CustomDomains, " ") ||
cfg.SubDomain != cmp.SubDomain {
return false
}
return true
}
func (cfg *DomainConf) LoadFromMsg(pMsg *msg.NewProxy) { func (cfg *DomainConf) LoadFromMsg(pMsg *msg.NewProxy) {
cfg.CustomDomains = pMsg.CustomDomains cfg.CustomDomains = pMsg.CustomDomains
cfg.SubDomain = pMsg.SubDomain cfg.SubDomain = pMsg.SubDomain
@ -246,6 +273,14 @@ type LocalSvrConf struct {
LocalPort int `json:"-"` LocalPort int `json:"-"`
} }
func (cfg *LocalSvrConf) compare(cmp *LocalSvrConf) bool {
if cfg.LocalIp != cmp.LocalIp ||
cfg.LocalPort != cmp.LocalPort {
return false
}
return true
}
func (cfg *LocalSvrConf) LoadFromFile(name string, section ini.Section) (err error) { func (cfg *LocalSvrConf) LoadFromFile(name string, section ini.Section) (err error) {
if cfg.LocalIp = section["local_ip"]; cfg.LocalIp == "" { if cfg.LocalIp = section["local_ip"]; cfg.LocalIp == "" {
cfg.LocalIp = "127.0.0.1" cfg.LocalIp = "127.0.0.1"
@ -266,6 +301,20 @@ type PluginConf struct {
PluginParams map[string]string `json:"-"` PluginParams map[string]string `json:"-"`
} }
func (cfg *PluginConf) compare(cmp *PluginConf) bool {
if cfg.Plugin != cmp.Plugin ||
len(cfg.PluginParams) != len(cmp.PluginParams) {
return false
}
for k, v := range cfg.PluginParams {
value, ok := cmp.PluginParams[k]
if !ok || v != value {
return false
}
}
return true
}
func (cfg *PluginConf) LoadFromFile(name string, section ini.Section) (err error) { func (cfg *PluginConf) LoadFromFile(name string, section ini.Section) (err error) {
cfg.Plugin = section["plugin"] cfg.Plugin = section["plugin"]
cfg.PluginParams = make(map[string]string) cfg.PluginParams = make(map[string]string)
@ -291,6 +340,21 @@ type TcpProxyConf struct {
PluginConf PluginConf
} }
func (cfg *TcpProxyConf) Compare(cmp ProxyConf) bool {
cmpConf, ok := cmp.(*TcpProxyConf)
if !ok {
return false
}
if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) ||
!cfg.BindInfoConf.compare(&cmpConf.BindInfoConf) ||
!cfg.LocalSvrConf.compare(&cmpConf.LocalSvrConf) ||
!cfg.PluginConf.compare(&cmpConf.PluginConf) {
return false
}
return true
}
func (cfg *TcpProxyConf) LoadFromMsg(pMsg *msg.NewProxy) { func (cfg *TcpProxyConf) LoadFromMsg(pMsg *msg.NewProxy) {
cfg.BaseProxyConf.LoadFromMsg(pMsg) cfg.BaseProxyConf.LoadFromMsg(pMsg)
cfg.BindInfoConf.LoadFromMsg(pMsg) cfg.BindInfoConf.LoadFromMsg(pMsg)
@ -330,6 +394,20 @@ type UdpProxyConf struct {
LocalSvrConf LocalSvrConf
} }
func (cfg *UdpProxyConf) Compare(cmp ProxyConf) bool {
cmpConf, ok := cmp.(*UdpProxyConf)
if !ok {
return false
}
if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) ||
!cfg.BindInfoConf.compare(&cmpConf.BindInfoConf) ||
!cfg.LocalSvrConf.compare(&cmpConf.LocalSvrConf) {
return false
}
return true
}
func (cfg *UdpProxyConf) LoadFromMsg(pMsg *msg.NewProxy) { func (cfg *UdpProxyConf) LoadFromMsg(pMsg *msg.NewProxy) {
cfg.BaseProxyConf.LoadFromMsg(pMsg) cfg.BaseProxyConf.LoadFromMsg(pMsg)
cfg.BindInfoConf.LoadFromMsg(pMsg) cfg.BindInfoConf.LoadFromMsg(pMsg)
@ -372,6 +450,25 @@ type HttpProxyConf struct {
HttpPwd string `json:"-"` HttpPwd string `json:"-"`
} }
func (cfg *HttpProxyConf) Compare(cmp ProxyConf) bool {
cmpConf, ok := cmp.(*HttpProxyConf)
if !ok {
return false
}
if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) ||
!cfg.DomainConf.compare(&cmpConf.DomainConf) ||
!cfg.LocalSvrConf.compare(&cmpConf.LocalSvrConf) ||
!cfg.PluginConf.compare(&cmpConf.PluginConf) ||
strings.Join(cfg.Locations, " ") != strings.Join(cmpConf.Locations, " ") ||
cfg.HostHeaderRewrite != cmpConf.HostHeaderRewrite ||
cfg.HttpUser != cmpConf.HttpUser ||
cfg.HttpPwd != cmpConf.HttpPwd {
return false
}
return true
}
func (cfg *HttpProxyConf) LoadFromMsg(pMsg *msg.NewProxy) { func (cfg *HttpProxyConf) LoadFromMsg(pMsg *msg.NewProxy) {
cfg.BaseProxyConf.LoadFromMsg(pMsg) cfg.BaseProxyConf.LoadFromMsg(pMsg)
cfg.DomainConf.LoadFromMsg(pMsg) cfg.DomainConf.LoadFromMsg(pMsg)
@ -438,6 +535,21 @@ type HttpsProxyConf struct {
PluginConf PluginConf
} }
func (cfg *HttpsProxyConf) Compare(cmp ProxyConf) bool {
cmpConf, ok := cmp.(*HttpsProxyConf)
if !ok {
return false
}
if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) ||
!cfg.DomainConf.compare(&cmpConf.DomainConf) ||
!cfg.LocalSvrConf.compare(&cmpConf.LocalSvrConf) ||
!cfg.PluginConf.compare(&cmpConf.PluginConf) {
return false
}
return true
}
func (cfg *HttpsProxyConf) LoadFromMsg(pMsg *msg.NewProxy) { func (cfg *HttpsProxyConf) LoadFromMsg(pMsg *msg.NewProxy) {
cfg.BaseProxyConf.LoadFromMsg(pMsg) cfg.BaseProxyConf.LoadFromMsg(pMsg)
cfg.DomainConf.LoadFromMsg(pMsg) cfg.DomainConf.LoadFromMsg(pMsg)
@ -488,6 +600,25 @@ type StcpProxyConf struct {
BindPort int `json:"bind_port"` BindPort int `json:"bind_port"`
} }
func (cfg *StcpProxyConf) Compare(cmp ProxyConf) bool {
cmpConf, ok := cmp.(*StcpProxyConf)
if !ok {
return false
}
if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) ||
!cfg.LocalSvrConf.compare(&cmpConf.LocalSvrConf) ||
!cfg.PluginConf.compare(&cmpConf.PluginConf) ||
cfg.Role != cmpConf.Role ||
cfg.Sk != cmpConf.Sk ||
cfg.ServerName != cmpConf.ServerName ||
cfg.BindAddr != cmpConf.BindAddr ||
cfg.BindPort != cmpConf.BindPort {
return false
}
return true
}
// Only for role server. // Only for role server.
func (cfg *StcpProxyConf) LoadFromMsg(pMsg *msg.NewProxy) { func (cfg *StcpProxyConf) LoadFromMsg(pMsg *msg.NewProxy) {
cfg.BaseProxyConf.LoadFromMsg(pMsg) cfg.BaseProxyConf.LoadFromMsg(pMsg)

View File

@ -15,16 +15,14 @@
package server package server
import ( import (
"compress/gzip"
"fmt" "fmt"
"io"
"net" "net"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/fatedier/frp/assets" "github.com/fatedier/frp/assets"
"github.com/fatedier/frp/models/config" "github.com/fatedier/frp/models/config"
frpNet "github.com/fatedier/frp/utils/net"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
) )
@ -38,20 +36,24 @@ func RunDashboardServer(addr string, port int64) (err error) {
// url router // url router
router := httprouter.New() router := httprouter.New()
user, passwd := config.ServerCommonCfg.DashboardUser, config.ServerCommonCfg.DashboardPwd
// api, see dashboard_api.go // api, see dashboard_api.go
router.GET("/api/serverinfo", httprouterBasicAuth(apiServerInfo)) router.GET("/api/serverinfo", frpNet.HttprouterBasicAuth(apiServerInfo, user, passwd))
router.GET("/api/proxy/tcp", httprouterBasicAuth(apiProxyTcp)) router.GET("/api/proxy/tcp", frpNet.HttprouterBasicAuth(apiProxyTcp, user, passwd))
router.GET("/api/proxy/udp", httprouterBasicAuth(apiProxyUdp)) router.GET("/api/proxy/udp", frpNet.HttprouterBasicAuth(apiProxyUdp, user, passwd))
router.GET("/api/proxy/http", httprouterBasicAuth(apiProxyHttp)) router.GET("/api/proxy/http", frpNet.HttprouterBasicAuth(apiProxyHttp, user, passwd))
router.GET("/api/proxy/https", httprouterBasicAuth(apiProxyHttps)) router.GET("/api/proxy/https", frpNet.HttprouterBasicAuth(apiProxyHttps, user, passwd))
router.GET("/api/proxy/traffic/:name", httprouterBasicAuth(apiProxyTraffic)) router.GET("/api/proxy/traffic/:name", frpNet.HttprouterBasicAuth(apiProxyTraffic, user, passwd))
// view // view
router.Handler("GET", "/favicon.ico", http.FileServer(assets.FileSystem)) router.Handler("GET", "/favicon.ico", http.FileServer(assets.FileSystem))
router.Handler("GET", "/static/*filepath", MakeGzipHandler(basicAuthWraper(http.StripPrefix("/static/", http.FileServer(assets.FileSystem))))) router.Handler("GET", "/static/*filepath", frpNet.MakeHttpGzipHandler(
router.HandlerFunc("GET", "/", basicAuth(func(w http.ResponseWriter, r *http.Request) { frpNet.NewHttpBasicAuthWraper(http.StripPrefix("/static/", http.FileServer(assets.FileSystem)), user, passwd)))
router.HandlerFunc("GET", "/", frpNet.HttpBasicAuth(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/static/", http.StatusMovedPermanently) http.Redirect(w, r, "/static/", http.StatusMovedPermanently)
})) }, user, passwd))
address := fmt.Sprintf("%s:%d", addr, port) address := fmt.Sprintf("%s:%d", addr, port)
server := &http.Server{ server := &http.Server{
@ -71,91 +73,3 @@ func RunDashboardServer(addr string, port int64) (err error) {
go server.Serve(ln) go server.Serve(ln)
return return
} }
func use(h http.HandlerFunc, middleware ...func(http.HandlerFunc) http.HandlerFunc) http.HandlerFunc {
for _, m := range middleware {
h = m(h)
}
return h
}
type AuthWraper struct {
h http.Handler
user string
passwd string
}
func (aw *AuthWraper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
user, passwd, hasAuth := r.BasicAuth()
if (aw.user == "" && aw.passwd == "") || (hasAuth && user == aw.user && passwd == aw.passwd) {
aw.h.ServeHTTP(w, r)
} else {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
}
}
func basicAuthWraper(h http.Handler) http.Handler {
return &AuthWraper{
h: h,
user: config.ServerCommonCfg.DashboardUser,
passwd: config.ServerCommonCfg.DashboardPwd,
}
}
func basicAuth(h http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, passwd, hasAuth := r.BasicAuth()
if (config.ServerCommonCfg.DashboardUser == "" && config.ServerCommonCfg.DashboardPwd == "") ||
(hasAuth && user == config.ServerCommonCfg.DashboardUser && passwd == config.ServerCommonCfg.DashboardPwd) {
h.ServeHTTP(w, r)
} else {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
}
}
}
func httprouterBasicAuth(h httprouter.Handle) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
user, passwd, hasAuth := r.BasicAuth()
if (config.ServerCommonCfg.DashboardUser == "" && config.ServerCommonCfg.DashboardPwd == "") ||
(hasAuth && user == config.ServerCommonCfg.DashboardUser && passwd == config.ServerCommonCfg.DashboardPwd) {
h(w, r, ps)
} else {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
}
}
}
type GzipWraper struct {
h http.Handler
}
func (gw *GzipWraper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
gw.h.ServeHTTP(w, r)
return
}
w.Header().Set("Content-Encoding", "gzip")
gz := gzip.NewWriter(w)
defer gz.Close()
gzr := gzipResponseWriter{Writer: gz, ResponseWriter: w}
gw.h.ServeHTTP(gzr, r)
}
func MakeGzipHandler(h http.Handler) http.Handler {
return &GzipWraper{
h: h,
}
}
type gzipResponseWriter struct {
io.Writer
http.ResponseWriter
}
func (w gzipResponseWriter) Write(b []byte) (int, error) {
return w.Writer.Write(b)
}

105
utils/net/http.go Normal file
View File

@ -0,0 +1,105 @@
// Copyright 2017 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package net
import (
"compress/gzip"
"io"
"net/http"
"strings"
"github.com/julienschmidt/httprouter"
)
type HttpAuthWraper struct {
h http.Handler
user string
passwd string
}
func NewHttpBasicAuthWraper(h http.Handler, user, passwd string) http.Handler {
return &HttpAuthWraper{
h: h,
user: user,
passwd: passwd,
}
}
func (aw *HttpAuthWraper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
user, passwd, hasAuth := r.BasicAuth()
if (aw.user == "" && aw.passwd == "") || (hasAuth && user == aw.user && passwd == aw.passwd) {
aw.h.ServeHTTP(w, r)
} else {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
}
}
func HttpBasicAuth(h http.HandlerFunc, user, passwd string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
reqUser, reqPasswd, hasAuth := r.BasicAuth()
if (user == "" && passwd == "") ||
(hasAuth && reqUser == user && reqPasswd == passwd) {
h.ServeHTTP(w, r)
} else {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
}
}
}
func HttprouterBasicAuth(h httprouter.Handle, user, passwd string) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
reqUser, reqPasswd, hasAuth := r.BasicAuth()
if (user == "" && passwd == "") ||
(hasAuth && reqUser == user && reqPasswd == passwd) {
h(w, r, ps)
} else {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
}
}
}
type HttpGzipWraper struct {
h http.Handler
}
func (gw *HttpGzipWraper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
gw.h.ServeHTTP(w, r)
return
}
w.Header().Set("Content-Encoding", "gzip")
gz := gzip.NewWriter(w)
defer gz.Close()
gzr := gzipResponseWriter{Writer: gz, ResponseWriter: w}
gw.h.ServeHTTP(gzr, r)
}
func MakeHttpGzipHandler(h http.Handler) http.Handler {
return &HttpGzipWraper{
h: h,
}
}
type gzipResponseWriter struct {
io.Writer
http.ResponseWriter
}
func (w gzipResponseWriter) Write(b []byte) (int, error) {
return w.Writer.Write(b)
}