From f3876d69bb789d152ba6c50027f486eb561b097b Mon Sep 17 00:00:00 2001 From: Maodanping <673698750@qq.com> Date: Mon, 13 Jun 2016 22:19:24 +0800 Subject: [PATCH] add https proto for reverse proxy --- conf/frps.ini | 3 +- src/frp/cmd/frps/main.go | 15 +- src/frp/models/client/config.go | 2 +- src/frp/models/server/config.go | 31 ++++- src/frp/models/server/server.go | 10 +- src/frp/utils/vhost/vhost_https.go | 217 +++++++++++++++++++++++++++++ 6 files changed, 272 insertions(+), 6 deletions(-) create mode 100644 src/frp/utils/vhost/vhost_https.go diff --git a/conf/frps.ini b/conf/frps.ini index e3087d87..e96fb177 100644 --- a/conf/frps.ini +++ b/conf/frps.ini @@ -4,6 +4,7 @@ bind_addr = 0.0.0.0 bind_port = 7000 # if you want to support virtual host, you must set the http port for listening (optional) vhost_http_port = 80 +vhost_https_port = 443 # if you want to configure or reload frps by dashboard, dashboard_port must be set dashboard_port = 7500 # console or real logFile path like ./frps.log @@ -20,7 +21,7 @@ bind_addr = 0.0.0.0 listen_port = 6000 [web01] -type = http +type = https auth_token = 123 # if proxy type equals http, custom_domains must be set separated by commas custom_domains = web01.yourdomain.com,web01.yourdomain2.com diff --git a/src/frp/cmd/frps/main.go b/src/frp/cmd/frps/main.go index 910827f0..de7f7bed 100644 --- a/src/frp/cmd/frps/main.go +++ b/src/frp/cmd/frps/main.go @@ -143,12 +143,25 @@ func main() { log.Error("Create vhost http listener error, %v", err) os.Exit(1) } - server.VhostMuxer, err = vhost.NewHttpMuxer(vhostListener, 30*time.Second) + server.VhostHttpMuxer, err = vhost.NewHttpMuxer(vhostListener, 30*time.Second) if err != nil { log.Error("Create vhost httpMuxer error, %v", err) } } + // create vhost if VhostHttpPort != 0 + if server.VhostHttpsPort != 0 { + vhostListener, err := conn.Listen(server.BindAddr, server.VhostHttpsPort) + if err != nil { + log.Error("Create vhost https listener error, %v", err) + os.Exit(1) + } + server.VhostHttpsMuxer, err = vhost.NewHttpsMuxer(vhostListener, 30*time.Second) + if err != nil { + log.Error("Create vhost httpsMuxer error, %v", err) + } + } + // create dashboard web server if DashboardPort is set, so it won't be 0 if server.DashboardPort != 0 { err := server.RunDashboardServer(server.BindAddr, server.DashboardPort) diff --git a/src/frp/models/client/config.go b/src/frp/models/client/config.go index 1d19c331..6bf1e40c 100644 --- a/src/frp/models/client/config.go +++ b/src/frp/models/client/config.go @@ -115,7 +115,7 @@ func LoadConf(confFile string) (err error) { proxyClient.Type = "tcp" typeStr, ok := section["type"] if ok { - if typeStr != "tcp" && typeStr != "http" { + if typeStr != "tcp" && typeStr != "http" && typeStr != "https" { return fmt.Errorf("Parse ini file error: proxy [%s] type error", proxyClient.Name) } proxyClient.Type = typeStr diff --git a/src/frp/models/server/config.go b/src/frp/models/server/config.go index 14b97531..4c76c667 100644 --- a/src/frp/models/server/config.go +++ b/src/frp/models/server/config.go @@ -32,6 +32,7 @@ var ( BindAddr string = "0.0.0.0" BindPort int64 = 7000 VhostHttpPort int64 = 0 // if VhostHttpPort equals 0, don't listen a public port for http + VhostHttpsPort int64 = 0 // if VhostHttpsPort equals 0, don't listen a public port for http DashboardPort int64 = 0 // if DashboardPort equals 0, dashboard is not available LogFile string = "console" LogWay string = "console" // console or file @@ -40,7 +41,8 @@ var ( HeartBeatTimeout int64 = 90 UserConnTimeout int64 = 10 - VhostMuxer *vhost.HttpMuxer + VhostHttpMuxer *vhost.HttpMuxer + VhostHttpsMuxer *vhost.HttpsMuxer ProxyServers map[string]*ProxyServer = make(map[string]*ProxyServer) // all proxy servers info and resources ProxyServersMutex sync.RWMutex ) @@ -91,6 +93,14 @@ func loadCommonConf(confFile string) error { VhostHttpPort = 0 } + tmpStr, ok = conf.Get("common", "vhost_https_port") + if ok { + VhostHttpsPort, _ = strconv.ParseInt(tmpStr, 10, 64) + } else { + VhostHttpsPort = 0 + } + vhost.VhostHttpsPort = VhostHttpsPort + tmpStr, ok = conf.Get("common", "dashboard_port") if ok { DashboardPort, _ = strconv.ParseInt(tmpStr, 10, 64) @@ -135,7 +145,7 @@ func loadProxyConf(confFile string) (proxyServers map[string]*ProxyServer, err e proxyServer.Type, ok = section["type"] if ok { - if proxyServer.Type != "tcp" && proxyServer.Type != "http" { + if proxyServer.Type != "tcp" && proxyServer.Type != "http" && proxyServer.Type != "https" { return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] type error", proxyServer.Name) } } else { @@ -179,6 +189,23 @@ func loadProxyConf(confFile string) (proxyServers map[string]*ProxyServer, err e proxyServer.CustomDomains[i] = strings.ToLower(strings.TrimSpace(domain)) + suffix } } + } else if proxyServer.Type == "https" { + // for https + domainStr, ok := section["custom_domains"] + if ok { + var suffix string + if VhostHttpsPort != 443 { + suffix = fmt.Sprintf(":%d", VhostHttpsPort) + } + proxyServer.CustomDomains = strings.Split(domainStr, ",") + if len(proxyServer.CustomDomains) == 0 { + return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] custom_domains must be set when type equals http", proxyServer.Name) + } + for i, domain := range proxyServer.CustomDomains { + proxyServer.CustomDomains[i] = strings.ToLower(strings.TrimSpace(domain)) + suffix + } + log.Info("proxyServer: %+v", proxyServer.CustomDomains) + } } proxyServers[proxyServer.Name] = proxyServer } diff --git a/src/frp/models/server/server.go b/src/frp/models/server/server.go index b9affdfb..d95e6eb5 100644 --- a/src/frp/models/server/server.go +++ b/src/frp/models/server/server.go @@ -100,7 +100,15 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) { p.listeners = append(p.listeners, l) } else if p.Type == "http" { for _, domain := range p.CustomDomains { - l, err := VhostMuxer.Listen(domain) + l, err := VhostHttpMuxer.Listen(domain) + if err != nil { + return err + } + p.listeners = append(p.listeners, l) + } + } else if p.Type == "https" { + for _, domain := range p.CustomDomains { + l, err := VhostHttpsMuxer.Listen(domain) if err != nil { return err } diff --git a/src/frp/utils/vhost/vhost_https.go b/src/frp/utils/vhost/vhost_https.go new file mode 100644 index 00000000..23d81ff8 --- /dev/null +++ b/src/frp/utils/vhost/vhost_https.go @@ -0,0 +1,217 @@ +// Copyright 2016 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vhost + +import ( + _ "bufio" + _ "bytes" + _ "crypto/tls" + "errors" + "fmt" + "frp/utils/conn" + "frp/utils/log" + "io" + _ "io/ioutil" + "net" + _ "net/http" + "strings" + _ "sync" + "time" +) + +var ( + maxHandshake int64 = 65536 // maximum handshake we support (protocol max is 16 MB) + VhostHttpsPort int64 = 443 +) + +const ( + typeClientHello uint8 = 1 // Type client hello +) + +// TLS extension numbers +const ( + extensionServerName uint16 = 0 + extensionStatusRequest uint16 = 5 + extensionSupportedCurves uint16 = 10 + extensionSupportedPoints uint16 = 11 + extensionSignatureAlgorithms uint16 = 13 + extensionALPN uint16 = 16 + extensionSCT uint16 = 18 + extensionSessionTicket uint16 = 35 + extensionNextProtoNeg uint16 = 13172 // not IANA assigned + extensionRenegotiationInfo uint16 = 0xff01 +) + +type HttpsMuxer struct { + *VhostMuxer +} + +/* + RFC document: http://tools.ietf.org/html/rfc5246 +*/ + +func errMsgToLog(format string, a ...interface{}) error { + errMsg := fmt.Sprintf(format, a...) + log.Warn(errMsg) + return errors.New(errMsg) +} + +func readHandshake(rd io.Reader) (string, error) { + + data := make([]byte, 1024) + length, err := rd.Read(data) + if err != nil { + return "", errMsgToLog("read err:%v", err) + } else { + if length < 47 { + return "", errMsgToLog("readHandshake: proto length[%d] is too short", length) + } + } + data = data[:length] + //log.Warn("data: %+v", data) + if uint8(data[5]) != typeClientHello { + return "", errMsgToLog("readHandshake: type[%d] is not clientHello", uint16(data[5])) + } + + //version and random + //tlsVersion := uint16(data[9])<<8 | uint16(data[10]) + //random := data[11:43] + + //session + sessionIdLen := int(data[43]) + if sessionIdLen > 32 || len(data) < 44+sessionIdLen { + return "", errMsgToLog("readHandshake: sessionIdLen[%d] is long", sessionIdLen) + } + data = data[44+sessionIdLen:] + if len(data) < 2 { + return "", errMsgToLog("readHandshake: dataLen[%d] after session is short", len(data)) + } + + // cipher suite numbers + cipherSuiteLen := int(data[0])<<8 | int(data[1]) + if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen { + //return "", errMsgToLog("readHandshake: cipherSuiteLen[%d] is long", sessionIdLen) + return "", errMsgToLog("readHandshake: dataLen[%d] after cipher suite is short", len(data)) + } + data = data[2+cipherSuiteLen:] + if len(data) < 1 { + return "", errMsgToLog("readHandshake: cipherSuiteLen[%d] is long", cipherSuiteLen) + } + + //compression method + compressionMethodsLen := int(data[0]) + if len(data) < 1+compressionMethodsLen { + return "", errMsgToLog("readHandshake: compressionMethodsLen[%d] is long", compressionMethodsLen) + //return false + } + + data = data[1+compressionMethodsLen:] + + if len(data) == 0 { + // ClientHello is optionally followed by extension data + //return true + return "", errMsgToLog("readHandshake: there is no extension data to get servername") + } + if len(data) < 2 { + return "", errMsgToLog("readHandshake: extension dataLen[%d] is too short") + } + + extensionsLength := int(data[0])<<8 | int(data[1]) + data = data[2:] + if extensionsLength != len(data) { + return "", errMsgToLog("readHandshake: extensionsLen[%d] is not equal to dataLen[%d]", extensionsLength, len(data)) + } + for len(data) != 0 { + if len(data) < 4 { + return "", errMsgToLog("readHandshake: extensionsDataLen[%d] is too short", len(data)) + } + extension := uint16(data[0])<<8 | uint16(data[1]) + length := int(data[2])<<8 | int(data[3]) + data = data[4:] + if len(data) < length { + return "", errMsgToLog("readHandshake: extensionLen[%d] is long", length) + //return false + } + + switch extension { + case extensionRenegotiationInfo: + if length != 1 || data[0] != 0 { + return "", errMsgToLog("readHandshake: extension reNegotiationInfoLen[%d] is short", length) + } + case extensionNextProtoNeg: + case extensionStatusRequest: + case extensionServerName: + d := data[:length] + if len(d) < 2 { + return "", errMsgToLog("readHandshake: remiaining dataLen[%d] is short", len(d)) + } + namesLen := int(d[0])<<8 | int(d[1]) + d = d[2:] + if len(d) != namesLen { + return "", errMsgToLog("readHandshake: nameListLen[%d] is not equal to dataLen[%d]", namesLen, len(d)) + } + for len(d) > 0 { + if len(d) < 3 { + return "", errMsgToLog("readHandshake: extension serverNameLen[%d] is short", len(d)) + } + nameType := d[0] + nameLen := int(d[1])<<8 | int(d[2]) + d = d[3:] + if len(d) < nameLen { + return "", errMsgToLog("readHandshake: nameLen[%d] is not equal to dataLen[%d]", nameLen, len(d)) + } + if nameType == 0 { + suffix := "" + if VhostHttpsPort != 443 { + suffix = fmt.Sprintf(":%d", VhostHttpsPort) + } + serverName := string(d[:nameLen]) + domain := strings.ToLower(strings.TrimSpace(serverName)) + suffix + return domain, nil + break + } + d = d[nameLen:] + } + } + data = data[length:] + } + //return "test.codermao.com:8082", nil + return "", errMsgToLog("Unknow error") +} + +func GetHttpsHostname(c *conn.Conn) (sc net.Conn, routerName string, err error) { + log.Info("GetHttpsHostname") + sc, rd := newShareConn(c.TcpConn) + + host, err := readHandshake(rd) + if err != nil { + return sc, "", err + } + /* + if _, ok := c.TcpConn.(*tls.Conn); ok { + log.Warn("convert to tlsConn success") + } else { + log.Warn("convert to tlsConn error") + }*/ + //tcpConn. + log.Debug("GetHttpsHostname: %s", host) + + return sc, host, nil +} + +func NewHttpsMuxer(listener *conn.Listener, timeout time.Duration) (*HttpsMuxer, error) { + mux, err := NewVhostMuxer(listener, GetHttpsHostname, timeout) + return &HttpsMuxer{mux}, err +}