mirror of
https://github.com/fatedier/frp.git
synced 2024-12-14 10:41:01 +01:00
vhost: set DisableKeepAlives = false and fix websocket not work
This commit is contained in:
parent
c842558ace
commit
46f809d711
@ -17,6 +17,7 @@ package vhost
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
@ -59,20 +60,25 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) *
|
|||||||
req.URL.Scheme = "http"
|
req.URL.Scheme = "http"
|
||||||
url := req.Context().Value(RouteInfoURL).(string)
|
url := req.Context().Value(RouteInfoURL).(string)
|
||||||
oldHost := util.GetHostFromAddr(req.Context().Value(RouteInfoHost).(string))
|
oldHost := util.GetHostFromAddr(req.Context().Value(RouteInfoHost).(string))
|
||||||
host := rp.GetRealHost(oldHost, url)
|
rc := rp.GetRouteConfig(oldHost, url)
|
||||||
if host != "" {
|
if rc != nil {
|
||||||
req.Host = host
|
if rc.RewriteHost != "" {
|
||||||
|
req.Host = rc.RewriteHost
|
||||||
}
|
}
|
||||||
req.URL.Host = req.Host
|
// Set {domain}.{location} as URL host here to let http transport reuse connections.
|
||||||
|
req.URL.Host = rc.Domain + "." + base64.StdEncoding.EncodeToString([]byte(rc.Location))
|
||||||
|
|
||||||
headers := rp.GetHeaders(oldHost, url)
|
for k, v := range rc.Headers {
|
||||||
for k, v := range headers {
|
|
||||||
req.Header.Set(k, v)
|
req.Header.Set(k, v)
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
req.URL.Host = req.Host
|
||||||
|
}
|
||||||
|
|
||||||
},
|
},
|
||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
ResponseHeaderTimeout: rp.responseHeaderTimeout,
|
ResponseHeaderTimeout: rp.responseHeaderTimeout,
|
||||||
DisableKeepAlives: true,
|
IdleConnTimeout: 60 * time.Second,
|
||||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
url := ctx.Value(RouteInfoURL).(string)
|
url := ctx.Value(RouteInfoURL).(string)
|
||||||
host := util.GetHostFromAddr(ctx.Value(RouteInfoHost).(string))
|
host := util.GetHostFromAddr(ctx.Value(RouteInfoHost).(string))
|
||||||
@ -107,6 +113,14 @@ func (rp *HTTPReverseProxy) UnRegister(domain string, location string) {
|
|||||||
rp.vhostRouter.Del(domain, location)
|
rp.vhostRouter.Del(domain, location)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rp *HTTPReverseProxy) GetRouteConfig(domain string, location string) *RouteConfig {
|
||||||
|
vr, ok := rp.getVhost(domain, location)
|
||||||
|
if ok {
|
||||||
|
return vr.payload.(*RouteConfig)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (rp *HTTPReverseProxy) GetRealHost(domain string, location string) (host string) {
|
func (rp *HTTPReverseProxy) GetRealHost(domain string, location string) (host string) {
|
||||||
vr, ok := rp.getVhost(domain, location)
|
vr, ok := rp.getVhost(domain, location)
|
||||||
if ok {
|
if ok {
|
||||||
|
@ -139,6 +139,7 @@ func TestHealthCheck(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
httpSvc3 := mock.NewHTTPServer(15005, func(w http.ResponseWriter, r *http.Request) {
|
httpSvc3 := mock.NewHTTPServer(15005, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(time.Second)
|
||||||
w.Write([]byte("http3"))
|
w.Write([]byte("http3"))
|
||||||
})
|
})
|
||||||
err = httpSvc3.Start()
|
err = httpSvc3.Start()
|
||||||
@ -147,6 +148,7 @@ func TestHealthCheck(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
httpSvc4 := mock.NewHTTPServer(15006, func(w http.ResponseWriter, r *http.Request) {
|
httpSvc4 := mock.NewHTTPServer(15006, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(time.Second)
|
||||||
w.Write([]byte("http4"))
|
w.Write([]byte("http4"))
|
||||||
})
|
})
|
||||||
err = httpSvc4.Start()
|
err = httpSvc4.Start()
|
||||||
@ -277,16 +279,30 @@ func TestHealthCheck(t *testing.T) {
|
|||||||
|
|
||||||
// ****** load balancing type http ******
|
// ****** load balancing type http ******
|
||||||
result = make([]string, 0)
|
result = make([]string, 0)
|
||||||
|
var wait sync.WaitGroup
|
||||||
|
var mu sync.Mutex
|
||||||
|
wait.Add(2)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer wait.Done()
|
||||||
|
code, body, _, err := util.SendHTTPMsg("GET", "http://127.0.0.1:14000/xxx", "test.balancing.com", nil, "")
|
||||||
|
assert.NoError(err)
|
||||||
|
assert.Equal(200, code)
|
||||||
|
mu.Lock()
|
||||||
|
result = append(result, body)
|
||||||
|
mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer wait.Done()
|
||||||
code, body, _, err = util.SendHTTPMsg("GET", "http://127.0.0.1:14000/xxx", "test.balancing.com", nil, "")
|
code, body, _, err = util.SendHTTPMsg("GET", "http://127.0.0.1:14000/xxx", "test.balancing.com", nil, "")
|
||||||
assert.NoError(err)
|
assert.NoError(err)
|
||||||
assert.Equal(200, code)
|
assert.Equal(200, code)
|
||||||
|
mu.Lock()
|
||||||
result = append(result, body)
|
result = append(result, body)
|
||||||
|
mu.Unlock()
|
||||||
code, body, _, err = util.SendHTTPMsg("GET", "http://127.0.0.1:14000/xxx", "test.balancing.com", nil, "")
|
}()
|
||||||
assert.NoError(err)
|
wait.Wait()
|
||||||
assert.Equal(200, code)
|
|
||||||
result = append(result, body)
|
|
||||||
|
|
||||||
assert.Contains(result, "http3")
|
assert.Contains(result, "http3")
|
||||||
assert.Contains(result, "http4")
|
assert.Contains(result, "http4")
|
||||||
|
Loading…
Reference in New Issue
Block a user