mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-18 19:09:09 +02:00
Add initial relay code
This commit is contained in:
8
relay/server/listener/listener.go
Normal file
8
relay/server/listener/listener.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package listener
|
||||
|
||||
import "net"
|
||||
|
||||
type Listener interface {
|
||||
Listen(func(conn net.Conn)) error
|
||||
Close() error
|
||||
}
|
80
relay/server/listener/tcp/listener.go
Normal file
80
relay/server/listener/tcp/listener.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/server/listener"
|
||||
)
|
||||
|
||||
// Listener
|
||||
// Is it just demo code. It does not work in real life environment because the TCP is a streaming protocol, adn
|
||||
// it does not handle framing.
|
||||
type Listener struct {
|
||||
address string
|
||||
|
||||
onAcceptFn func(conn net.Conn)
|
||||
wg sync.WaitGroup
|
||||
quit chan struct{}
|
||||
listener net.Listener
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func NewListener(address string) listener.Listener {
|
||||
return &Listener{
|
||||
address: address,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) Listen(onAcceptFn func(conn net.Conn)) error {
|
||||
l.lock.Lock()
|
||||
|
||||
l.onAcceptFn = onAcceptFn
|
||||
l.quit = make(chan struct{})
|
||||
|
||||
li, err := net.Listen("tcp", l.address)
|
||||
if err != nil {
|
||||
log.Errorf("failed to listen on address: %s, %s", l.address, err)
|
||||
l.lock.Unlock()
|
||||
return err
|
||||
}
|
||||
log.Debugf("TCP server is listening on address: %s", l.address)
|
||||
l.listener = li
|
||||
l.wg.Add(1)
|
||||
go l.acceptLoop()
|
||||
|
||||
l.lock.Unlock()
|
||||
<-l.quit
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close todo: prevent multiple call (do not close two times the channel)
|
||||
func (l *Listener) Close() error {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
close(l.quit)
|
||||
err := l.listener.Close()
|
||||
l.wg.Wait()
|
||||
return err
|
||||
}
|
||||
|
||||
func (l *Listener) acceptLoop() {
|
||||
defer l.wg.Done()
|
||||
|
||||
for {
|
||||
conn, err := l.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-l.quit:
|
||||
return
|
||||
default:
|
||||
log.Errorf("failed to accept connection: %s", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
go l.onAcceptFn(conn)
|
||||
}
|
||||
}
|
82
relay/server/listener/ws/listener.go
Normal file
82
relay/server/listener/ws/listener.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/server/listener"
|
||||
)
|
||||
|
||||
var (
|
||||
upgrader = websocket.Upgrader{} // use default options
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
address string
|
||||
|
||||
wg sync.WaitGroup
|
||||
server *http.Server
|
||||
acceptFn func(conn net.Conn)
|
||||
}
|
||||
|
||||
func NewListener(address string) listener.Listener {
|
||||
return &Listener{
|
||||
address: address,
|
||||
}
|
||||
}
|
||||
|
||||
// Listen todo: prevent multiple call
|
||||
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
|
||||
l.acceptFn = acceptFn
|
||||
http.HandleFunc("/", l.onAccept)
|
||||
|
||||
l.server = &http.Server{
|
||||
Addr: l.address,
|
||||
}
|
||||
|
||||
log.Debugf("WS server is listening on address: %s", l.address)
|
||||
err := l.server.ListenAndServe()
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (l *Listener) Close() error {
|
||||
if l.server == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
log.Debugf("closing WS server")
|
||||
if err := l.server.Shutdown(ctx); err != nil {
|
||||
return fmt.Errorf("server shutdown failed: %v", err)
|
||||
}
|
||||
|
||||
l.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Listener) onAccept(writer http.ResponseWriter, request *http.Request) {
|
||||
l.wg.Add(1)
|
||||
defer l.wg.Done()
|
||||
|
||||
wsConn, err := upgrader.Upgrade(writer, request, nil)
|
||||
if err != nil {
|
||||
log.Errorf("failed to upgrade connection: %s", err)
|
||||
return
|
||||
}
|
||||
conn := NewConn(wsConn)
|
||||
l.acceptFn(conn)
|
||||
return
|
||||
}
|
52
relay/server/listener/ws/server_conn.go
Normal file
52
relay/server/listener/ws/server_conn.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
*websocket.Conn
|
||||
}
|
||||
|
||||
func NewConn(wsConn *websocket.Conn) *Conn {
|
||||
return &Conn{
|
||||
wsConn,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
t, r, err := c.NextReader()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if t != websocket.BinaryMessage {
|
||||
log.Errorf("unexpected message type: %d", t)
|
||||
return 0, fmt.Errorf("unexpected message type")
|
||||
}
|
||||
|
||||
return r.Read(b)
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
err := c.WriteMessage(websocket.BinaryMessage, b)
|
||||
return len(b), err
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
errR := c.SetReadDeadline(t)
|
||||
errW := c.SetWriteDeadline(t)
|
||||
|
||||
if errR != nil {
|
||||
return errR
|
||||
}
|
||||
|
||||
if errW != nil {
|
||||
return errW
|
||||
}
|
||||
return nil
|
||||
}
|
113
relay/server/peer.go
Normal file
113
relay/server/peer.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Participant struct {
|
||||
ChannelID uint16
|
||||
ChannelIDForeign uint16
|
||||
ConnForeign net.Conn
|
||||
Peer *Peer
|
||||
}
|
||||
|
||||
type Peer struct {
|
||||
Log *log.Entry
|
||||
id string
|
||||
conn net.Conn
|
||||
|
||||
pendingParticipantByChannelID map[uint16]*Participant
|
||||
participantByID map[uint16]*Participant // used for package transfer
|
||||
participantByPeerID map[string]*Participant // used for channel linking
|
||||
|
||||
lastId uint16
|
||||
lastIdLock sync.Mutex
|
||||
}
|
||||
|
||||
func NewPeer(id string, conn net.Conn) *Peer {
|
||||
return &Peer{
|
||||
Log: log.WithField("peer_id", id),
|
||||
id: id,
|
||||
conn: conn,
|
||||
pendingParticipantByChannelID: make(map[uint16]*Participant),
|
||||
participantByID: make(map[uint16]*Participant),
|
||||
participantByPeerID: make(map[string]*Participant),
|
||||
}
|
||||
}
|
||||
func (p *Peer) BindChannel(remotePeerId string) uint16 {
|
||||
ch, ok := p.participantByPeerID[remotePeerId]
|
||||
if ok {
|
||||
return ch.ChannelID
|
||||
}
|
||||
|
||||
channelID := p.newChannelID()
|
||||
channel := &Participant{
|
||||
ChannelID: channelID,
|
||||
}
|
||||
p.pendingParticipantByChannelID[channelID] = channel
|
||||
p.participantByPeerID[remotePeerId] = channel
|
||||
return channelID
|
||||
}
|
||||
|
||||
func (p *Peer) UnBindChannel(remotePeerId string) {
|
||||
pa, ok := p.participantByPeerID[remotePeerId]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
p.Log.Debugf("unbind channel with '%s': %d", remotePeerId, pa.ChannelID)
|
||||
p.pendingParticipantByChannelID[pa.ChannelID] = pa
|
||||
delete(p.participantByID, pa.ChannelID)
|
||||
}
|
||||
|
||||
func (p *Peer) AddParticipant(peer *Peer, remoteChannelID uint16) (uint16, bool) {
|
||||
participant, ok := p.participantByPeerID[peer.ID()]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
participant.ChannelIDForeign = remoteChannelID
|
||||
participant.ConnForeign = peer.conn
|
||||
participant.Peer = peer
|
||||
|
||||
delete(p.pendingParticipantByChannelID, participant.ChannelID)
|
||||
p.participantByID[participant.ChannelID] = participant
|
||||
return participant.ChannelID, true
|
||||
}
|
||||
|
||||
func (p *Peer) DeleteParticipants() {
|
||||
for _, participant := range p.participantByID {
|
||||
participant.Peer.UnBindChannel(p.id)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Peer) ConnByChannelID(dstID uint16) (uint16, net.Conn, error) {
|
||||
ch, ok := p.participantByID[dstID]
|
||||
if !ok {
|
||||
return 0, nil, fmt.Errorf("destination channel not found")
|
||||
}
|
||||
|
||||
return ch.ChannelIDForeign, ch.ConnForeign, nil
|
||||
}
|
||||
|
||||
func (p *Peer) ID() string {
|
||||
return p.id
|
||||
}
|
||||
|
||||
func (p *Peer) newChannelID() uint16 {
|
||||
p.lastIdLock.Lock()
|
||||
defer p.lastIdLock.Unlock()
|
||||
for {
|
||||
p.lastId++
|
||||
if _, ok := p.pendingParticipantByChannelID[p.lastId]; ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := p.participantByID[p.lastId]; ok {
|
||||
continue
|
||||
}
|
||||
return p.lastId
|
||||
}
|
||||
}
|
149
relay/server/server.go
Normal file
149
relay/server/server.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
"github.com/netbirdio/netbird/relay/server/listener"
|
||||
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
||||
)
|
||||
|
||||
// Server
|
||||
// todo:
|
||||
// authentication: provide JWT token via RPC call. The MGM server can forward the token to the agents.
|
||||
// connection timeout handling
|
||||
// implement HA (High Availability) mode
|
||||
type Server struct {
|
||||
store *Store
|
||||
|
||||
listener listener.Listener
|
||||
}
|
||||
|
||||
func NewServer() *Server {
|
||||
return &Server{
|
||||
store: NewStore(),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Server) Listen(address string) error {
|
||||
r.listener = ws.NewListener(address)
|
||||
return r.listener.Listen(r.accept)
|
||||
}
|
||||
|
||||
func (r *Server) Close() error {
|
||||
if r.listener == nil {
|
||||
return nil
|
||||
}
|
||||
return r.listener.Close()
|
||||
}
|
||||
|
||||
func (r *Server) accept(conn net.Conn) {
|
||||
peer, err := handShake(conn)
|
||||
if err != nil {
|
||||
log.Errorf("failed to handshake wiht %s: %s", conn.RemoteAddr(), err)
|
||||
cErr := conn.Close()
|
||||
if cErr != nil {
|
||||
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
peer.Log.Debugf("on new connection: %s", conn.RemoteAddr())
|
||||
|
||||
r.store.AddPeer(peer)
|
||||
defer func() {
|
||||
peer.Log.Debugf("teardown connection")
|
||||
r.store.DeletePeer(peer)
|
||||
}()
|
||||
|
||||
buf := make([]byte, 65535) // todo: optimize buffer size
|
||||
for {
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
peer.Log.Errorf("failed to read message: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineClientMsgType(buf[:n])
|
||||
if err != nil {
|
||||
log.Errorf("failed to determine message type: %s", err)
|
||||
return
|
||||
}
|
||||
switch msgType {
|
||||
case messages.MsgTypeBindNewChannel:
|
||||
dstPeerId, err := messages.UnmarshalBindNewChannel(buf[:n])
|
||||
if err != nil {
|
||||
log.Errorf("failed to unmarshal bind new channel message: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
channelID := r.store.Link(peer, dstPeerId)
|
||||
|
||||
msg := messages.MarshalBindResponseMsg(channelID, dstPeerId)
|
||||
_, err = conn.Write(msg)
|
||||
if err != nil {
|
||||
peer.Log.Errorf("failed to response to bind request: %s", err)
|
||||
continue
|
||||
}
|
||||
peer.Log.Debugf("bind new channel with '%s', channelID: %d", dstPeerId, channelID)
|
||||
case messages.MsgTypeTransport:
|
||||
msg := buf[:n]
|
||||
channelId, err := messages.UnmarshalTransportID(msg)
|
||||
if err != nil {
|
||||
peer.Log.Errorf("failed to unmarshal transport message: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
foreignChannelID, remoteConn, err := peer.ConnByChannelID(channelId)
|
||||
if err != nil {
|
||||
peer.Log.Errorf("failed to transport message from peer '%s' to '%d': %s", peer.ID(), channelId, err)
|
||||
continue
|
||||
}
|
||||
|
||||
err = transportTo(remoteConn, foreignChannelID, msg)
|
||||
if err != nil {
|
||||
peer.Log.Errorf("failed to transport message from peer '%s' to '%d': %s", peer.ID(), channelId, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func transportTo(conn net.Conn, channelID uint16, msg []byte) error {
|
||||
err := messages.UpdateTransportMsg(msg, channelID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = conn.Write(msg)
|
||||
return err
|
||||
}
|
||||
|
||||
func handShake(conn net.Conn) (*Peer, error) {
|
||||
buf := make([]byte, 65535) // todo: reduce the buffer size
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
log.Errorf("failed to read message: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
msgType, err := messages.DetermineClientMsgType(buf[:n])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if msgType != messages.MsgTypeHello {
|
||||
tErr := fmt.Errorf("invalid message type")
|
||||
log.Errorf("failed to handshake: %s", tErr)
|
||||
return nil, tErr
|
||||
}
|
||||
peerId, err := messages.UnmarshalHelloMsg(buf[:n])
|
||||
if err != nil {
|
||||
log.Errorf("failed to handshake: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
p := NewPeer(peerId, conn)
|
||||
return p, nil
|
||||
}
|
48
relay/server/store.go
Normal file
48
relay/server/store.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
peers map[string]*Peer // Key is the id (public key or sha-256) of the peer
|
||||
peersLock sync.Mutex
|
||||
}
|
||||
|
||||
func NewStore() *Store {
|
||||
return &Store{
|
||||
peers: make(map[string]*Peer),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Store) AddPeer(peer *Peer) {
|
||||
s.peersLock.Lock()
|
||||
defer s.peersLock.Unlock()
|
||||
s.peers[peer.ID()] = peer
|
||||
}
|
||||
|
||||
func (s *Store) Link(peer *Peer, peerForeignID string) uint16 {
|
||||
s.peersLock.Lock()
|
||||
defer s.peersLock.Unlock()
|
||||
|
||||
channelId := peer.BindChannel(peerForeignID)
|
||||
dstPeer, ok := s.peers[peerForeignID]
|
||||
if !ok {
|
||||
return channelId
|
||||
}
|
||||
|
||||
foreignChannelID, ok := dstPeer.AddParticipant(peer, channelId)
|
||||
if !ok {
|
||||
return channelId
|
||||
}
|
||||
peer.AddParticipant(dstPeer, foreignChannelID)
|
||||
return channelId
|
||||
}
|
||||
|
||||
func (s *Store) DeletePeer(peer *Peer) {
|
||||
s.peersLock.Lock()
|
||||
defer s.peersLock.Unlock()
|
||||
|
||||
delete(s.peers, peer.ID())
|
||||
peer.DeleteParticipants()
|
||||
}
|
Reference in New Issue
Block a user