Use forked Wireguard-go for custom bind (#823)

Update go version to 1.20
Use forked wireguard-go repo because of custom Bind implementation
This commit is contained in:
Zoltan Papp
2023-04-27 17:50:45 +02:00
committed by GitHub
parent afaa3fbe4f
commit 7f5e1c623e
10 changed files with 303 additions and 198 deletions

View File

@@ -1,98 +1,132 @@
package bind
import (
"errors"
"fmt"
"net"
"net/netip"
"runtime"
"sync"
"syscall"
"github.com/pion/stun"
"github.com/pion/transport/v2"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
"golang.org/x/net/ipv4"
wgConn "golang.zx2c4.com/wireguard/conn"
)
// ICEBind is the userspace implementation of WireGuard's conn.Bind interface using ice.UDPMux of the pion/ice library
type ICEBind struct {
// below fields, initialized on open
ipv4 net.PacketConn
udpMux *UniversalUDPMuxDefault
// below are fields initialized on creation
transportNet transport.Net
mu sync.Mutex
type receiverCreator struct {
iceBind *ICEBind
}
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn)
}
type ICEBind struct {
*wgConn.StdNetBind
muUDPMux sync.Mutex
transportNet transport.Net
udpMux *UniversalUDPMuxDefault
}
// NewICEBind create a new instance of ICEBind with a given transportNet function.
// The transportNet can be nil.
func NewICEBind(transportNet transport.Net) *ICEBind {
return &ICEBind{
ib := &ICEBind{
transportNet: transportNet,
mu: sync.Mutex{},
}
rc := receiverCreator{
ib,
}
ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc)
return ib
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (b *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.udpMux == nil {
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
if s.udpMux == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
return b.udpMux, nil
return s.udpMux, nil
}
// Open creates a WireGuard socket and an instance of UDPMux that is used to glue up ICE and WireGuard for hole punching
func (b *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
b.mu.Lock()
defer b.mu.Unlock()
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
if b.ipv4 != nil {
return nil, 0, conn.ErrBindAlreadyOpen
}
var err error
b.ipv4, _, err = listenNet("udp4", int(uport))
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
b.udpMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: b.ipv4, Net: b.transportNet})
portAddr, err := netip.ParseAddrPort(b.ipv4.LocalAddr().String())
if err != nil {
return nil, 0, err
}
log.Infof("opened ICEBind on %s", b.ipv4.LocalAddr().String())
return []conn.ReceiveFunc{
b.makeReceiveIPv4(b.ipv4),
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
Net: s.transportNet,
},
portAddr.Port(), nil
}
func listenNet(network string, port int) (net.PacketConn, int, error) {
c, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
if err != nil {
return nil, 0, err
}
lAddr := c.LocalAddr()
uAddr, err := net.ResolveUDPAddr(
lAddr.Network(),
lAddr.String(),
)
if err != nil {
return nil, 0, err
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := ipv4MsgsPool.Get().(*[]ipv4.Message)
defer ipv4MsgsPool.Put(msgs)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
}
var numMsgs int
if runtime.GOOS == "linux" {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
// todo: handle err
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
if ok {
sizes[i] = 0
} else {
sizes[i] = msg.N
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
}
return c, uAddr.Port, nil
}
func parseSTUNMessage(raw []byte) (*stun.Message, error) {
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
for i := range buffers {
if !stun.IsMessage(buffers[i]) {
continue
}
msg, err := s.parseSTUNMessage(buffers[i][:n])
if err != nil {
buffers[i] = []byte{}
return true, err
}
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
if muxErr != nil {
log.Warnf("failed to handle STUN packet")
}
buffers[i] = []byte{}
return true, nil
}
return false, nil
}
func (s *ICEBind) parseSTUNMessage(raw []byte) (*stun.Message, error) {
msg := &stun.Message{
Raw: raw,
}
@@ -102,107 +136,3 @@ func parseSTUNMessage(raw []byte) (*stun.Message, error) {
return msg, nil
}
func (b *ICEBind) makeReceiveIPv4(c net.PacketConn) conn.ReceiveFunc {
return func(buff []byte) (int, conn.Endpoint, error) {
n, endpoint, err := c.ReadFrom(buff)
if err != nil {
return 0, nil, err
}
e, err := netip.ParseAddrPort(endpoint.String())
if err != nil {
return 0, nil, err
}
if !stun.IsMessage(buff) {
// WireGuard traffic
return n, (conn.StdNetEndpoint)(netip.AddrPortFrom(e.Addr(), e.Port())), nil
}
msg, err := parseSTUNMessage(buff[:n])
if err != nil {
return 0, nil, err
}
err = b.udpMux.HandleSTUNMessage(msg, endpoint)
if err != nil {
log.Warnf("failed to handle packet")
}
// discard packets because they are STUN related
return 0, nil, nil //todo proper return
}
}
// Close closes the WireGuard socket and UDPMux
func (b *ICEBind) Close() error {
b.mu.Lock()
defer b.mu.Unlock()
var err1, err2 error
if b.ipv4 != nil {
c := b.ipv4
b.ipv4 = nil
err1 = c.Close()
}
if b.udpMux != nil {
m := b.udpMux
b.udpMux = nil
err2 = m.Close()
}
if err1 != nil {
return err1
}
return err2
}
// SetMark sets the mark for each packet sent through this Bind.
// This mark is passed to the kernel as the socket option SO_MARK.
func (b *ICEBind) SetMark(mark uint32) error {
return nil
}
// Send bytes to the remote endpoint (peer)
func (b *ICEBind) Send(buff []byte, endpoint conn.Endpoint) error {
nend, ok := endpoint.(conn.StdNetEndpoint)
if !ok {
return conn.ErrWrongEndpointType
}
addrPort := netip.AddrPort(nend)
_, err := b.ipv4.WriteTo(buff, &net.UDPAddr{
IP: addrPort.Addr().AsSlice(),
Port: int(addrPort.Port()),
Zone: addrPort.Addr().Zone(),
})
return err
}
// ParseEndpoint creates a new endpoint from a string.
func (b *ICEBind) ParseEndpoint(s string) (ep conn.Endpoint, err error) {
e, err := netip.ParseAddrPort(s)
return asEndpoint(e), err
}
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
// This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates,
// but Endpoints are immutable, so we can re-use them.
var endpointPool = sync.Pool{
New: func() any {
return make(map[netip.AddrPort]conn.Endpoint)
},
}
// asEndpoint returns an Endpoint containing ap.
func asEndpoint(ap netip.AddrPort) conn.Endpoint {
m := endpointPool.Get().(map[netip.AddrPort]conn.Endpoint)
defer endpointPool.Put(m)
e, ok := m[ap]
if !ok {
e = conn.Endpoint(conn.StdNetEndpoint(ap))
m[ap] = e
}
return e
}