From 6a1f15b25ed3fe5b71b097a187d0975c31445a73 Mon Sep 17 00:00:00 2001 From: fatedier Date: Thu, 25 Apr 2019 12:01:57 +0800 Subject: [PATCH] support proxy protocol in unix_domain_socket --- client/proxy/proxy.go | 65 ++++++++++++++++------------- models/plugin/http_proxy.go | 2 +- models/plugin/https2http.go | 7 +--- models/plugin/plugin.go | 2 +- models/plugin/socks5.go | 2 +- models/plugin/static_file.go | 2 +- models/plugin/unix_domain_socket.go | 5 ++- 7 files changed, 45 insertions(+), 40 deletions(-) diff --git a/client/proxy/proxy.go b/client/proxy/proxy.go index c68fe708..bc230f40 100644 --- a/client/proxy/proxy.go +++ b/client/proxy/proxy.go @@ -503,10 +503,43 @@ func HandleTcpWorkConnection(localInfo *config.LocalSvrConf, proxyPlugin plugin. remote = frpIo.WithCompression(remote) } + // check if we need to send proxy protocol info + var extraInfo []byte + if baseInfo.ProxyProtocolVersion != "" { + if m.SrcAddr != "" && m.SrcPort != 0 { + if m.DstAddr == "" { + m.DstAddr = "127.0.0.1" + } + h := &pp.Header{ + Command: pp.PROXY, + SourceAddress: net.ParseIP(m.SrcAddr), + SourcePort: m.SrcPort, + DestinationAddress: net.ParseIP(m.DstAddr), + DestinationPort: m.DstPort, + } + + if h.SourceAddress.To16() == nil { + h.TransportProtocol = pp.TCPv4 + } else { + h.TransportProtocol = pp.TCPv6 + } + + if baseInfo.ProxyProtocolVersion == "v1" { + h.Version = 1 + } else if baseInfo.ProxyProtocolVersion == "v2" { + h.Version = 2 + } + + buf := bytes.NewBuffer(nil) + h.WriteTo(buf) + extraInfo = buf.Bytes() + } + } + if proxyPlugin != nil { // if plugin is set, let plugin handle connections first workConn.Debug("handle by plugin: %s", proxyPlugin.Name()) - proxyPlugin.Handle(remote, workConn) + proxyPlugin.Handle(remote, workConn, extraInfo) workConn.Debug("handle by plugin finished") return } else { @@ -520,34 +553,8 @@ func HandleTcpWorkConnection(localInfo *config.LocalSvrConf, proxyPlugin plugin. workConn.Debug("join connections, localConn(l[%s] r[%s]) workConn(l[%s] r[%s])", localConn.LocalAddr().String(), localConn.RemoteAddr().String(), workConn.LocalAddr().String(), workConn.RemoteAddr().String()) - // check if we need to send proxy protocol info - if baseInfo.ProxyProtocolVersion != "" { - if m.SrcAddr != "" && m.SrcPort != 0 { - if m.DstAddr == "" { - m.DstAddr = "127.0.0.1" - } - h := &pp.Header{ - Command: pp.PROXY, - SourceAddress: net.ParseIP(m.SrcAddr), - SourcePort: m.SrcPort, - DestinationAddress: net.ParseIP(m.DstAddr), - DestinationPort: m.DstPort, - } - - if h.SourceAddress.To16() == nil { - h.TransportProtocol = pp.TCPv4 - } else { - h.TransportProtocol = pp.TCPv6 - } - - if baseInfo.ProxyProtocolVersion == "v1" { - h.Version = 1 - } else if baseInfo.ProxyProtocolVersion == "v2" { - h.Version = 2 - } - - h.WriteTo(localConn) - } + if len(extraInfo) > 0 { + localConn.Write(extraInfo) } frpIo.Join(localConn, remote) diff --git a/models/plugin/http_proxy.go b/models/plugin/http_proxy.go index a9ff6ef7..3afa2cb8 100644 --- a/models/plugin/http_proxy.go +++ b/models/plugin/http_proxy.go @@ -64,7 +64,7 @@ func (hp *HttpProxy) Name() string { return PluginHttpProxy } -func (hp *HttpProxy) Handle(conn io.ReadWriteCloser, realConn frpNet.Conn) { +func (hp *HttpProxy) Handle(conn io.ReadWriteCloser, realConn frpNet.Conn, extraBufToLocal []byte) { wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn) sc, rd := gnet.NewSharedConn(wrapConn) diff --git a/models/plugin/https2http.go b/models/plugin/https2http.go index 746995fe..6e84ad62 100644 --- a/models/plugin/https2http.go +++ b/models/plugin/https2http.go @@ -100,16 +100,11 @@ func (p *HTTPS2HTTPPlugin) genTLSConfig() (*tls.Config, error) { return config, nil } -func (p *HTTPS2HTTPPlugin) Handle(conn io.ReadWriteCloser, realConn frpNet.Conn) { +func (p *HTTPS2HTTPPlugin) Handle(conn io.ReadWriteCloser, realConn frpNet.Conn, extraBufToLocal []byte) { wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn) p.l.PutConn(wrapConn) } -func (p *HTTPS2HTTPPlugin) handleRequest(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("hello")) - return -} - func (p *HTTPS2HTTPPlugin) Name() string { return PluginHTTPS2HTTP } diff --git a/models/plugin/plugin.go b/models/plugin/plugin.go index 653e48a2..cfad5510 100644 --- a/models/plugin/plugin.go +++ b/models/plugin/plugin.go @@ -46,7 +46,7 @@ func Create(name string, params map[string]string) (p Plugin, err error) { type Plugin interface { Name() string - Handle(conn io.ReadWriteCloser, realConn frpNet.Conn) + Handle(conn io.ReadWriteCloser, realConn frpNet.Conn, extraBufToLocal []byte) Close() error } diff --git a/models/plugin/socks5.go b/models/plugin/socks5.go index fba9f5df..447602a9 100644 --- a/models/plugin/socks5.go +++ b/models/plugin/socks5.go @@ -53,7 +53,7 @@ func NewSocks5Plugin(params map[string]string) (p Plugin, err error) { return } -func (sp *Socks5Plugin) Handle(conn io.ReadWriteCloser, realConn frpNet.Conn) { +func (sp *Socks5Plugin) Handle(conn io.ReadWriteCloser, realConn frpNet.Conn, extraBufToLocal []byte) { defer conn.Close() wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn) sp.Server.ServeConn(wrapConn) diff --git a/models/plugin/static_file.go b/models/plugin/static_file.go index 52b0c0c6..080ff74f 100644 --- a/models/plugin/static_file.go +++ b/models/plugin/static_file.go @@ -72,7 +72,7 @@ func NewStaticFilePlugin(params map[string]string) (Plugin, error) { return sp, nil } -func (sp *StaticFilePlugin) Handle(conn io.ReadWriteCloser, realConn frpNet.Conn) { +func (sp *StaticFilePlugin) Handle(conn io.ReadWriteCloser, realConn frpNet.Conn, extraBufToLocal []byte) { wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn) sp.l.PutConn(wrapConn) } diff --git a/models/plugin/unix_domain_socket.go b/models/plugin/unix_domain_socket.go index b1ce6226..86833e25 100644 --- a/models/plugin/unix_domain_socket.go +++ b/models/plugin/unix_domain_socket.go @@ -53,11 +53,14 @@ func NewUnixDomainSocketPlugin(params map[string]string) (p Plugin, err error) { return } -func (uds *UnixDomainSocketPlugin) Handle(conn io.ReadWriteCloser, realConn frpNet.Conn) { +func (uds *UnixDomainSocketPlugin) Handle(conn io.ReadWriteCloser, realConn frpNet.Conn, extraBufToLocal []byte) { localConn, err := net.DialUnix("unix", nil, uds.UnixAddr) if err != nil { return } + if len(extraBufToLocal) > 0 { + localConn.Write(extraBufToLocal) + } frpIo.Join(localConn, conn) }