Add udp listener and did some change for debug purpose.

This commit is contained in:
Zoltan Papp 2024-05-19 12:41:06 +02:00
parent d4eaec5cbd
commit 9ac5a1ed3f
9 changed files with 224 additions and 26 deletions

View File

@ -10,12 +10,12 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/client/dialer/ws"
"github.com/netbirdio/netbird/relay/client/dialer/udp"
"github.com/netbirdio/netbird/relay/messages"
)
const (
bufferSize = 65535 // optimise the buffer size
bufferSize = 1500 // optimise the buffer size
)
type connContainer struct {
@ -52,7 +52,7 @@ func NewClient(serverAddress, peerID string) *Client {
}
func (c *Client) Connect() error {
conn, err := ws.Dial(c.serverAddress)
conn, err := udp.Dial(c.serverAddress)
if err != nil {
return err
}
@ -128,6 +128,7 @@ func (c *Client) readLoop() {
buf := c.msgPool.Get().([]byte)
n, errExit = c.relayConn.Read(buf)
if errExit != nil {
log.Debugf("failed to read message from relay server: %s", errExit)
break
}
@ -155,7 +156,8 @@ func (c *Client) readLoop() {
c.msgPool.Put(buf)
continue
}
c.handleTransport(channelId, buf[:n])
go c.handleTransport(channelId, buf[:n])
}
}

View File

@ -34,13 +34,11 @@ func (c *Conn) Close() error {
}
func (c *Conn) LocalAddr() net.Addr {
//TODO implement me
panic("implement me")
return c.client.relayConn.LocalAddr()
}
func (c *Conn) RemoteAddr() net.Addr {
//TODO implement me
panic("implement me")
return c.client.relayConn.RemoteAddr()
}
func (c *Conn) SetDeadline(t time.Time) error {

View File

@ -0,0 +1,14 @@
package udp
import (
"net"
)
func Dial(address string) (net.Conn, error) {
udpAddr, err := net.ResolveUDPAddr("udp", address)
if err != nil {
return nil, err
}
return net.DialUDP("udp", nil, udpAddr)
}

View File

@ -17,6 +17,21 @@ var (
type MsgType byte
func (m MsgType) String() string {
switch m {
case MsgTypeHello:
return "hello"
case MsgTypeBindNewChannel:
return "bind new channel"
case MsgTypeBindResponse:
return "bind response"
case MsgTypeTransport:
return "transport"
default:
return "unknown"
}
}
func DetermineClientMsgType(msg []byte) (MsgType, error) {
// todo: validate magic byte
msgType := MsgType(msg[0])
@ -41,7 +56,7 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) {
case MsgTypeTransport:
return msgType, nil
default:
return 0, fmt.Errorf("invalid msg type: %s", msg)
return 0, fmt.Errorf("invalid msg type, len: %d", len(msg))
}
}

View File

@ -0,0 +1,96 @@
package udp
import (
"net"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/server/listener"
)
type Listener struct {
address string
onAcceptFn func(conn net.Conn)
conns map[string]*UDPConn
wg sync.WaitGroup
quit chan struct{}
lock sync.Mutex
listener *net.UDPConn
}
func NewListener(address string) listener.Listener {
return &Listener{
address: address,
conns: make(map[string]*UDPConn),
}
}
func (l *Listener) Listen(onAcceptFn func(conn net.Conn)) error {
l.lock.Lock()
l.onAcceptFn = onAcceptFn
l.quit = make(chan struct{})
addr := &net.UDPAddr{
Port: 1234,
IP: net.ParseIP("0.0.0.0"),
}
li, err := net.ListenUDP("udp", addr)
if err != nil {
log.Errorf("%s", err)
l.lock.Unlock()
return err
}
log.Debugf("udp server is listening on address: %s", l.address)
l.listener = li
l.wg.Add(1)
go l.readLoop()
l.lock.Unlock()
<-l.quit
return nil
}
// Close todo: prevent multiple call (do not close two times the channel)
func (l *Listener) Close() error {
l.lock.Lock()
defer l.lock.Unlock()
close(l.quit)
err := l.listener.Close()
l.wg.Wait()
return err
}
func (l *Listener) readLoop() {
defer l.wg.Done()
for {
buf := make([]byte, 1500)
n, addr, err := l.listener.ReadFromUDP(buf)
if err != nil {
select {
case <-l.quit:
return
default:
log.Errorf("failed to accept connection: %s", err)
continue
}
}
pConn, ok := l.conns[addr.String()]
if ok {
pConn.onNewMsg(buf[:n])
continue
}
pConn = NewConn(l.listener, addr)
l.conns[addr.String()] = pConn
go l.onAcceptFn(pConn)
pConn.onNewMsg(buf[:n])
}
}

View File

@ -0,0 +1,68 @@
package udp
import (
"io"
"net"
"time"
)
type UDPConn struct {
*net.UDPConn
addr *net.UDPAddr
msgChannel chan []byte
}
func NewConn(conn *net.UDPConn, addr *net.UDPAddr) *UDPConn {
return &UDPConn{
UDPConn: conn,
addr: addr,
msgChannel: make(chan []byte),
}
}
func (u *UDPConn) Read(b []byte) (n int, err error) {
msg, ok := <-u.msgChannel
if !ok {
return 0, io.EOF
}
n = copy(b, msg)
return n, nil
}
func (u *UDPConn) Write(b []byte) (n int, err error) {
return u.UDPConn.WriteTo(b, u.addr)
}
func (u *UDPConn) Close() error {
//TODO implement me
//panic("implement me")
return nil
}
func (u *UDPConn) LocalAddr() net.Addr {
return u.UDPConn.LocalAddr()
}
func (u *UDPConn) RemoteAddr() net.Addr {
return u.addr
}
func (u *UDPConn) SetDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (u *UDPConn) SetReadDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (u *UDPConn) SetWriteDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (u *UDPConn) onNewMsg(b []byte) {
u.msgChannel <- b
}

View File

@ -2,6 +2,7 @@ package ws
import (
"fmt"
"sync"
"time"
"github.com/gorilla/websocket"
@ -10,11 +11,13 @@ import (
type Conn struct {
*websocket.Conn
mu sync.Mutex
}
func NewConn(wsConn *websocket.Conn) *Conn {
return &Conn{
wsConn,
Conn: wsConn,
}
}
@ -33,7 +36,9 @@ func (c *Conn) Read(b []byte) (n int, err error) {
}
func (c *Conn) Write(b []byte) (int, error) {
c.mu.Lock()
err := c.WriteMessage(websocket.BinaryMessage, b)
c.mu.Unlock()
return len(b), err
}

View File

@ -9,7 +9,7 @@ import (
"github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/listener/ws"
"github.com/netbirdio/netbird/relay/server/listener/udp"
)
// Server
@ -30,7 +30,7 @@ func NewServer() *Server {
}
func (r *Server) Listen(address string) error {
r.listener = ws.NewListener(address)
r.listener = udp.NewListener(address)
return r.listener.Listen(r.accept)
}
@ -51,7 +51,7 @@ func (r *Server) accept(conn net.Conn) {
}
return
}
peer.Log.Debugf("on new connection: %s", conn.RemoteAddr())
peer.Log.Debugf("peer connected from: %s", conn.RemoteAddr())
r.store.AddPeer(peer)
defer func() {
@ -59,8 +59,8 @@ func (r *Server) accept(conn net.Conn) {
r.store.DeletePeer(peer)
}()
buf := make([]byte, 65535) // todo: optimize buffer size
for {
buf := make([]byte, 1500) // todo: optimize buffer size
n, err := conn.Read(buf)
if err != nil {
if err != io.EOF {
@ -98,18 +98,18 @@ func (r *Server) accept(conn net.Conn) {
peer.Log.Errorf("failed to unmarshal transport message: %s", err)
continue
}
go func() {
foreignChannelID, remoteConn, err := peer.ConnByChannelID(channelId)
if err != nil {
peer.Log.Errorf("failed to transport message from peer '%s' to '%d': %s", peer.ID(), channelId, err)
return
}
foreignChannelID, remoteConn, err := peer.ConnByChannelID(channelId)
if err != nil {
peer.Log.Errorf("failed to transport message from peer '%s' to '%d': %s", peer.ID(), channelId, err)
continue
}
err = transportTo(remoteConn, foreignChannelID, msg)
if err != nil {
peer.Log.Errorf("failed to transport message from peer '%s' to '%d': %s", peer.ID(), channelId, err)
continue
}
err = transportTo(remoteConn, foreignChannelID, msg)
if err != nil {
peer.Log.Errorf("failed to transport message from peer '%s' to '%d': %s", peer.ID(), channelId, err)
}
}()
}
}
}

View File

@ -173,7 +173,7 @@ func TestBindToUnavailabePeer(t *testing.T) {
go func() {
err := srv.Listen(addr)
if err != nil {
t.Errorf("failed to bind server: %s", err)
t.Fatalf("failed to bind server: %s", err)
}
}()