Remove channel binding logic

This commit is contained in:
Zoltán Papp 2024-05-23 13:24:02 +02:00
parent 0a05f8b4d4
commit 36b2cd16cc
11 changed files with 229 additions and 374 deletions

View File

@ -1,7 +1,6 @@
package client package client
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -19,25 +18,21 @@ const (
serverResponseTimeout = 8 * time.Second serverResponseTimeout = 8 * time.Second
) )
type bufMsg struct { type Msg struct {
bufPtr *[]byte buf []byte
buf []byte
} }
type connContainer struct { type connContainer struct {
conn *Conn conn *Conn
messages chan bufMsg messages chan Msg
} }
// Client Todo:
// - handle automatic reconnection
type Client struct { type Client struct {
log *log.Entry
serverAddress string serverAddress string
peerID string hashedID []byte
channelsPending map[string]chan net.Conn // todo: protect map with mutex conns map[string]*connContainer
channels map[uint16]*connContainer
msgPool sync.Pool
relayConn net.Conn relayConn net.Conn
relayConnState bool relayConnState bool
@ -45,17 +40,12 @@ type Client struct {
} }
func NewClient(serverAddress, peerID string) *Client { func NewClient(serverAddress, peerID string) *Client {
hashedID, hashedStringId := messages.HashID(peerID)
return &Client{ return &Client{
serverAddress: serverAddress, log: log.WithField("client_id", hashedStringId),
peerID: peerID, serverAddress: serverAddress,
channelsPending: make(map[string]chan net.Conn), hashedID: hashedID,
channels: make(map[uint16]*connContainer), conns: make(map[string]*connContainer),
msgPool: sync.Pool{
New: func() any {
buf := make([]byte, bufferSize)
return &buf
},
},
} }
} }
@ -89,31 +79,17 @@ func (c *Client) Connect() error {
return nil return nil
} }
func (c *Client) BindChannel(remotePeerID string) (net.Conn, error) { func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
c.mu.Lock() hashedID, hashedStringID := messages.HashID(dstPeerID)
defer c.mu.Unlock() log.Infof("open connection to peer: %s", hashedStringID)
messageBuffer := make(chan Msg, 2)
conn := NewConn(c, hashedID, c.generateConnReaderFN(messageBuffer))
if c.relayConn == nil { c.conns[hashedStringID] = &connContainer{
return nil, fmt.Errorf("client not connected to the relay server") conn,
} messageBuffer,
bindSuccessChan := make(chan net.Conn, 1)
c.channelsPending[remotePeerID] = bindSuccessChan
msg := messages.MarshalBindNewChannelMsg(remotePeerID)
_, err := c.relayConn.Write(msg)
if err != nil {
log.Errorf("failed to write out bind message: %s", err)
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), serverResponseTimeout)
defer cancel()
select {
case <-ctx.Done():
return nil, fmt.Errorf("bind timeout")
case c := <-bindSuccessChan:
return c, nil
} }
return conn, nil
} }
func (c *Client) Close() error { func (c *Client) Close() error {
@ -124,18 +100,15 @@ func (c *Client) Close() error {
return nil return nil
} }
for _, conn := range c.channels {
close(conn.messages)
}
c.channels = make(map[uint16]*connContainer)
c.relayConnState = false c.relayConnState = false
err := c.relayConn.Close() err := c.relayConn.Close()
return err return err
} }
func (c *Client) handShake() error { func (c *Client) handShake() error {
msg, err := messages.MarshalHelloMsg(c.peerID) msg, err := messages.MarshalHelloMsg(c.hashedID)
if err != nil { if err != nil {
log.Errorf("failed to marshal hello message: %s", err)
return err return err
} }
_, err = c.relayConn.Write(msg) _, err = c.relayConn.Write(msg)
@ -171,85 +144,56 @@ func (c *Client) handShake() error {
} }
func (c *Client) readLoop() { func (c *Client) readLoop() {
log := log.WithField("client_id", c.peerID) defer func() {
c.log.Debugf("exit from read loop")
}()
var errExit error var errExit error
var n int var n int
for { for {
bufPtr := c.msgPool.Get().(*[]byte) buf := make([]byte, bufferSize)
buf := *bufPtr
n, errExit = c.relayConn.Read(buf) n, errExit = c.relayConn.Read(buf)
if errExit != nil { if errExit != nil {
log.Debugf("failed to read message from relay server: %s", errExit) if c.relayConnState {
c.freeBuf(bufPtr) c.log.Debugf("failed to read message from relay server: %s", errExit)
}
break break
} }
msgType, err := messages.DetermineServerMsgType(buf[:n]) msgType, err := messages.DetermineServerMsgType(buf[:n])
if err != nil { if err != nil {
log.Errorf("failed to determine message type: %s", err) c.log.Errorf("failed to determine message type: %s", err)
c.freeBuf(bufPtr)
continue continue
} }
switch msgType { switch msgType {
case messages.MsgTypeBindResponse:
channelId, peerId, err := messages.UnmarshalBindResponseMsg(buf[:n])
if err != nil {
log.Errorf("failed to parse bind response message: %v", err)
} else {
c.handleBindResponse(channelId, peerId)
}
c.freeBuf(bufPtr)
continue
case messages.MsgTypeTransport: case messages.MsgTypeTransport:
channelId, err := messages.UnmarshalTransportID(buf[:n]) peerID, err := messages.UnmarshalTransportID(buf[:n])
if err != nil { if err != nil {
log.Errorf("failed to parse transport message: %v", err) c.log.Errorf("failed to parse transport message: %v", err)
c.freeBuf(bufPtr)
continue continue
} }
container, ok := c.channels[channelId] stringID := messages.HashIDToString(peerID)
container, ok := c.conns[stringID]
if !ok { if !ok {
log.Errorf("unexpected transport message for channel: %d", channelId) c.log.Errorf("peer not found: %s", stringID)
c.freeBuf(bufPtr) continue
return
} }
container.messages <- bufMsg{ container.messages <- Msg{
bufPtr,
buf[:n], buf[:n],
} }
} }
} }
if c.relayConnState { if c.relayConnState {
log.Errorf("failed to read message from relay server: %s", errExit) c.log.Errorf("failed to read message from relay server: %s", errExit)
_ = c.relayConn.Close() _ = c.relayConn.Close()
} }
} }
func (c *Client) handleBindResponse(channelId uint16, peerId string) { func (c *Client) writeTo(dstID []byte, payload []byte) (int, error) {
bindSuccessChan, ok := c.channelsPending[peerId] msg := messages.MarshalTransportMsg(dstID, payload)
if !ok {
log.Errorf("unexpected bind response from: %s", peerId)
return
}
delete(c.channelsPending, peerId)
messageBuffer := make(chan bufMsg, 2)
conn := NewConn(c, channelId, c.generateConnReaderFN(messageBuffer))
c.channels[channelId] = &connContainer{
conn,
messageBuffer,
}
log.Debugf("bind success for '%s': %d", peerId, channelId)
bindSuccessChan <- conn
}
func (c *Client) writeTo(channelID uint16, payload []byte) (int, error) {
msg := messages.MarshalTransportMsg(channelID, payload)
n, err := c.relayConn.Write(msg) n, err := c.relayConn.Write(msg)
if err != nil { if err != nil {
log.Errorf("failed to write transport message: %s", err) log.Errorf("failed to write transport message: %s", err)
@ -257,26 +201,19 @@ func (c *Client) writeTo(channelID uint16, payload []byte) (int, error) {
return n, err return n, err
} }
func (c *Client) generateConnReaderFN(messageBufferChan chan bufMsg) func(b []byte) (n int, err error) { func (c *Client) generateConnReaderFN(msgChannel chan Msg) func(b []byte) (n int, err error) {
return func(b []byte) (n int, err error) { return func(b []byte) (n int, err error) {
select { msg, ok := <-msgChannel
case bufMsg, ok := <-messageBufferChan: if !ok {
if !ok { return 0, io.EOF
return 0, io.EOF
}
payload, err := messages.UnmarshalTransportPayload(bufMsg.buf)
if err != nil {
return 0, err
}
n = copy(b, payload)
c.freeBuf(bufMsg.bufPtr)
} }
payload, err := messages.UnmarshalTransportPayload(msg.buf)
if err != nil {
return 0, err
}
n = copy(b, payload)
return n, nil return n, nil
} }
} }
func (c *Client) freeBuf(ptr *[]byte) {
c.msgPool.Put(ptr)
}

View File

@ -6,23 +6,23 @@ import (
) )
type Conn struct { type Conn struct {
client *Client client *Client
channelID uint16 dstID []byte
readerFn func(b []byte) (n int, err error) readerFn func(b []byte) (n int, err error)
} }
func NewConn(client *Client, channelID uint16, readerFn func(b []byte) (n int, err error)) *Conn { func NewConn(client *Client, dstID []byte, readerFn func(b []byte) (n int, err error)) *Conn {
c := &Conn{ c := &Conn{
client: client, client: client,
channelID: channelID, dstID: dstID,
readerFn: readerFn, readerFn: readerFn,
} }
return c return c
} }
func (c *Conn) Write(p []byte) (n int, err error) { func (c *Conn) Write(p []byte) (n int, err error) {
return c.client.writeTo(c.channelID, p) return c.client.writeTo(c.dstID, p)
} }
func (c *Conn) Read(b []byte) (n int, err error) { func (c *Conn) Read(b []byte) (n int, err error) {

43
relay/client/manager.go Normal file
View File

@ -0,0 +1,43 @@
package client
import (
"context"
"sync"
)
type Manager struct {
ctx context.Context
ctxCancel context.CancelFunc
srvAddress string
peerID string
wg sync.WaitGroup
clients map[string]*Client
clientsMutex sync.RWMutex
}
func NewManager(ctx context.Context, serverAddress string, peerID string) *Manager {
ctx, cancel := context.WithCancel(ctx)
return &Manager{
ctx: ctx,
ctxCancel: cancel,
srvAddress: serverAddress,
peerID: peerID,
clients: make(map[string]*Client),
}
}
func (m *Manager) Teardown() {
m.ctxCancel()
m.wg.Wait()
}
func (m *Manager) newSrvConnection(address string) {
if _, ok := m.clients[address]; ok {
return
}
// client := NewClient(address, m.peerID)
//err = client.Connect()
}

View File

@ -1,9 +1,10 @@
package main package main
import ( import (
"github.com/netbirdio/netbird/util"
"os" "os"
"github.com/netbirdio/netbird/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/relay/server"
@ -15,7 +16,7 @@ func init() {
func main() { func main() {
address := "0.0.0.0:1234" address := "10.145.236.1:1235"
srv := server.NewServer() srv := server.NewServer()
err := srv.Listen(address) err := srv.Listen(address)
if err != nil { if err != nil {

20
relay/messages/id.go Normal file
View File

@ -0,0 +1,20 @@
package messages
import (
"crypto/sha256"
"encoding/base64"
)
const (
IDSize = sha256.Size
)
func HashID(peerID string) ([]byte, string) {
idHash := sha256.Sum256([]byte(peerID))
idHashString := base64.StdEncoding.EncodeToString(idHash[:])
return idHash[:], idHashString
}
func HashIDToString(idHash []byte) string {
return base64.StdEncoding.EncodeToString(idHash[:])
}

View File

@ -5,11 +5,9 @@ import (
) )
const ( const (
MsgTypeHello MsgType = 0 MsgTypeHello MsgType = 0
MsgTypeHelloResponse MsgType = 1 MsgTypeHelloResponse MsgType = 1
MsgTypeBindNewChannel MsgType = 2 MsgTypeTransport MsgType = 2
MsgTypeBindResponse MsgType = 3
MsgTypeTransport MsgType = 4
) )
var ( var (
@ -22,10 +20,8 @@ func (m MsgType) String() string {
switch m { switch m {
case MsgTypeHello: case MsgTypeHello:
return "hello" return "hello"
case MsgTypeBindNewChannel: case MsgTypeHelloResponse:
return "bind new channel" return "hello response"
case MsgTypeBindResponse:
return "bind response"
case MsgTypeTransport: case MsgTypeTransport:
return "transport" return "transport"
default: default:
@ -39,8 +35,6 @@ func DetermineClientMsgType(msg []byte) (MsgType, error) {
switch msgType { switch msgType {
case MsgTypeHello: case MsgTypeHello:
return msgType, nil return msgType, nil
case MsgTypeBindNewChannel:
return msgType, nil
case MsgTypeTransport: case MsgTypeTransport:
return msgType, nil return msgType, nil
default: default:
@ -54,8 +48,6 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) {
switch msgType { switch msgType {
case MsgTypeHelloResponse: case MsgTypeHelloResponse:
return msgType, nil return msgType, nil
case MsgTypeBindResponse:
return msgType, nil
case MsgTypeTransport: case MsgTypeTransport:
return msgType, nil return msgType, nil
default: default:
@ -64,21 +56,21 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) {
} }
// MarshalHelloMsg initial hello message // MarshalHelloMsg initial hello message
func MarshalHelloMsg(peerID string) ([]byte, error) { func MarshalHelloMsg(peerID []byte) ([]byte, error) {
if len(peerID) == 0 { if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peer id") return nil, fmt.Errorf("invalid peerID length")
} }
msg := make([]byte, 1, 1+len(peerID)) msg := make([]byte, 1, 1+len(peerID))
msg[0] = byte(MsgTypeHello) msg[0] = byte(MsgTypeHello)
msg = append(msg, []byte(peerID)...) msg = append(msg, peerID...)
return msg, nil return msg, nil
} }
func UnmarshalHelloMsg(msg []byte) (string, error) { func UnmarshalHelloMsg(msg []byte) ([]byte, error) {
if len(msg) < 2 { if len(msg) < 2 {
return "", fmt.Errorf("invalid 'hello' messge") return nil, fmt.Errorf("invalid 'hello' messge")
} }
return string(msg[1:]), nil return msg[1:], nil
} }
func MarshalHelloResponse() []byte { func MarshalHelloResponse() []byte {
@ -87,71 +79,40 @@ func MarshalHelloResponse() []byte {
return msg return msg
} }
// Bind new channel
func MarshalBindNewChannelMsg(destinationPeerId string) []byte {
msg := make([]byte, 1, 1+len(destinationPeerId))
msg[0] = byte(MsgTypeBindNewChannel)
msg = append(msg, []byte(destinationPeerId)...)
return msg
}
func UnmarshalBindNewChannel(msg []byte) (string, error) {
if len(msg) < 2 {
return "", fmt.Errorf("invalid 'bind new channel' messge")
}
return string(msg[1:]), nil
}
// Bind response
func MarshalBindResponseMsg(channelId uint16, id string) []byte {
data := []byte(id)
msg := make([]byte, 3, 3+len(data))
msg[0] = byte(MsgTypeBindResponse)
msg[1], msg[2] = uint8(channelId>>8), uint8(channelId&0xff)
msg = append(msg, data...)
return msg
}
func UnmarshalBindResponseMsg(buf []byte) (uint16, string, error) {
if len(buf) < 3 {
return 0, "", ErrInvalidMessageLength
}
channelId := uint16(buf[1])<<8 | uint16(buf[2])
peerID := string(buf[3:])
return channelId, peerID, nil
}
// Transport message // Transport message
func MarshalTransportMsg(channelId uint16, payload []byte) []byte { func MarshalTransportMsg(peerID []byte, payload []byte) []byte {
msg := make([]byte, 3, 3+len(payload)) if len(peerID) != IDSize {
return nil
}
msg := make([]byte, 1+IDSize, 1+IDSize+len(payload))
msg[0] = byte(MsgTypeTransport) msg[0] = byte(MsgTypeTransport)
msg[1], msg[2] = uint8(channelId>>8), uint8(channelId&0xff) copy(msg[1:], peerID)
msg = append(msg, payload...) msg = append(msg, payload...)
return msg return msg
} }
func UnmarshalTransportPayload(buf []byte) ([]byte, error) { func UnmarshalTransportPayload(buf []byte) ([]byte, error) {
if len(buf) < 3 { headerSize := 1 + IDSize
if len(buf) < headerSize {
return nil, ErrInvalidMessageLength return nil, ErrInvalidMessageLength
} }
return buf[3:], nil return buf[headerSize:], nil
} }
func UnmarshalTransportID(buf []byte) (uint16, error) { func UnmarshalTransportID(buf []byte) ([]byte, error) {
if len(buf) < 3 { headerSize := 1 + IDSize
return 0, ErrInvalidMessageLength if len(buf) < headerSize {
return nil, ErrInvalidMessageLength
} }
channelId := uint16(buf[1])<<8 | uint16(buf[2]) return buf[1:headerSize], nil
return channelId, nil
} }
func UpdateTransportMsg(msg []byte, channelId uint16) error { func UpdateTransportMsg(msg []byte, peerID []byte) error {
if len(msg) < 3 { if len(msg) < 1+len(peerID) {
return ErrInvalidMessageLength return ErrInvalidMessageLength
} }
msg[1], msg[2] = uint8(channelId>>8), uint8(channelId&0xff) copy(msg[1:], peerID)
return nil return nil
} }

View File

@ -10,15 +10,15 @@ import (
) )
type Listener struct { type Listener struct {
address string address string
conns map[string]*UDPConn
onAcceptFn func(conn net.Conn) onAcceptFn func(conn net.Conn)
conns map[string]*UDPConn
wg sync.WaitGroup
quit chan struct{}
lock sync.Mutex
listener *net.UDPConn listener *net.UDPConn
wg sync.WaitGroup
quit chan struct{}
lock sync.Mutex
} }
func NewListener(address string) listener.Listener { func NewListener(address string) listener.Listener {
@ -34,17 +34,20 @@ func (l *Listener) Listen(onAcceptFn func(conn net.Conn)) error {
l.onAcceptFn = onAcceptFn l.onAcceptFn = onAcceptFn
l.quit = make(chan struct{}) l.quit = make(chan struct{})
addr := &net.UDPAddr{ addr, err := net.ResolveUDPAddr("udp", l.address)
Port: 1234,
IP: net.ParseIP("0.0.0.0"),
}
li, err := net.ListenUDP("udp", addr)
if err != nil { if err != nil {
log.Errorf("%s", err) log.Errorf("invalid listen address '%s': %s", l.address, err)
l.lock.Unlock() l.lock.Unlock()
return err return err
} }
log.Debugf("udp server is listening on address: %s", l.address)
li, err := net.ListenUDP("udp", addr)
if err != nil {
log.Fatalf("%s", err)
l.lock.Unlock()
return err
}
log.Debugf("udp server is listening on address: %s", addr.String())
l.listener = li l.listener = li
l.wg.Add(1) l.wg.Add(1)
go l.readLoop() go l.readLoop()
@ -54,14 +57,18 @@ func (l *Listener) Listen(onAcceptFn func(conn net.Conn)) error {
return nil return nil
} }
// Close todo: prevent multiple call (do not close two times the channel)
func (l *Listener) Close() error { func (l *Listener) Close() error {
l.lock.Lock() l.lock.Lock()
defer l.lock.Unlock() defer l.lock.Unlock()
if l.listener == nil {
return nil
}
close(l.quit) close(l.quit)
err := l.listener.Close() err := l.listener.Close()
l.wg.Wait() l.wg.Wait()
l.listener = nil
return err return err
} }
@ -91,6 +98,5 @@ func (l *Listener) readLoop() {
l.conns[addr.String()] = pConn l.conns[addr.String()] = pConn
go l.onAcceptFn(pConn) go l.onAcceptFn(pConn)
pConn.onNewMsg(buf[:n]) pConn.onNewMsg(buf[:n])
} }
} }

View File

@ -1,113 +1,34 @@
package server package server
import ( import (
"fmt"
"net" "net"
"sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
)
type Participant struct { "github.com/netbirdio/netbird/relay/messages"
ChannelID uint16 )
ChannelIDForeign uint16
ConnForeign net.Conn
Peer *Peer
}
type Peer struct { type Peer struct {
Log *log.Entry Log *log.Entry
id string idS string
idB []byte
conn net.Conn conn net.Conn
pendingParticipantByChannelID map[uint16]*Participant
participantByID map[uint16]*Participant // used for package transfer
participantByPeerID map[string]*Participant // used for channel linking
lastId uint16
lastIdLock sync.Mutex
} }
func NewPeer(id string, conn net.Conn) *Peer { func NewPeer(id []byte, conn net.Conn) *Peer {
log.Debugf("new peer: %v", id)
stringID := messages.HashIDToString(id)
return &Peer{ return &Peer{
Log: log.WithField("peer_id", id), Log: log.WithField("peer_id", stringID),
id: id, idB: id,
conn: conn, idS: stringID,
pendingParticipantByChannelID: make(map[uint16]*Participant), conn: conn,
participantByID: make(map[uint16]*Participant),
participantByPeerID: make(map[string]*Participant),
} }
} }
func (p *Peer) BindChannel(remotePeerId string) uint16 { func (p *Peer) ID() []byte {
ch, ok := p.participantByPeerID[remotePeerId] return p.idB
if ok {
return ch.ChannelID
}
channelID := p.newChannelID()
channel := &Participant{
ChannelID: channelID,
}
p.pendingParticipantByChannelID[channelID] = channel
p.participantByPeerID[remotePeerId] = channel
return channelID
} }
func (p *Peer) UnBindChannel(remotePeerId string) { func (p *Peer) String() string {
pa, ok := p.participantByPeerID[remotePeerId] return p.idS
if !ok {
return
}
p.Log.Debugf("unbind channel with '%s': %d", remotePeerId, pa.ChannelID)
p.pendingParticipantByChannelID[pa.ChannelID] = pa
delete(p.participantByID, pa.ChannelID)
}
func (p *Peer) AddParticipant(peer *Peer, remoteChannelID uint16) (uint16, bool) {
participant, ok := p.participantByPeerID[peer.ID()]
if !ok {
return 0, false
}
participant.ChannelIDForeign = remoteChannelID
participant.ConnForeign = peer.conn
participant.Peer = peer
delete(p.pendingParticipantByChannelID, participant.ChannelID)
p.participantByID[participant.ChannelID] = participant
return participant.ChannelID, true
}
func (p *Peer) DeleteParticipants() {
for _, participant := range p.participantByID {
participant.Peer.UnBindChannel(p.id)
}
}
func (p *Peer) ConnByChannelID(dstID uint16) (uint16, net.Conn, error) {
ch, ok := p.participantByID[dstID]
if !ok {
return 0, nil, fmt.Errorf("destination channel not found")
}
return ch.ChannelIDForeign, ch.ConnForeign, nil
}
func (p *Peer) ID() string {
return p.id
}
func (p *Peer) newChannelID() uint16 {
p.lastIdLock.Lock()
defer p.lastIdLock.Unlock()
for {
p.lastId++
if _, ok := p.pendingParticipantByChannelID[p.lastId]; ok {
continue
}
if _, ok := p.participantByID[p.lastId]; ok {
continue
}
return p.lastId
}
} }

View File

@ -16,7 +16,6 @@ import (
// todo: // todo:
// authentication: provide JWT token via RPC call. The MGM server can forward the token to the agents. // authentication: provide JWT token via RPC call. The MGM server can forward the token to the agents.
// connection timeout handling // connection timeout handling
// implement HA (High Availability) mode
type Server struct { type Server struct {
store *Store store *Store
@ -75,54 +74,35 @@ func (r *Server) accept(conn net.Conn) {
return return
} }
switch msgType { switch msgType {
case messages.MsgTypeBindNewChannel:
dstPeerId, err := messages.UnmarshalBindNewChannel(buf[:n])
if err != nil {
log.Errorf("failed to unmarshal bind new channel message: %s", err)
continue
}
channelID := r.store.Link(peer, dstPeerId)
msg := messages.MarshalBindResponseMsg(channelID, dstPeerId)
_, err = conn.Write(msg)
if err != nil {
peer.Log.Errorf("failed to response to bind request: %s", err)
continue
}
peer.Log.Debugf("bind new channel with '%s', channelID: %d", dstPeerId, channelID)
case messages.MsgTypeTransport: case messages.MsgTypeTransport:
msg := buf[:n] msg := buf[:n]
channelId, err := messages.UnmarshalTransportID(msg) peerID, err := messages.UnmarshalTransportID(msg)
if err != nil { if err != nil {
peer.Log.Errorf("failed to unmarshal transport message: %s", err) peer.Log.Errorf("failed to unmarshal transport message: %s", err)
continue continue
} }
go func() { go func() {
foreignChannelID, remoteConn, err := peer.ConnByChannelID(channelId) stringPeerID := messages.HashIDToString(peerID)
if err != nil { dp, ok := r.store.Peer(stringPeerID)
peer.Log.Errorf("failed to transport message from peer '%s' to '%d': %s", peer.ID(), channelId, err) if !ok {
peer.Log.Errorf("peer not found: %s", stringPeerID)
return return
} }
err := messages.UpdateTransportMsg(msg, peer.ID())
err = transportTo(remoteConn, foreignChannelID, msg)
if err != nil { if err != nil {
peer.Log.Errorf("failed to transport message from peer '%s' to '%d': %s", peer.ID(), channelId, err) peer.Log.Errorf("failed to update transport message: %s", err)
return
} }
_, err = dp.conn.Write(msg)
if err != nil {
peer.Log.Errorf("failed to write transport message to: %s", dp.String())
}
return
}() }()
} }
} }
} }
func transportTo(conn net.Conn, channelID uint16, msg []byte) error {
err := messages.UpdateTransportMsg(msg, channelID)
if err != nil {
return err
}
_, err = conn.Write(msg)
return err
}
func handShake(conn net.Conn) (*Peer, error) { func handShake(conn net.Conn) (*Peer, error) {
buf := make([]byte, 1500) buf := make([]byte, 1500)
n, err := conn.Read(buf) n, err := conn.Read(buf)

View File

@ -5,8 +5,8 @@ import (
) )
type Store struct { type Store struct {
peers map[string]*Peer // Key is the id (public key or sha-256) of the peer peers map[string]*Peer // consider to use [32]byte as key. The Peer(id string) would be faster
peersLock sync.Mutex peersLock sync.RWMutex
} }
func NewStore() *Store { func NewStore() *Store {
@ -18,31 +18,20 @@ func NewStore() *Store {
func (s *Store) AddPeer(peer *Peer) { func (s *Store) AddPeer(peer *Peer) {
s.peersLock.Lock() s.peersLock.Lock()
defer s.peersLock.Unlock() defer s.peersLock.Unlock()
s.peers[peer.ID()] = peer s.peers[peer.String()] = peer
}
func (s *Store) Link(peer *Peer, peerForeignID string) uint16 {
s.peersLock.Lock()
defer s.peersLock.Unlock()
channelId := peer.BindChannel(peerForeignID)
dstPeer, ok := s.peers[peerForeignID]
if !ok {
return channelId
}
foreignChannelID, ok := dstPeer.AddParticipant(peer, channelId)
if !ok {
return channelId
}
peer.AddParticipant(dstPeer, foreignChannelID)
return channelId
} }
func (s *Store) DeletePeer(peer *Peer) { func (s *Store) DeletePeer(peer *Peer) {
s.peersLock.Lock() s.peersLock.Lock()
defer s.peersLock.Unlock() defer s.peersLock.Unlock()
delete(s.peers, peer.ID()) delete(s.peers, peer.String())
peer.DeleteParticipants() }
func (s *Store) Peer(id string) (*Peer, bool) {
s.peersLock.RLock()
defer s.peersLock.RUnlock()
p, ok := s.peers[id]
return p, ok
} }

View File

@ -50,11 +50,6 @@ func TestClient(t *testing.T) {
} }
defer clientPlaceHolder.Close() defer clientPlaceHolder.Close()
_, err = clientAlice.BindChannel("clientPlaceHolder")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
clientBob := client.NewClient(addr, "bob") clientBob := client.NewClient(addr, "bob")
err = clientBob.Connect() err = clientBob.Connect()
if err != nil { if err != nil {
@ -62,12 +57,12 @@ func TestClient(t *testing.T) {
} }
defer clientBob.Close() defer clientBob.Close()
connAliceToBob, err := clientAlice.BindChannel("bob") connAliceToBob, err := clientAlice.OpenConn("bob")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
connBobToAlice, err := clientBob.BindChannel("alice") connBobToAlice, err := clientBob.OpenConn("alice")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@ -154,6 +149,8 @@ func TestRegistrationTimeout(t *testing.T) {
} }
func TestEcho(t *testing.T) { func TestEcho(t *testing.T) {
idAlice := "alice"
idBob := "bob"
addr := "localhost:1234" addr := "localhost:1234"
srv := server.NewServer() srv := server.NewServer()
go func() { go func() {
@ -170,7 +167,7 @@ func TestEcho(t *testing.T) {
} }
}() }()
clientAlice := client.NewClient(addr, "alice") clientAlice := client.NewClient(addr, idAlice)
err := clientAlice.Connect() err := clientAlice.Connect()
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
@ -182,7 +179,7 @@ func TestEcho(t *testing.T) {
} }
}() }()
clientBob := client.NewClient(addr, "bob") clientBob := client.NewClient(addr, idBob)
err = clientBob.Connect() err = clientBob.Connect()
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
@ -194,12 +191,12 @@ func TestEcho(t *testing.T) {
} }
}() }()
connAliceToBob, err := clientAlice.BindChannel("bob") connAliceToBob, err := clientAlice.OpenConn(idBob)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
connBobToAlice, err := clientBob.BindChannel("alice") connBobToAlice, err := clientBob.OpenConn(idAlice)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@ -262,7 +259,7 @@ func TestBindToUnavailabePeer(t *testing.T) {
} }
}() }()
_, err = clientAlice.BindChannel("bob") _, err = clientAlice.OpenConn("bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
@ -292,7 +289,7 @@ func TestBindReconnect(t *testing.T) {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
_, err = clientAlice.BindChannel("bob") _, err = clientAlice.OpenConn("bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
@ -303,7 +300,7 @@ func TestBindReconnect(t *testing.T) {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
chBob, err := clientBob.BindChannel("alice") chBob, err := clientBob.OpenConn("alice")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
@ -320,7 +317,7 @@ func TestBindReconnect(t *testing.T) {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
chAlice, err := clientAlice.BindChannel("bob") chAlice, err := clientAlice.OpenConn("bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }