Expect correct conn type (#1801)

This commit is contained in:
Viktor Liu 2024-04-05 00:10:32 +02:00 committed by GitHub
parent 3d2a2377c6
commit 3461b1bb90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 56 additions and 44 deletions

View File

@ -1,10 +1,7 @@
package net package net
import ( import (
"fmt"
"net" "net"
log "github.com/sirupsen/logrus"
) )
// Dialer extends the standard net.Dialer with the ability to execute hooks before // Dialer extends the standard net.Dialer with the ability to execute hooks before
@ -22,43 +19,3 @@ func NewDialer() *Dialer {
return dialer return dialer
} }
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
}
udpConn, ok := conn.(*net.UDPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDP connection, got different type")
}
return udpConn, nil
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
}
tcpConn, ok := conn.(*net.TCPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected TCP connection, got different type")
}
return tcpConn, nil
}

View File

@ -121,3 +121,43 @@ func calliDialerHooks(ctx context.Context, connID ConnectionID, address string,
return result.ErrorOrNil() return result.ErrorOrNil()
} }
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
}
udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
}
return udpConn, nil
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
}
tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
}
return tcpConn, nil
}

15
util/net/dialer_mobile.go Normal file
View File

@ -0,0 +1,15 @@
//go:build android || ios
package net
import (
"net"
)
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
return net.DialUDP(network, laddr, raddr)
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
return net.DialTCP(network, laddr, raddr)
}

View File

@ -156,7 +156,7 @@ func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) {
if err := packetConn.Close(); err != nil { if err := packetConn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err) log.Errorf("Failed to close connection: %v", err)
} }
return nil, fmt.Errorf("expected UDPConn, got different type") return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn)
} }
return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil