2024-05-07 12:28:30 +02:00
|
|
|
//go:build !ios
|
2024-04-08 18:56:52 +02:00
|
|
|
|
|
|
|
package net
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"fmt"
|
|
|
|
"net"
|
|
|
|
"sync"
|
|
|
|
|
|
|
|
"github.com/hashicorp/go-multierror"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
)
|
|
|
|
|
|
|
|
type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error
|
|
|
|
type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error
|
|
|
|
|
|
|
|
var (
|
|
|
|
dialerDialHooksMutex sync.RWMutex
|
|
|
|
dialerDialHooks []DialerDialHookFunc
|
|
|
|
dialerCloseHooksMutex sync.RWMutex
|
|
|
|
dialerCloseHooks []DialerCloseHookFunc
|
|
|
|
)
|
|
|
|
|
|
|
|
// AddDialerHook allows adding a new hook to be executed before dialing.
|
|
|
|
func AddDialerHook(hook DialerDialHookFunc) {
|
|
|
|
dialerDialHooksMutex.Lock()
|
|
|
|
defer dialerDialHooksMutex.Unlock()
|
|
|
|
dialerDialHooks = append(dialerDialHooks, hook)
|
|
|
|
}
|
|
|
|
|
|
|
|
// AddDialerCloseHook allows adding a new hook to be executed on connection close.
|
|
|
|
func AddDialerCloseHook(hook DialerCloseHookFunc) {
|
|
|
|
dialerCloseHooksMutex.Lock()
|
|
|
|
defer dialerCloseHooksMutex.Unlock()
|
|
|
|
dialerCloseHooks = append(dialerCloseHooks, hook)
|
|
|
|
}
|
|
|
|
|
2024-05-07 12:28:30 +02:00
|
|
|
// RemoveDialerHooks removes all dialer hooks.
|
2024-04-08 18:56:52 +02:00
|
|
|
func RemoveDialerHooks() {
|
|
|
|
dialerDialHooksMutex.Lock()
|
|
|
|
defer dialerDialHooksMutex.Unlock()
|
|
|
|
dialerDialHooks = nil
|
|
|
|
|
|
|
|
dialerCloseHooksMutex.Lock()
|
|
|
|
defer dialerCloseHooksMutex.Unlock()
|
|
|
|
dialerCloseHooks = nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// DialContext wraps the net.Dialer's DialContext method to use the custom connection
|
|
|
|
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
2024-04-12 16:53:11 +02:00
|
|
|
if CustomRoutingDisabled() {
|
|
|
|
return d.Dialer.DialContext(ctx, network, address)
|
|
|
|
}
|
|
|
|
|
2024-04-08 18:56:52 +02:00
|
|
|
var resolver *net.Resolver
|
|
|
|
if d.Resolver != nil {
|
|
|
|
resolver = d.Resolver
|
|
|
|
}
|
|
|
|
|
|
|
|
connID := GenerateConnID()
|
|
|
|
if dialerDialHooks != nil {
|
2024-04-09 13:25:14 +02:00
|
|
|
if err := callDialerHooks(ctx, connID, address, resolver); err != nil {
|
2024-04-08 18:56:52 +02:00
|
|
|
log.Errorf("Failed to call dialer hooks: %v", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
conn, err := d.Dialer.DialContext(ctx, network, address)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("dial: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Wrap the connection in Conn to handle Close with hooks
|
|
|
|
return &Conn{Conn: conn, ID: connID}, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// Dial wraps the net.Dialer's Dial method to use the custom connection
|
|
|
|
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
|
|
|
|
return d.DialContext(context.Background(), network, address)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Conn wraps a net.Conn to override the Close method
|
|
|
|
type Conn struct {
|
|
|
|
net.Conn
|
|
|
|
ID ConnectionID
|
|
|
|
}
|
|
|
|
|
|
|
|
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
|
|
|
|
func (c *Conn) Close() error {
|
|
|
|
err := c.Conn.Close()
|
|
|
|
|
|
|
|
dialerCloseHooksMutex.RLock()
|
|
|
|
defer dialerCloseHooksMutex.RUnlock()
|
|
|
|
|
|
|
|
for _, hook := range dialerCloseHooks {
|
|
|
|
if err := hook(c.ID, &c.Conn); err != nil {
|
|
|
|
log.Errorf("Error executing dialer close hook: %v", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2024-04-09 13:25:14 +02:00
|
|
|
func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error {
|
2024-04-08 18:56:52 +02:00
|
|
|
host, _, err := net.SplitHostPort(address)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("split host and port: %w", err)
|
|
|
|
}
|
|
|
|
ips, err := resolver.LookupIPAddr(ctx, host)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("failed to resolve address %s: %w", address, err)
|
|
|
|
}
|
|
|
|
|
|
|
|
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
|
|
|
|
|
|
|
|
var result *multierror.Error
|
|
|
|
|
|
|
|
dialerDialHooksMutex.RLock()
|
|
|
|
defer dialerDialHooksMutex.RUnlock()
|
|
|
|
for _, hook := range dialerDialHooks {
|
|
|
|
if err := hook(ctx, connID, ips); err != nil {
|
|
|
|
result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return result.ErrorOrNil()
|
|
|
|
}
|
|
|
|
|
|
|
|
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
2024-04-12 16:53:11 +02:00
|
|
|
if CustomRoutingDisabled() {
|
|
|
|
return net.DialUDP(network, laddr, raddr)
|
|
|
|
}
|
|
|
|
|
2024-04-08 18:56:52 +02:00
|
|
|
dialer := NewDialer()
|
|
|
|
dialer.LocalAddr = laddr
|
|
|
|
|
|
|
|
conn, err := dialer.Dial(network, raddr.String())
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
|
|
|
|
}
|
|
|
|
|
|
|
|
udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
|
|
|
|
if !ok {
|
|
|
|
if err := conn.Close(); err != nil {
|
|
|
|
log.Errorf("Failed to close connection: %v", err)
|
|
|
|
}
|
|
|
|
return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
|
|
|
|
}
|
|
|
|
|
|
|
|
return udpConn, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
|
2024-04-12 16:53:11 +02:00
|
|
|
if CustomRoutingDisabled() {
|
|
|
|
return net.DialTCP(network, laddr, raddr)
|
|
|
|
}
|
|
|
|
|
2024-04-08 18:56:52 +02:00
|
|
|
dialer := NewDialer()
|
|
|
|
dialer.LocalAddr = laddr
|
|
|
|
|
|
|
|
conn, err := dialer.Dial(network, raddr.String())
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
|
|
|
|
}
|
|
|
|
|
|
|
|
tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
|
|
|
|
if !ok {
|
|
|
|
if err := conn.Close(); err != nil {
|
|
|
|
log.Errorf("Failed to close connection: %v", err)
|
|
|
|
}
|
|
|
|
return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
|
|
|
|
}
|
|
|
|
|
|
|
|
return tcpConn, nil
|
|
|
|
}
|