mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-14 10:50:45 +01:00
Add udp listener and did some change for debug purpose.
This commit is contained in:
parent
d4eaec5cbd
commit
9ac5a1ed3f
@ -10,12 +10,12 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/client/dialer/ws"
|
||||
"github.com/netbirdio/netbird/relay/client/dialer/udp"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
)
|
||||
|
||||
const (
|
||||
bufferSize = 65535 // optimise the buffer size
|
||||
bufferSize = 1500 // optimise the buffer size
|
||||
)
|
||||
|
||||
type connContainer struct {
|
||||
@ -52,7 +52,7 @@ func NewClient(serverAddress, peerID string) *Client {
|
||||
}
|
||||
|
||||
func (c *Client) Connect() error {
|
||||
conn, err := ws.Dial(c.serverAddress)
|
||||
conn, err := udp.Dial(c.serverAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -128,6 +128,7 @@ func (c *Client) readLoop() {
|
||||
buf := c.msgPool.Get().([]byte)
|
||||
n, errExit = c.relayConn.Read(buf)
|
||||
if errExit != nil {
|
||||
log.Debugf("failed to read message from relay server: %s", errExit)
|
||||
break
|
||||
}
|
||||
|
||||
@ -155,7 +156,8 @@ func (c *Client) readLoop() {
|
||||
c.msgPool.Put(buf)
|
||||
continue
|
||||
}
|
||||
c.handleTransport(channelId, buf[:n])
|
||||
go c.handleTransport(channelId, buf[:n])
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -34,13 +34,11 @@ func (c *Conn) Close() error {
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return c.client.relayConn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return c.client.relayConn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
|
14
relay/client/dialer/udp/udp.go
Normal file
14
relay/client/dialer/udp/udp.go
Normal file
@ -0,0 +1,14 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
func Dial(address string) (net.Conn, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return net.DialUDP("udp", nil, udpAddr)
|
||||
}
|
@ -17,6 +17,21 @@ var (
|
||||
|
||||
type MsgType byte
|
||||
|
||||
func (m MsgType) String() string {
|
||||
switch m {
|
||||
case MsgTypeHello:
|
||||
return "hello"
|
||||
case MsgTypeBindNewChannel:
|
||||
return "bind new channel"
|
||||
case MsgTypeBindResponse:
|
||||
return "bind response"
|
||||
case MsgTypeTransport:
|
||||
return "transport"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func DetermineClientMsgType(msg []byte) (MsgType, error) {
|
||||
// todo: validate magic byte
|
||||
msgType := MsgType(msg[0])
|
||||
@ -41,7 +56,7 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) {
|
||||
case MsgTypeTransport:
|
||||
return msgType, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid msg type: %s", msg)
|
||||
return 0, fmt.Errorf("invalid msg type, len: %d", len(msg))
|
||||
}
|
||||
}
|
||||
|
||||
|
96
relay/server/listener/udp/listener.go
Normal file
96
relay/server/listener/udp/listener.go
Normal file
@ -0,0 +1,96 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/server/listener"
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
address string
|
||||
|
||||
onAcceptFn func(conn net.Conn)
|
||||
|
||||
conns map[string]*UDPConn
|
||||
wg sync.WaitGroup
|
||||
quit chan struct{}
|
||||
lock sync.Mutex
|
||||
listener *net.UDPConn
|
||||
}
|
||||
|
||||
func NewListener(address string) listener.Listener {
|
||||
return &Listener{
|
||||
address: address,
|
||||
conns: make(map[string]*UDPConn),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) Listen(onAcceptFn func(conn net.Conn)) error {
|
||||
l.lock.Lock()
|
||||
|
||||
l.onAcceptFn = onAcceptFn
|
||||
l.quit = make(chan struct{})
|
||||
|
||||
addr := &net.UDPAddr{
|
||||
Port: 1234,
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
}
|
||||
li, err := net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
log.Errorf("%s", err)
|
||||
l.lock.Unlock()
|
||||
return err
|
||||
}
|
||||
log.Debugf("udp server is listening on address: %s", l.address)
|
||||
l.listener = li
|
||||
l.wg.Add(1)
|
||||
go l.readLoop()
|
||||
|
||||
l.lock.Unlock()
|
||||
<-l.quit
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close todo: prevent multiple call (do not close two times the channel)
|
||||
func (l *Listener) Close() error {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
close(l.quit)
|
||||
err := l.listener.Close()
|
||||
l.wg.Wait()
|
||||
return err
|
||||
}
|
||||
|
||||
func (l *Listener) readLoop() {
|
||||
defer l.wg.Done()
|
||||
|
||||
for {
|
||||
buf := make([]byte, 1500)
|
||||
n, addr, err := l.listener.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
select {
|
||||
case <-l.quit:
|
||||
return
|
||||
default:
|
||||
log.Errorf("failed to accept connection: %s", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
pConn, ok := l.conns[addr.String()]
|
||||
if ok {
|
||||
pConn.onNewMsg(buf[:n])
|
||||
continue
|
||||
}
|
||||
|
||||
pConn = NewConn(l.listener, addr)
|
||||
l.conns[addr.String()] = pConn
|
||||
go l.onAcceptFn(pConn)
|
||||
pConn.onNewMsg(buf[:n])
|
||||
|
||||
}
|
||||
}
|
68
relay/server/listener/udp/udp_conn.go
Normal file
68
relay/server/listener/udp/udp_conn.go
Normal file
@ -0,0 +1,68 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UDPConn struct {
|
||||
*net.UDPConn
|
||||
addr *net.UDPAddr
|
||||
msgChannel chan []byte
|
||||
}
|
||||
|
||||
func NewConn(conn *net.UDPConn, addr *net.UDPAddr) *UDPConn {
|
||||
return &UDPConn{
|
||||
UDPConn: conn,
|
||||
addr: addr,
|
||||
msgChannel: make(chan []byte),
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UDPConn) Read(b []byte) (n int, err error) {
|
||||
msg, ok := <-u.msgChannel
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
n = copy(b, msg)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (u *UDPConn) Write(b []byte) (n int, err error) {
|
||||
return u.UDPConn.WriteTo(b, u.addr)
|
||||
}
|
||||
|
||||
func (u *UDPConn) Close() error {
|
||||
//TODO implement me
|
||||
//panic("implement me")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UDPConn) LocalAddr() net.Addr {
|
||||
return u.UDPConn.LocalAddr()
|
||||
}
|
||||
|
||||
func (u *UDPConn) RemoteAddr() net.Addr {
|
||||
return u.addr
|
||||
}
|
||||
|
||||
func (u *UDPConn) SetDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (u *UDPConn) SetReadDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (u *UDPConn) SetWriteDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (u *UDPConn) onNewMsg(b []byte) {
|
||||
u.msgChannel <- b
|
||||
}
|
@ -2,6 +2,7 @@ package ws
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
@ -10,11 +11,13 @@ import (
|
||||
|
||||
type Conn struct {
|
||||
*websocket.Conn
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewConn(wsConn *websocket.Conn) *Conn {
|
||||
return &Conn{
|
||||
wsConn,
|
||||
Conn: wsConn,
|
||||
}
|
||||
}
|
||||
|
||||
@ -33,7 +36,9 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
c.mu.Lock()
|
||||
err := c.WriteMessage(websocket.BinaryMessage, b)
|
||||
c.mu.Unlock()
|
||||
return len(b), err
|
||||
}
|
||||
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
"github.com/netbirdio/netbird/relay/server/listener"
|
||||
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
||||
"github.com/netbirdio/netbird/relay/server/listener/udp"
|
||||
)
|
||||
|
||||
// Server
|
||||
@ -30,7 +30,7 @@ func NewServer() *Server {
|
||||
}
|
||||
|
||||
func (r *Server) Listen(address string) error {
|
||||
r.listener = ws.NewListener(address)
|
||||
r.listener = udp.NewListener(address)
|
||||
return r.listener.Listen(r.accept)
|
||||
}
|
||||
|
||||
@ -51,7 +51,7 @@ func (r *Server) accept(conn net.Conn) {
|
||||
}
|
||||
return
|
||||
}
|
||||
peer.Log.Debugf("on new connection: %s", conn.RemoteAddr())
|
||||
peer.Log.Debugf("peer connected from: %s", conn.RemoteAddr())
|
||||
|
||||
r.store.AddPeer(peer)
|
||||
defer func() {
|
||||
@ -59,8 +59,8 @@ func (r *Server) accept(conn net.Conn) {
|
||||
r.store.DeletePeer(peer)
|
||||
}()
|
||||
|
||||
buf := make([]byte, 65535) // todo: optimize buffer size
|
||||
for {
|
||||
buf := make([]byte, 1500) // todo: optimize buffer size
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
@ -98,18 +98,18 @@ func (r *Server) accept(conn net.Conn) {
|
||||
peer.Log.Errorf("failed to unmarshal transport message: %s", err)
|
||||
continue
|
||||
}
|
||||
go func() {
|
||||
foreignChannelID, remoteConn, err := peer.ConnByChannelID(channelId)
|
||||
if err != nil {
|
||||
peer.Log.Errorf("failed to transport message from peer '%s' to '%d': %s", peer.ID(), channelId, err)
|
||||
return
|
||||
}
|
||||
|
||||
foreignChannelID, remoteConn, err := peer.ConnByChannelID(channelId)
|
||||
if err != nil {
|
||||
peer.Log.Errorf("failed to transport message from peer '%s' to '%d': %s", peer.ID(), channelId, err)
|
||||
continue
|
||||
}
|
||||
|
||||
err = transportTo(remoteConn, foreignChannelID, msg)
|
||||
if err != nil {
|
||||
peer.Log.Errorf("failed to transport message from peer '%s' to '%d': %s", peer.ID(), channelId, err)
|
||||
continue
|
||||
}
|
||||
err = transportTo(remoteConn, foreignChannelID, msg)
|
||||
if err != nil {
|
||||
peer.Log.Errorf("failed to transport message from peer '%s' to '%d': %s", peer.ID(), channelId, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -173,7 +173,7 @@ func TestBindToUnavailabePeer(t *testing.T) {
|
||||
go func() {
|
||||
err := srv.Listen(addr)
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind server: %s", err)
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user