client/proxy: simplify the code (#3465)

This commit is contained in:
fatedier 2023-05-30 22:18:56 +08:00 committed by GitHub
parent 9aef3b9944
commit cceab7e1b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 147 additions and 224 deletions

View File

@ -0,0 +1,47 @@
// Copyright 2023 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"reflect"
"github.com/fatedier/frp/pkg/config"
)
func init() {
pxyConfs := []config.ProxyConf{
&config.TCPProxyConf{},
&config.HTTPProxyConf{},
&config.HTTPSProxyConf{},
&config.STCPProxyConf{},
&config.TCPMuxProxyConf{},
}
for _, cfg := range pxyConfs {
RegisterProxyFactory(reflect.TypeOf(cfg), NewGeneralTCPProxy)
}
}
// GeneralTCPProxy is a general implementation of Proxy interface for TCP protocol.
// If the default GeneralTCPProxy cannot meet the requirements, you can customize
// the implementation of the Proxy interface.
type GeneralTCPProxy struct {
*BaseProxy
}
func NewGeneralTCPProxy(baseProxy *BaseProxy, cfg config.ProxyConf) Proxy {
return &GeneralTCPProxy{
BaseProxy: baseProxy,
}
}

View File

@ -19,6 +19,7 @@ import (
"context" "context"
"io" "io"
"net" "net"
"reflect"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -37,6 +38,12 @@ import (
"github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/util/xlog"
) )
var proxyFactoryRegistry = map[reflect.Type]func(*BaseProxy, config.ProxyConf) Proxy{}
func RegisterProxyFactory(proxyConfType reflect.Type, factory func(*BaseProxy, config.ProxyConf) Proxy) {
proxyFactoryRegistry[proxyConfType] = factory
}
// Proxy defines how to handle work connections for different proxy type. // Proxy defines how to handle work connections for different proxy type.
type Proxy interface { type Proxy interface {
Run() error Run() error
@ -60,233 +67,74 @@ func NewProxy(
} }
baseProxy := BaseProxy{ baseProxy := BaseProxy{
clientCfg: clientCfg, baseProxyConfig: pxyConf.GetBaseConfig(),
limiter: limiter, clientCfg: clientCfg,
msgTransporter: msgTransporter, limiter: limiter,
xl: xlog.FromContextSafe(ctx), msgTransporter: msgTransporter,
ctx: ctx, xl: xlog.FromContextSafe(ctx),
ctx: ctx,
} }
switch cfg := pxyConf.(type) {
case *config.TCPProxyConf: factory := proxyFactoryRegistry[reflect.TypeOf(pxyConf)]
pxy = &TCPProxy{ if factory == nil {
BaseProxy: &baseProxy, return nil
cfg: cfg,
}
case *config.TCPMuxProxyConf:
pxy = &TCPMuxProxy{
BaseProxy: &baseProxy,
cfg: cfg,
}
case *config.UDPProxyConf:
pxy = &UDPProxy{
BaseProxy: &baseProxy,
cfg: cfg,
}
case *config.HTTPProxyConf:
pxy = &HTTPProxy{
BaseProxy: &baseProxy,
cfg: cfg,
}
case *config.HTTPSProxyConf:
pxy = &HTTPSProxy{
BaseProxy: &baseProxy,
cfg: cfg,
}
case *config.STCPProxyConf:
pxy = &STCPProxy{
BaseProxy: &baseProxy,
cfg: cfg,
}
case *config.XTCPProxyConf:
pxy = &XTCPProxy{
BaseProxy: &baseProxy,
cfg: cfg,
}
case *config.SUDPProxyConf:
pxy = &SUDPProxy{
BaseProxy: &baseProxy,
cfg: cfg,
closeCh: make(chan struct{}),
}
} }
return return factory(&baseProxy, pxyConf)
} }
type BaseProxy struct { type BaseProxy struct {
closed bool baseProxyConfig *config.BaseProxyConf
clientCfg config.ClientCommonConf clientCfg config.ClientCommonConf
msgTransporter transport.MessageTransporter msgTransporter transport.MessageTransporter
limiter *rate.Limiter limiter *rate.Limiter
// proxyPlugin is used to handle connections instead of dialing to local service.
// It's only validate for TCP protocol now.
proxyPlugin plugin.Plugin
mu sync.RWMutex mu sync.RWMutex
xl *xlog.Logger xl *xlog.Logger
ctx context.Context ctx context.Context
} }
// TCP func (pxy *BaseProxy) Run() error {
type TCPProxy struct { if pxy.baseProxyConfig.Plugin != "" {
*BaseProxy p, err := plugin.Create(pxy.baseProxyConfig.Plugin, pxy.baseProxyConfig.PluginParams)
cfg *config.TCPProxyConf
proxyPlugin plugin.Plugin
}
func (pxy *TCPProxy) Run() (err error) {
if pxy.cfg.Plugin != "" {
pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams)
if err != nil { if err != nil {
return return err
} }
pxy.proxyPlugin = p
} }
return return nil
} }
func (pxy *TCPProxy) Close() { func (pxy *BaseProxy) Close() {
if pxy.proxyPlugin != nil { if pxy.proxyPlugin != nil {
pxy.proxyPlugin.Close() pxy.proxyPlugin.Close()
} }
} }
func (pxy *TCPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { func (pxy *BaseProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter, pxy.HandleTCPWorkConnection(conn, m, []byte(pxy.clientCfg.Token))
conn, []byte(pxy.clientCfg.Token), m)
}
// TCP Multiplexer
type TCPMuxProxy struct {
*BaseProxy
cfg *config.TCPMuxProxyConf
proxyPlugin plugin.Plugin
}
func (pxy *TCPMuxProxy) Run() (err error) {
if pxy.cfg.Plugin != "" {
pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams)
if err != nil {
return
}
}
return
}
func (pxy *TCPMuxProxy) Close() {
if pxy.proxyPlugin != nil {
pxy.proxyPlugin.Close()
}
}
func (pxy *TCPMuxProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter,
conn, []byte(pxy.clientCfg.Token), m)
}
// HTTP
type HTTPProxy struct {
*BaseProxy
cfg *config.HTTPProxyConf
proxyPlugin plugin.Plugin
}
func (pxy *HTTPProxy) Run() (err error) {
if pxy.cfg.Plugin != "" {
pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams)
if err != nil {
return
}
}
return
}
func (pxy *HTTPProxy) Close() {
if pxy.proxyPlugin != nil {
pxy.proxyPlugin.Close()
}
}
func (pxy *HTTPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter,
conn, []byte(pxy.clientCfg.Token), m)
}
// HTTPS
type HTTPSProxy struct {
*BaseProxy
cfg *config.HTTPSProxyConf
proxyPlugin plugin.Plugin
}
func (pxy *HTTPSProxy) Run() (err error) {
if pxy.cfg.Plugin != "" {
pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams)
if err != nil {
return
}
}
return
}
func (pxy *HTTPSProxy) Close() {
if pxy.proxyPlugin != nil {
pxy.proxyPlugin.Close()
}
}
func (pxy *HTTPSProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter,
conn, []byte(pxy.clientCfg.Token), m)
}
// STCP
type STCPProxy struct {
*BaseProxy
cfg *config.STCPProxyConf
proxyPlugin plugin.Plugin
}
func (pxy *STCPProxy) Run() (err error) {
if pxy.cfg.Plugin != "" {
pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams)
if err != nil {
return
}
}
return
}
func (pxy *STCPProxy) Close() {
if pxy.proxyPlugin != nil {
pxy.proxyPlugin.Close()
}
}
func (pxy *STCPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter,
conn, []byte(pxy.clientCfg.Token), m)
} }
// Common handler for tcp work connections. // Common handler for tcp work connections.
func HandleTCPWorkConnection(ctx context.Context, localInfo *config.LocalSvrConf, proxyPlugin plugin.Plugin, func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWorkConn, encKey []byte) {
baseInfo *config.BaseProxyConf, limiter *rate.Limiter, workConn net.Conn, encKey []byte, m *msg.StartWorkConn, xl := pxy.xl
) { baseConfig := pxy.baseProxyConfig
xl := xlog.FromContextSafe(ctx)
var ( var (
remote io.ReadWriteCloser remote io.ReadWriteCloser
err error err error
) )
remote = workConn remote = workConn
if limiter != nil { if pxy.limiter != nil {
remote = libio.WrapReadWriteCloser(limit.NewReader(workConn, limiter), limit.NewWriter(workConn, limiter), func() error { remote = libio.WrapReadWriteCloser(limit.NewReader(workConn, pxy.limiter), limit.NewWriter(workConn, pxy.limiter), func() error {
return workConn.Close() return workConn.Close()
}) })
} }
xl.Trace("handle tcp work connection, use_encryption: %t, use_compression: %t", xl.Trace("handle tcp work connection, use_encryption: %t, use_compression: %t",
baseInfo.UseEncryption, baseInfo.UseCompression) baseConfig.UseEncryption, baseConfig.UseCompression)
if baseInfo.UseEncryption { if baseConfig.UseEncryption {
remote, err = libio.WithEncryption(remote, encKey) remote, err = libio.WithEncryption(remote, encKey)
if err != nil { if err != nil {
workConn.Close() workConn.Close()
@ -294,13 +142,13 @@ func HandleTCPWorkConnection(ctx context.Context, localInfo *config.LocalSvrConf
return return
} }
} }
if baseInfo.UseCompression { if baseConfig.UseCompression {
remote = libio.WithCompression(remote) remote = libio.WithCompression(remote)
} }
// check if we need to send proxy protocol info // check if we need to send proxy protocol info
var extraInfo []byte var extraInfo []byte
if baseInfo.ProxyProtocolVersion != "" { if baseConfig.ProxyProtocolVersion != "" {
if m.SrcAddr != "" && m.SrcPort != 0 { if m.SrcAddr != "" && m.SrcPort != 0 {
if m.DstAddr == "" { if m.DstAddr == "" {
m.DstAddr = "127.0.0.1" m.DstAddr = "127.0.0.1"
@ -319,9 +167,9 @@ func HandleTCPWorkConnection(ctx context.Context, localInfo *config.LocalSvrConf
h.TransportProtocol = pp.TCPv6 h.TransportProtocol = pp.TCPv6
} }
if baseInfo.ProxyProtocolVersion == "v1" { if baseConfig.ProxyProtocolVersion == "v1" {
h.Version = 1 h.Version = 1
} else if baseInfo.ProxyProtocolVersion == "v2" { } else if baseConfig.ProxyProtocolVersion == "v2" {
h.Version = 2 h.Version = 2
} }
@ -331,21 +179,21 @@ func HandleTCPWorkConnection(ctx context.Context, localInfo *config.LocalSvrConf
} }
} }
if proxyPlugin != nil { if pxy.proxyPlugin != nil {
// if plugin is set, let plugin handle connections first // if plugin is set, let plugin handle connection first
xl.Debug("handle by plugin: %s", proxyPlugin.Name()) xl.Debug("handle by plugin: %s", pxy.proxyPlugin.Name())
proxyPlugin.Handle(remote, workConn, extraInfo) pxy.proxyPlugin.Handle(remote, workConn, extraInfo)
xl.Debug("handle by plugin finished") xl.Debug("handle by plugin finished")
return return
} }
localConn, err := libdial.Dial( localConn, err := libdial.Dial(
net.JoinHostPort(localInfo.LocalIP, strconv.Itoa(localInfo.LocalPort)), net.JoinHostPort(baseConfig.LocalIP, strconv.Itoa(baseConfig.LocalPort)),
libdial.WithTimeout(10*time.Second), libdial.WithTimeout(10*time.Second),
) )
if err != nil { if err != nil {
workConn.Close() workConn.Close()
xl.Error("connect to local service [%s:%d] error: %v", localInfo.LocalIP, localInfo.LocalPort, err) xl.Error("connect to local service [%s:%d] error: %v", baseConfig.LocalIP, baseConfig.LocalPort, err)
return return
} }

View File

@ -17,6 +17,7 @@ package proxy
import ( import (
"io" "io"
"net" "net"
"reflect"
"strconv" "strconv"
"sync" "sync"
"time" "time"
@ -31,6 +32,10 @@ import (
utilnet "github.com/fatedier/frp/pkg/util/net" utilnet "github.com/fatedier/frp/pkg/util/net"
) )
func init() {
RegisterProxyFactory(reflect.TypeOf(&config.SUDPProxyConf{}), NewSUDPProxy)
}
type SUDPProxy struct { type SUDPProxy struct {
*BaseProxy *BaseProxy
@ -41,6 +46,18 @@ type SUDPProxy struct {
closeCh chan struct{} closeCh chan struct{}
} }
func NewSUDPProxy(baseProxy *BaseProxy, cfg config.ProxyConf) Proxy {
unwrapped, ok := cfg.(*config.SUDPProxyConf)
if !ok {
return nil
}
return &SUDPProxy{
BaseProxy: baseProxy,
cfg: unwrapped,
closeCh: make(chan struct{}),
}
}
func (pxy *SUDPProxy) Run() (err error) { func (pxy *SUDPProxy) Run() (err error) {
pxy.localAddr, err = net.ResolveUDPAddr("udp", net.JoinHostPort(pxy.cfg.LocalIP, strconv.Itoa(pxy.cfg.LocalPort))) pxy.localAddr, err = net.ResolveUDPAddr("udp", net.JoinHostPort(pxy.cfg.LocalIP, strconv.Itoa(pxy.cfg.LocalPort)))
if err != nil { if err != nil {

View File

@ -17,6 +17,7 @@ package proxy
import ( import (
"io" "io"
"net" "net"
"reflect"
"strconv" "strconv"
"time" "time"
@ -30,7 +31,10 @@ import (
utilnet "github.com/fatedier/frp/pkg/util/net" utilnet "github.com/fatedier/frp/pkg/util/net"
) )
// UDP func init() {
RegisterProxyFactory(reflect.TypeOf(&config.UDPProxyConf{}), NewUDPProxy)
}
type UDPProxy struct { type UDPProxy struct {
*BaseProxy *BaseProxy
@ -42,6 +46,18 @@ type UDPProxy struct {
// include msg.UDPPacket and msg.Ping // include msg.UDPPacket and msg.Ping
sendCh chan msg.Message sendCh chan msg.Message
workConn net.Conn workConn net.Conn
closed bool
}
func NewUDPProxy(baseProxy *BaseProxy, cfg config.ProxyConf) Proxy {
unwrapped, ok := cfg.(*config.UDPProxyConf)
if !ok {
return nil
}
return &UDPProxy{
BaseProxy: baseProxy,
cfg: unwrapped,
}
} }
func (pxy *UDPProxy) Run() (err error) { func (pxy *UDPProxy) Run() (err error) {

View File

@ -17,6 +17,7 @@ package proxy
import ( import (
"io" "io"
"net" "net"
"reflect"
"time" "time"
fmux "github.com/hashicorp/yamux" fmux "github.com/hashicorp/yamux"
@ -25,32 +26,28 @@ import (
"github.com/fatedier/frp/pkg/config" "github.com/fatedier/frp/pkg/config"
"github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/nathole" "github.com/fatedier/frp/pkg/nathole"
plugin "github.com/fatedier/frp/pkg/plugin/client"
"github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/transport"
utilnet "github.com/fatedier/frp/pkg/util/net" utilnet "github.com/fatedier/frp/pkg/util/net"
) )
// XTCP func init() {
RegisterProxyFactory(reflect.TypeOf(&config.XTCPProxyConf{}), NewXTCPProxy)
}
type XTCPProxy struct { type XTCPProxy struct {
*BaseProxy *BaseProxy
cfg *config.XTCPProxyConf cfg *config.XTCPProxyConf
proxyPlugin plugin.Plugin
} }
func (pxy *XTCPProxy) Run() (err error) { func NewXTCPProxy(baseProxy *BaseProxy, cfg config.ProxyConf) Proxy {
if pxy.cfg.Plugin != "" { unwrapped, ok := cfg.(*config.XTCPProxyConf)
pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams) if !ok {
if err != nil { return nil
return
}
} }
return return &XTCPProxy{
} BaseProxy: baseProxy,
cfg: unwrapped,
func (pxy *XTCPProxy) Close() {
if pxy.proxyPlugin != nil {
pxy.proxyPlugin.Close()
} }
} }
@ -155,8 +152,7 @@ func (pxy *XTCPProxy) listenByKCP(listenConn *net.UDPConn, raddr *net.UDPAddr, s
xl.Error("accept connection error: %v", err) xl.Error("accept connection error: %v", err)
return return
} }
go HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter, go pxy.HandleTCPWorkConnection(muxConn, startWorkConnMsg, []byte(pxy.cfg.Sk))
muxConn, []byte(pxy.cfg.Sk), startWorkConnMsg)
} }
} }
@ -194,7 +190,6 @@ func (pxy *XTCPProxy) listenByQUIC(listenConn *net.UDPConn, _ *net.UDPAddr, star
_ = c.CloseWithError(0, "") _ = c.CloseWithError(0, "")
return return
} }
go HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter, go pxy.HandleTCPWorkConnection(utilnet.QuicStreamToNetConn(stream, c), startWorkConnMsg, []byte(pxy.cfg.Sk))
utilnet.QuicStreamToNetConn(stream, c), []byte(pxy.cfg.Sk), startWorkConnMsg)
} }
} }