diff --git a/client/control.go b/client/control.go index 04be13ce..90d2d4ab 100644 --- a/client/control.go +++ b/client/control.go @@ -255,6 +255,7 @@ func (ctl *Control) login() (err error) { return nil } +// connectServer return a new connection to frps func (ctl *Control) connectServer() (conn frpNet.Conn, err error) { if g.GlbClientCfg.TcpMux { stream, errRet := ctl.session.OpenStream() diff --git a/client/health.go b/client/health.go index ad58554d..8e84a6f8 100644 --- a/client/health.go +++ b/client/health.go @@ -15,18 +15,133 @@ package client import ( - "github.com/fatedier/frp/models/config" + "context" + "net" + "net/http" + "time" ) type HealthCheckMonitor struct { - cfg config.HealthCheckConf + checkType string + interval time.Duration + timeout time.Duration + maxFailedTimes int + + // For tcp + addr string + + // For http + url string + + failedTimes uint64 + statusOK bool + statusNormalFn func() + statusFailedFn func() + + ctx context.Context + cancel context.CancelFunc } -func NewHealthCheckMonitor(cfg *config.HealthCheckConf) *HealthCheckMonitor { +func NewHealthCheckMonitor(checkType string, intervalS int, timeoutS int, maxFailedTimes int, addr string, url string, + statusNormalFn func(), statusFailedFn func()) *HealthCheckMonitor { + + if intervalS <= 0 { + intervalS = 10 + } + if timeoutS <= 0 { + timeoutS = 3 + } + if maxFailedTimes <= 0 { + maxFailedTimes = 1 + } + ctx, cancel := context.WithCancel(context.Background()) return &HealthCheckMonitor{ - cfg: *cfg, + checkType: checkType, + interval: time.Duration(intervalS) * time.Second, + timeout: time.Duration(timeoutS) * time.Second, + maxFailedTimes: maxFailedTimes, + addr: addr, + url: url, + statusOK: false, + statusNormalFn: statusNormalFn, + statusFailedFn: statusFailedFn, + ctx: ctx, + cancel: cancel, } } func (monitor *HealthCheckMonitor) Start() { + go monitor.checkWorker() +} + +func (monitor *HealthCheckMonitor) Stop() { + monitor.cancel() +} + +func (monitor *HealthCheckMonitor) checkWorker() { + for { + ctx, cancel := context.WithDeadline(monitor.ctx, time.Now().Add(monitor.timeout)) + ok := monitor.doCheck(ctx) + + // check if this monitor has been closed + select { + case <-ctx.Done(): + cancel() + return + default: + cancel() + } + + if ok { + if !monitor.statusOK && monitor.statusNormalFn != nil { + monitor.statusOK = true + monitor.statusNormalFn() + } + } else { + monitor.failedTimes++ + if monitor.statusOK && int(monitor.failedTimes) >= monitor.maxFailedTimes && monitor.statusFailedFn != nil { + monitor.statusOK = false + monitor.statusFailedFn() + } + } + + time.Sleep(monitor.interval) + } +} + +func (monitor *HealthCheckMonitor) doCheck(ctx context.Context) bool { + switch monitor.checkType { + case "tcp": + return monitor.doTcpCheck(ctx) + case "http": + return monitor.doHttpCheck(ctx) + default: + return false + } +} + +func (monitor *HealthCheckMonitor) doTcpCheck(ctx context.Context) bool { + var d net.Dialer + conn, err := d.DialContext(ctx, "tcp", monitor.addr) + if err != nil { + return false + } + conn.Close() + return true +} + +func (monitor *HealthCheckMonitor) doHttpCheck(ctx context.Context) bool { + req, err := http.NewRequest("GET", monitor.url, nil) + if err != nil { + return false + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return false + } + + if resp.StatusCode/100 != 2 { + return false + } + return true } diff --git a/client/proxy_manager.go b/client/proxy_manager.go index cfa56fc5..bd193bb8 100644 --- a/client/proxy_manager.go +++ b/client/proxy_manager.go @@ -18,7 +18,6 @@ const ( ProxyStatusWaitStart = "wait start" ProxyStatusRunning = "running" ProxyStatusCheckFailed = "check failed" - ProxyStatusCheckSuccess = "check success" ProxyStatusClosed = "closed" ) diff --git a/models/config/proxy.go b/models/config/proxy.go index b600be5c..9ea680a4 100644 --- a/models/config/proxy.go +++ b/models/config/proxy.go @@ -381,6 +381,8 @@ func (cfg *LocalSvrConf) checkForCli() (err error) { // Health check info type HealthCheckConf struct { HealthCheckType string `json:"health_check_type"` // tcp | http + HealthCheckTimeout int `json:"health_check_timeout"` + HealthCheckMaxFailed int `json:"health_check_max_failed"` HealthCheckIntervalS int `json:"health_check_interval_s"` HealthCheckUrl string `json:"health_check_url"`