diff --git a/client/proxy/general_tcp.go b/client/proxy/general_tcp.go new file mode 100644 index 00000000..7efe476f --- /dev/null +++ b/client/proxy/general_tcp.go @@ -0,0 +1,47 @@ +// Copyright 2023 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "reflect" + + "github.com/fatedier/frp/pkg/config" +) + +func init() { + pxyConfs := []config.ProxyConf{ + &config.TCPProxyConf{}, + &config.HTTPProxyConf{}, + &config.HTTPSProxyConf{}, + &config.STCPProxyConf{}, + &config.TCPMuxProxyConf{}, + } + for _, cfg := range pxyConfs { + RegisterProxyFactory(reflect.TypeOf(cfg), NewGeneralTCPProxy) + } +} + +// GeneralTCPProxy is a general implementation of Proxy interface for TCP protocol. +// If the default GeneralTCPProxy cannot meet the requirements, you can customize +// the implementation of the Proxy interface. +type GeneralTCPProxy struct { + *BaseProxy +} + +func NewGeneralTCPProxy(baseProxy *BaseProxy, cfg config.ProxyConf) Proxy { + return &GeneralTCPProxy{ + BaseProxy: baseProxy, + } +} diff --git a/client/proxy/proxy.go b/client/proxy/proxy.go index b336e13a..82d5d9f2 100644 --- a/client/proxy/proxy.go +++ b/client/proxy/proxy.go @@ -19,6 +19,7 @@ import ( "context" "io" "net" + "reflect" "strconv" "strings" "sync" @@ -37,6 +38,12 @@ import ( "github.com/fatedier/frp/pkg/util/xlog" ) +var proxyFactoryRegistry = map[reflect.Type]func(*BaseProxy, config.ProxyConf) Proxy{} + +func RegisterProxyFactory(proxyConfType reflect.Type, factory func(*BaseProxy, config.ProxyConf) Proxy) { + proxyFactoryRegistry[proxyConfType] = factory +} + // Proxy defines how to handle work connections for different proxy type. type Proxy interface { Run() error @@ -60,233 +67,74 @@ func NewProxy( } baseProxy := BaseProxy{ - clientCfg: clientCfg, - limiter: limiter, - msgTransporter: msgTransporter, - xl: xlog.FromContextSafe(ctx), - ctx: ctx, + baseProxyConfig: pxyConf.GetBaseConfig(), + clientCfg: clientCfg, + limiter: limiter, + msgTransporter: msgTransporter, + xl: xlog.FromContextSafe(ctx), + ctx: ctx, } - switch cfg := pxyConf.(type) { - case *config.TCPProxyConf: - pxy = &TCPProxy{ - BaseProxy: &baseProxy, - cfg: cfg, - } - case *config.TCPMuxProxyConf: - pxy = &TCPMuxProxy{ - BaseProxy: &baseProxy, - cfg: cfg, - } - case *config.UDPProxyConf: - pxy = &UDPProxy{ - BaseProxy: &baseProxy, - cfg: cfg, - } - case *config.HTTPProxyConf: - pxy = &HTTPProxy{ - BaseProxy: &baseProxy, - cfg: cfg, - } - case *config.HTTPSProxyConf: - pxy = &HTTPSProxy{ - BaseProxy: &baseProxy, - cfg: cfg, - } - case *config.STCPProxyConf: - pxy = &STCPProxy{ - BaseProxy: &baseProxy, - cfg: cfg, - } - case *config.XTCPProxyConf: - pxy = &XTCPProxy{ - BaseProxy: &baseProxy, - cfg: cfg, - } - case *config.SUDPProxyConf: - pxy = &SUDPProxy{ - BaseProxy: &baseProxy, - cfg: cfg, - closeCh: make(chan struct{}), - } + + factory := proxyFactoryRegistry[reflect.TypeOf(pxyConf)] + if factory == nil { + return nil } - return + return factory(&baseProxy, pxyConf) } type BaseProxy struct { - closed bool - clientCfg config.ClientCommonConf - msgTransporter transport.MessageTransporter - limiter *rate.Limiter + baseProxyConfig *config.BaseProxyConf + clientCfg config.ClientCommonConf + msgTransporter transport.MessageTransporter + limiter *rate.Limiter + // proxyPlugin is used to handle connections instead of dialing to local service. + // It's only validate for TCP protocol now. + proxyPlugin plugin.Plugin mu sync.RWMutex xl *xlog.Logger ctx context.Context } -// TCP -type TCPProxy struct { - *BaseProxy - - cfg *config.TCPProxyConf - proxyPlugin plugin.Plugin -} - -func (pxy *TCPProxy) Run() (err error) { - if pxy.cfg.Plugin != "" { - pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams) +func (pxy *BaseProxy) Run() error { + if pxy.baseProxyConfig.Plugin != "" { + p, err := plugin.Create(pxy.baseProxyConfig.Plugin, pxy.baseProxyConfig.PluginParams) if err != nil { - return + return err } + pxy.proxyPlugin = p } - return + return nil } -func (pxy *TCPProxy) Close() { +func (pxy *BaseProxy) Close() { if pxy.proxyPlugin != nil { pxy.proxyPlugin.Close() } } -func (pxy *TCPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { - HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter, - conn, []byte(pxy.clientCfg.Token), m) -} - -// TCP Multiplexer -type TCPMuxProxy struct { - *BaseProxy - - cfg *config.TCPMuxProxyConf - proxyPlugin plugin.Plugin -} - -func (pxy *TCPMuxProxy) Run() (err error) { - if pxy.cfg.Plugin != "" { - pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams) - if err != nil { - return - } - } - return -} - -func (pxy *TCPMuxProxy) Close() { - if pxy.proxyPlugin != nil { - pxy.proxyPlugin.Close() - } -} - -func (pxy *TCPMuxProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { - HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter, - conn, []byte(pxy.clientCfg.Token), m) -} - -// HTTP -type HTTPProxy struct { - *BaseProxy - - cfg *config.HTTPProxyConf - proxyPlugin plugin.Plugin -} - -func (pxy *HTTPProxy) Run() (err error) { - if pxy.cfg.Plugin != "" { - pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams) - if err != nil { - return - } - } - return -} - -func (pxy *HTTPProxy) Close() { - if pxy.proxyPlugin != nil { - pxy.proxyPlugin.Close() - } -} - -func (pxy *HTTPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { - HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter, - conn, []byte(pxy.clientCfg.Token), m) -} - -// HTTPS -type HTTPSProxy struct { - *BaseProxy - - cfg *config.HTTPSProxyConf - proxyPlugin plugin.Plugin -} - -func (pxy *HTTPSProxy) Run() (err error) { - if pxy.cfg.Plugin != "" { - pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams) - if err != nil { - return - } - } - return -} - -func (pxy *HTTPSProxy) Close() { - if pxy.proxyPlugin != nil { - pxy.proxyPlugin.Close() - } -} - -func (pxy *HTTPSProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { - HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter, - conn, []byte(pxy.clientCfg.Token), m) -} - -// STCP -type STCPProxy struct { - *BaseProxy - - cfg *config.STCPProxyConf - proxyPlugin plugin.Plugin -} - -func (pxy *STCPProxy) Run() (err error) { - if pxy.cfg.Plugin != "" { - pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams) - if err != nil { - return - } - } - return -} - -func (pxy *STCPProxy) Close() { - if pxy.proxyPlugin != nil { - pxy.proxyPlugin.Close() - } -} - -func (pxy *STCPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { - HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter, - conn, []byte(pxy.clientCfg.Token), m) +func (pxy *BaseProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { + pxy.HandleTCPWorkConnection(conn, m, []byte(pxy.clientCfg.Token)) } // Common handler for tcp work connections. -func HandleTCPWorkConnection(ctx context.Context, localInfo *config.LocalSvrConf, proxyPlugin plugin.Plugin, - baseInfo *config.BaseProxyConf, limiter *rate.Limiter, workConn net.Conn, encKey []byte, m *msg.StartWorkConn, -) { - xl := xlog.FromContextSafe(ctx) +func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWorkConn, encKey []byte) { + xl := pxy.xl + baseConfig := pxy.baseProxyConfig var ( remote io.ReadWriteCloser err error ) remote = workConn - if limiter != nil { - remote = libio.WrapReadWriteCloser(limit.NewReader(workConn, limiter), limit.NewWriter(workConn, limiter), func() error { + if pxy.limiter != nil { + remote = libio.WrapReadWriteCloser(limit.NewReader(workConn, pxy.limiter), limit.NewWriter(workConn, pxy.limiter), func() error { return workConn.Close() }) } xl.Trace("handle tcp work connection, use_encryption: %t, use_compression: %t", - baseInfo.UseEncryption, baseInfo.UseCompression) - if baseInfo.UseEncryption { + baseConfig.UseEncryption, baseConfig.UseCompression) + if baseConfig.UseEncryption { remote, err = libio.WithEncryption(remote, encKey) if err != nil { workConn.Close() @@ -294,13 +142,13 @@ func HandleTCPWorkConnection(ctx context.Context, localInfo *config.LocalSvrConf return } } - if baseInfo.UseCompression { + if baseConfig.UseCompression { remote = libio.WithCompression(remote) } // check if we need to send proxy protocol info var extraInfo []byte - if baseInfo.ProxyProtocolVersion != "" { + if baseConfig.ProxyProtocolVersion != "" { if m.SrcAddr != "" && m.SrcPort != 0 { if m.DstAddr == "" { m.DstAddr = "127.0.0.1" @@ -319,9 +167,9 @@ func HandleTCPWorkConnection(ctx context.Context, localInfo *config.LocalSvrConf h.TransportProtocol = pp.TCPv6 } - if baseInfo.ProxyProtocolVersion == "v1" { + if baseConfig.ProxyProtocolVersion == "v1" { h.Version = 1 - } else if baseInfo.ProxyProtocolVersion == "v2" { + } else if baseConfig.ProxyProtocolVersion == "v2" { h.Version = 2 } @@ -331,21 +179,21 @@ func HandleTCPWorkConnection(ctx context.Context, localInfo *config.LocalSvrConf } } - if proxyPlugin != nil { - // if plugin is set, let plugin handle connections first - xl.Debug("handle by plugin: %s", proxyPlugin.Name()) - proxyPlugin.Handle(remote, workConn, extraInfo) + if pxy.proxyPlugin != nil { + // if plugin is set, let plugin handle connection first + xl.Debug("handle by plugin: %s", pxy.proxyPlugin.Name()) + pxy.proxyPlugin.Handle(remote, workConn, extraInfo) xl.Debug("handle by plugin finished") return } localConn, err := libdial.Dial( - net.JoinHostPort(localInfo.LocalIP, strconv.Itoa(localInfo.LocalPort)), + net.JoinHostPort(baseConfig.LocalIP, strconv.Itoa(baseConfig.LocalPort)), libdial.WithTimeout(10*time.Second), ) if err != nil { workConn.Close() - xl.Error("connect to local service [%s:%d] error: %v", localInfo.LocalIP, localInfo.LocalPort, err) + xl.Error("connect to local service [%s:%d] error: %v", baseConfig.LocalIP, baseConfig.LocalPort, err) return } diff --git a/client/proxy/sudp.go b/client/proxy/sudp.go index ff88bf74..9d61e228 100644 --- a/client/proxy/sudp.go +++ b/client/proxy/sudp.go @@ -17,6 +17,7 @@ package proxy import ( "io" "net" + "reflect" "strconv" "sync" "time" @@ -31,6 +32,10 @@ import ( utilnet "github.com/fatedier/frp/pkg/util/net" ) +func init() { + RegisterProxyFactory(reflect.TypeOf(&config.SUDPProxyConf{}), NewSUDPProxy) +} + type SUDPProxy struct { *BaseProxy @@ -41,6 +46,18 @@ type SUDPProxy struct { closeCh chan struct{} } +func NewSUDPProxy(baseProxy *BaseProxy, cfg config.ProxyConf) Proxy { + unwrapped, ok := cfg.(*config.SUDPProxyConf) + if !ok { + return nil + } + return &SUDPProxy{ + BaseProxy: baseProxy, + cfg: unwrapped, + closeCh: make(chan struct{}), + } +} + func (pxy *SUDPProxy) Run() (err error) { pxy.localAddr, err = net.ResolveUDPAddr("udp", net.JoinHostPort(pxy.cfg.LocalIP, strconv.Itoa(pxy.cfg.LocalPort))) if err != nil { diff --git a/client/proxy/udp.go b/client/proxy/udp.go index 0dc11794..adf92766 100644 --- a/client/proxy/udp.go +++ b/client/proxy/udp.go @@ -17,6 +17,7 @@ package proxy import ( "io" "net" + "reflect" "strconv" "time" @@ -30,7 +31,10 @@ import ( utilnet "github.com/fatedier/frp/pkg/util/net" ) -// UDP +func init() { + RegisterProxyFactory(reflect.TypeOf(&config.UDPProxyConf{}), NewUDPProxy) +} + type UDPProxy struct { *BaseProxy @@ -42,6 +46,18 @@ type UDPProxy struct { // include msg.UDPPacket and msg.Ping sendCh chan msg.Message workConn net.Conn + closed bool +} + +func NewUDPProxy(baseProxy *BaseProxy, cfg config.ProxyConf) Proxy { + unwrapped, ok := cfg.(*config.UDPProxyConf) + if !ok { + return nil + } + return &UDPProxy{ + BaseProxy: baseProxy, + cfg: unwrapped, + } } func (pxy *UDPProxy) Run() (err error) { diff --git a/client/proxy/xtcp.go b/client/proxy/xtcp.go index a25dc185..4ba8d50c 100644 --- a/client/proxy/xtcp.go +++ b/client/proxy/xtcp.go @@ -17,6 +17,7 @@ package proxy import ( "io" "net" + "reflect" "time" fmux "github.com/hashicorp/yamux" @@ -25,32 +26,28 @@ import ( "github.com/fatedier/frp/pkg/config" "github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/nathole" - plugin "github.com/fatedier/frp/pkg/plugin/client" "github.com/fatedier/frp/pkg/transport" utilnet "github.com/fatedier/frp/pkg/util/net" ) -// XTCP +func init() { + RegisterProxyFactory(reflect.TypeOf(&config.XTCPProxyConf{}), NewXTCPProxy) +} + type XTCPProxy struct { *BaseProxy - cfg *config.XTCPProxyConf - proxyPlugin plugin.Plugin + cfg *config.XTCPProxyConf } -func (pxy *XTCPProxy) Run() (err error) { - if pxy.cfg.Plugin != "" { - pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams) - if err != nil { - return - } +func NewXTCPProxy(baseProxy *BaseProxy, cfg config.ProxyConf) Proxy { + unwrapped, ok := cfg.(*config.XTCPProxyConf) + if !ok { + return nil } - return -} - -func (pxy *XTCPProxy) Close() { - if pxy.proxyPlugin != nil { - pxy.proxyPlugin.Close() + return &XTCPProxy{ + BaseProxy: baseProxy, + cfg: unwrapped, } } @@ -155,8 +152,7 @@ func (pxy *XTCPProxy) listenByKCP(listenConn *net.UDPConn, raddr *net.UDPAddr, s xl.Error("accept connection error: %v", err) return } - go HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter, - muxConn, []byte(pxy.cfg.Sk), startWorkConnMsg) + go pxy.HandleTCPWorkConnection(muxConn, startWorkConnMsg, []byte(pxy.cfg.Sk)) } } @@ -194,7 +190,6 @@ func (pxy *XTCPProxy) listenByQUIC(listenConn *net.UDPConn, _ *net.UDPAddr, star _ = c.CloseWithError(0, "") return } - go HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter, - utilnet.QuicStreamToNetConn(stream, c), []byte(pxy.cfg.Sk), startWorkConnMsg) + go pxy.HandleTCPWorkConnection(utilnet.QuicStreamToNetConn(stream, c), startWorkConnMsg, []byte(pxy.cfg.Sk)) } }