From 36b2cd16ccc53350fae358b16917218a50a4753f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20Papp?= Date: Thu, 23 May 2024 13:24:02 +0200 Subject: [PATCH] Remove channel binding logic --- relay/client/client.go | 167 ++++++++------------------ relay/client/conn.go | 16 +-- relay/client/manager.go | 43 +++++++ relay/cmd/main.go | 5 +- relay/messages/id.go | 20 +++ relay/messages/message.go | 99 +++++---------- relay/server/listener/udp/listener.go | 36 +++--- relay/server/peer.go | 109 +++-------------- relay/server/server.go | 46 ++----- relay/server/store.go | 35 ++---- relay/test/client_test.go | 27 ++--- 11 files changed, 229 insertions(+), 374 deletions(-) create mode 100644 relay/client/manager.go create mode 100644 relay/messages/id.go diff --git a/relay/client/client.go b/relay/client/client.go index 130cb8592..d574e04a1 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -1,7 +1,6 @@ package client import ( - "context" "fmt" "io" "net" @@ -19,25 +18,21 @@ const ( serverResponseTimeout = 8 * time.Second ) -type bufMsg struct { - bufPtr *[]byte - buf []byte +type Msg struct { + buf []byte } type connContainer struct { conn *Conn - messages chan bufMsg + messages chan Msg } -// Client Todo: -// - handle automatic reconnection type Client struct { + log *log.Entry serverAddress string - peerID string + hashedID []byte - channelsPending map[string]chan net.Conn // todo: protect map with mutex - channels map[uint16]*connContainer - msgPool sync.Pool + conns map[string]*connContainer relayConn net.Conn relayConnState bool @@ -45,17 +40,12 @@ type Client struct { } func NewClient(serverAddress, peerID string) *Client { + hashedID, hashedStringId := messages.HashID(peerID) return &Client{ - serverAddress: serverAddress, - peerID: peerID, - channelsPending: make(map[string]chan net.Conn), - channels: make(map[uint16]*connContainer), - msgPool: sync.Pool{ - New: func() any { - buf := make([]byte, bufferSize) - return &buf - }, - }, + log: log.WithField("client_id", hashedStringId), + serverAddress: serverAddress, + hashedID: hashedID, + conns: make(map[string]*connContainer), } } @@ -89,31 +79,17 @@ func (c *Client) Connect() error { return nil } -func (c *Client) BindChannel(remotePeerID string) (net.Conn, error) { - c.mu.Lock() - defer c.mu.Unlock() +func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) { + hashedID, hashedStringID := messages.HashID(dstPeerID) + log.Infof("open connection to peer: %s", hashedStringID) + messageBuffer := make(chan Msg, 2) + conn := NewConn(c, hashedID, c.generateConnReaderFN(messageBuffer)) - if c.relayConn == nil { - return nil, fmt.Errorf("client not connected to the relay server") - } - - bindSuccessChan := make(chan net.Conn, 1) - c.channelsPending[remotePeerID] = bindSuccessChan - msg := messages.MarshalBindNewChannelMsg(remotePeerID) - _, err := c.relayConn.Write(msg) - if err != nil { - log.Errorf("failed to write out bind message: %s", err) - return nil, err - } - - ctx, cancel := context.WithTimeout(context.Background(), serverResponseTimeout) - defer cancel() - select { - case <-ctx.Done(): - return nil, fmt.Errorf("bind timeout") - case c := <-bindSuccessChan: - return c, nil + c.conns[hashedStringID] = &connContainer{ + conn, + messageBuffer, } + return conn, nil } func (c *Client) Close() error { @@ -124,18 +100,15 @@ func (c *Client) Close() error { return nil } - for _, conn := range c.channels { - close(conn.messages) - } - c.channels = make(map[uint16]*connContainer) c.relayConnState = false err := c.relayConn.Close() return err } func (c *Client) handShake() error { - msg, err := messages.MarshalHelloMsg(c.peerID) + msg, err := messages.MarshalHelloMsg(c.hashedID) if err != nil { + log.Errorf("failed to marshal hello message: %s", err) return err } _, err = c.relayConn.Write(msg) @@ -171,85 +144,56 @@ func (c *Client) handShake() error { } func (c *Client) readLoop() { - log := log.WithField("client_id", c.peerID) + defer func() { + c.log.Debugf("exit from read loop") + }() var errExit error var n int for { - bufPtr := c.msgPool.Get().(*[]byte) - buf := *bufPtr + buf := make([]byte, bufferSize) n, errExit = c.relayConn.Read(buf) if errExit != nil { - log.Debugf("failed to read message from relay server: %s", errExit) - c.freeBuf(bufPtr) + if c.relayConnState { + c.log.Debugf("failed to read message from relay server: %s", errExit) + } break } msgType, err := messages.DetermineServerMsgType(buf[:n]) if err != nil { - log.Errorf("failed to determine message type: %s", err) - c.freeBuf(bufPtr) + c.log.Errorf("failed to determine message type: %s", err) continue } switch msgType { - case messages.MsgTypeBindResponse: - channelId, peerId, err := messages.UnmarshalBindResponseMsg(buf[:n]) - if err != nil { - log.Errorf("failed to parse bind response message: %v", err) - } else { - c.handleBindResponse(channelId, peerId) - } - c.freeBuf(bufPtr) - continue case messages.MsgTypeTransport: - channelId, err := messages.UnmarshalTransportID(buf[:n]) + peerID, err := messages.UnmarshalTransportID(buf[:n]) if err != nil { - log.Errorf("failed to parse transport message: %v", err) - c.freeBuf(bufPtr) + c.log.Errorf("failed to parse transport message: %v", err) continue } - container, ok := c.channels[channelId] + stringID := messages.HashIDToString(peerID) + + container, ok := c.conns[stringID] if !ok { - log.Errorf("unexpected transport message for channel: %d", channelId) - c.freeBuf(bufPtr) - return + c.log.Errorf("peer not found: %s", stringID) + continue } - container.messages <- bufMsg{ - bufPtr, + container.messages <- Msg{ buf[:n], } } } if c.relayConnState { - log.Errorf("failed to read message from relay server: %s", errExit) + c.log.Errorf("failed to read message from relay server: %s", errExit) _ = c.relayConn.Close() } } -func (c *Client) handleBindResponse(channelId uint16, peerId string) { - bindSuccessChan, ok := c.channelsPending[peerId] - if !ok { - log.Errorf("unexpected bind response from: %s", peerId) - return - } - delete(c.channelsPending, peerId) - - messageBuffer := make(chan bufMsg, 2) - conn := NewConn(c, channelId, c.generateConnReaderFN(messageBuffer)) - - c.channels[channelId] = &connContainer{ - conn, - messageBuffer, - } - log.Debugf("bind success for '%s': %d", peerId, channelId) - - bindSuccessChan <- conn -} - -func (c *Client) writeTo(channelID uint16, payload []byte) (int, error) { - msg := messages.MarshalTransportMsg(channelID, payload) +func (c *Client) writeTo(dstID []byte, payload []byte) (int, error) { + msg := messages.MarshalTransportMsg(dstID, payload) n, err := c.relayConn.Write(msg) if err != nil { log.Errorf("failed to write transport message: %s", err) @@ -257,26 +201,19 @@ func (c *Client) writeTo(channelID uint16, payload []byte) (int, error) { return n, err } -func (c *Client) generateConnReaderFN(messageBufferChan chan bufMsg) func(b []byte) (n int, err error) { +func (c *Client) generateConnReaderFN(msgChannel chan Msg) func(b []byte) (n int, err error) { return func(b []byte) (n int, err error) { - select { - case bufMsg, ok := <-messageBufferChan: - if !ok { - return 0, io.EOF - } - - payload, err := messages.UnmarshalTransportPayload(bufMsg.buf) - if err != nil { - return 0, err - } - - n = copy(b, payload) - c.freeBuf(bufMsg.bufPtr) + msg, ok := <-msgChannel + if !ok { + return 0, io.EOF } + + payload, err := messages.UnmarshalTransportPayload(msg.buf) + if err != nil { + return 0, err + } + + n = copy(b, payload) return n, nil } } - -func (c *Client) freeBuf(ptr *[]byte) { - c.msgPool.Put(ptr) -} diff --git a/relay/client/conn.go b/relay/client/conn.go index d450c3f30..aea07ff9b 100644 --- a/relay/client/conn.go +++ b/relay/client/conn.go @@ -6,23 +6,23 @@ import ( ) type Conn struct { - client *Client - channelID uint16 - readerFn func(b []byte) (n int, err error) + client *Client + dstID []byte + readerFn func(b []byte) (n int, err error) } -func NewConn(client *Client, channelID uint16, readerFn func(b []byte) (n int, err error)) *Conn { +func NewConn(client *Client, dstID []byte, readerFn func(b []byte) (n int, err error)) *Conn { c := &Conn{ - client: client, - channelID: channelID, - readerFn: readerFn, + client: client, + dstID: dstID, + readerFn: readerFn, } return c } func (c *Conn) Write(p []byte) (n int, err error) { - return c.client.writeTo(c.channelID, p) + return c.client.writeTo(c.dstID, p) } func (c *Conn) Read(b []byte) (n int, err error) { diff --git a/relay/client/manager.go b/relay/client/manager.go new file mode 100644 index 000000000..4d1aeca79 --- /dev/null +++ b/relay/client/manager.go @@ -0,0 +1,43 @@ +package client + +import ( + "context" + "sync" +) + +type Manager struct { + ctx context.Context + ctxCancel context.CancelFunc + srvAddress string + peerID string + + wg sync.WaitGroup + + clients map[string]*Client + clientsMutex sync.RWMutex +} + +func NewManager(ctx context.Context, serverAddress string, peerID string) *Manager { + ctx, cancel := context.WithCancel(ctx) + return &Manager{ + ctx: ctx, + ctxCancel: cancel, + srvAddress: serverAddress, + peerID: peerID, + clients: make(map[string]*Client), + } +} + +func (m *Manager) Teardown() { + m.ctxCancel() + m.wg.Wait() +} + +func (m *Manager) newSrvConnection(address string) { + if _, ok := m.clients[address]; ok { + return + } + + // client := NewClient(address, m.peerID) + //err = client.Connect() +} diff --git a/relay/cmd/main.go b/relay/cmd/main.go index cfae8232a..b89ab26ef 100644 --- a/relay/cmd/main.go +++ b/relay/cmd/main.go @@ -1,9 +1,10 @@ package main import ( - "github.com/netbirdio/netbird/util" "os" + "github.com/netbirdio/netbird/util" + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/relay/server" @@ -15,7 +16,7 @@ func init() { func main() { - address := "0.0.0.0:1234" + address := "10.145.236.1:1235" srv := server.NewServer() err := srv.Listen(address) if err != nil { diff --git a/relay/messages/id.go b/relay/messages/id.go new file mode 100644 index 000000000..d37dc37f8 --- /dev/null +++ b/relay/messages/id.go @@ -0,0 +1,20 @@ +package messages + +import ( + "crypto/sha256" + "encoding/base64" +) + +const ( + IDSize = sha256.Size +) + +func HashID(peerID string) ([]byte, string) { + idHash := sha256.Sum256([]byte(peerID)) + idHashString := base64.StdEncoding.EncodeToString(idHash[:]) + return idHash[:], idHashString +} + +func HashIDToString(idHash []byte) string { + return base64.StdEncoding.EncodeToString(idHash[:]) +} diff --git a/relay/messages/message.go b/relay/messages/message.go index 02945e2f1..c71d203b1 100644 --- a/relay/messages/message.go +++ b/relay/messages/message.go @@ -5,11 +5,9 @@ import ( ) const ( - MsgTypeHello MsgType = 0 - MsgTypeHelloResponse MsgType = 1 - MsgTypeBindNewChannel MsgType = 2 - MsgTypeBindResponse MsgType = 3 - MsgTypeTransport MsgType = 4 + MsgTypeHello MsgType = 0 + MsgTypeHelloResponse MsgType = 1 + MsgTypeTransport MsgType = 2 ) var ( @@ -22,10 +20,8 @@ func (m MsgType) String() string { switch m { case MsgTypeHello: return "hello" - case MsgTypeBindNewChannel: - return "bind new channel" - case MsgTypeBindResponse: - return "bind response" + case MsgTypeHelloResponse: + return "hello response" case MsgTypeTransport: return "transport" default: @@ -39,8 +35,6 @@ func DetermineClientMsgType(msg []byte) (MsgType, error) { switch msgType { case MsgTypeHello: return msgType, nil - case MsgTypeBindNewChannel: - return msgType, nil case MsgTypeTransport: return msgType, nil default: @@ -54,8 +48,6 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) { switch msgType { case MsgTypeHelloResponse: return msgType, nil - case MsgTypeBindResponse: - return msgType, nil case MsgTypeTransport: return msgType, nil default: @@ -64,21 +56,21 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) { } // MarshalHelloMsg initial hello message -func MarshalHelloMsg(peerID string) ([]byte, error) { - if len(peerID) == 0 { - return nil, fmt.Errorf("invalid peer id") +func MarshalHelloMsg(peerID []byte) ([]byte, error) { + if len(peerID) != IDSize { + return nil, fmt.Errorf("invalid peerID length") } msg := make([]byte, 1, 1+len(peerID)) msg[0] = byte(MsgTypeHello) - msg = append(msg, []byte(peerID)...) + msg = append(msg, peerID...) return msg, nil } -func UnmarshalHelloMsg(msg []byte) (string, error) { +func UnmarshalHelloMsg(msg []byte) ([]byte, error) { if len(msg) < 2 { - return "", fmt.Errorf("invalid 'hello' messge") + return nil, fmt.Errorf("invalid 'hello' messge") } - return string(msg[1:]), nil + return msg[1:], nil } func MarshalHelloResponse() []byte { @@ -87,71 +79,40 @@ func MarshalHelloResponse() []byte { return msg } -// Bind new channel - -func MarshalBindNewChannelMsg(destinationPeerId string) []byte { - msg := make([]byte, 1, 1+len(destinationPeerId)) - msg[0] = byte(MsgTypeBindNewChannel) - msg = append(msg, []byte(destinationPeerId)...) - return msg -} - -func UnmarshalBindNewChannel(msg []byte) (string, error) { - if len(msg) < 2 { - return "", fmt.Errorf("invalid 'bind new channel' messge") - } - return string(msg[1:]), nil -} - -// Bind response - -func MarshalBindResponseMsg(channelId uint16, id string) []byte { - data := []byte(id) - msg := make([]byte, 3, 3+len(data)) - msg[0] = byte(MsgTypeBindResponse) - msg[1], msg[2] = uint8(channelId>>8), uint8(channelId&0xff) - msg = append(msg, data...) - return msg -} - -func UnmarshalBindResponseMsg(buf []byte) (uint16, string, error) { - if len(buf) < 3 { - return 0, "", ErrInvalidMessageLength - } - channelId := uint16(buf[1])<<8 | uint16(buf[2]) - peerID := string(buf[3:]) - return channelId, peerID, nil -} - // Transport message -func MarshalTransportMsg(channelId uint16, payload []byte) []byte { - msg := make([]byte, 3, 3+len(payload)) +func MarshalTransportMsg(peerID []byte, payload []byte) []byte { + if len(peerID) != IDSize { + return nil + } + + msg := make([]byte, 1+IDSize, 1+IDSize+len(payload)) msg[0] = byte(MsgTypeTransport) - msg[1], msg[2] = uint8(channelId>>8), uint8(channelId&0xff) + copy(msg[1:], peerID) msg = append(msg, payload...) return msg } func UnmarshalTransportPayload(buf []byte) ([]byte, error) { - if len(buf) < 3 { + headerSize := 1 + IDSize + if len(buf) < headerSize { return nil, ErrInvalidMessageLength } - return buf[3:], nil + return buf[headerSize:], nil } -func UnmarshalTransportID(buf []byte) (uint16, error) { - if len(buf) < 3 { - return 0, ErrInvalidMessageLength +func UnmarshalTransportID(buf []byte) ([]byte, error) { + headerSize := 1 + IDSize + if len(buf) < headerSize { + return nil, ErrInvalidMessageLength } - channelId := uint16(buf[1])<<8 | uint16(buf[2]) - return channelId, nil + return buf[1:headerSize], nil } -func UpdateTransportMsg(msg []byte, channelId uint16) error { - if len(msg) < 3 { +func UpdateTransportMsg(msg []byte, peerID []byte) error { + if len(msg) < 1+len(peerID) { return ErrInvalidMessageLength } - msg[1], msg[2] = uint8(channelId>>8), uint8(channelId&0xff) + copy(msg[1:], peerID) return nil } diff --git a/relay/server/listener/udp/listener.go b/relay/server/listener/udp/listener.go index df7fa4c64..400b68a88 100644 --- a/relay/server/listener/udp/listener.go +++ b/relay/server/listener/udp/listener.go @@ -10,15 +10,15 @@ import ( ) type Listener struct { - address string - + address string + conns map[string]*UDPConn onAcceptFn func(conn net.Conn) - conns map[string]*UDPConn - wg sync.WaitGroup - quit chan struct{} - lock sync.Mutex listener *net.UDPConn + + wg sync.WaitGroup + quit chan struct{} + lock sync.Mutex } func NewListener(address string) listener.Listener { @@ -34,17 +34,20 @@ func (l *Listener) Listen(onAcceptFn func(conn net.Conn)) error { l.onAcceptFn = onAcceptFn l.quit = make(chan struct{}) - addr := &net.UDPAddr{ - Port: 1234, - IP: net.ParseIP("0.0.0.0"), - } - li, err := net.ListenUDP("udp", addr) + addr, err := net.ResolveUDPAddr("udp", l.address) if err != nil { - log.Errorf("%s", err) + log.Errorf("invalid listen address '%s': %s", l.address, err) l.lock.Unlock() return err } - log.Debugf("udp server is listening on address: %s", l.address) + + li, err := net.ListenUDP("udp", addr) + if err != nil { + log.Fatalf("%s", err) + l.lock.Unlock() + return err + } + log.Debugf("udp server is listening on address: %s", addr.String()) l.listener = li l.wg.Add(1) go l.readLoop() @@ -54,14 +57,18 @@ func (l *Listener) Listen(onAcceptFn func(conn net.Conn)) error { return nil } -// Close todo: prevent multiple call (do not close two times the channel) func (l *Listener) Close() error { l.lock.Lock() defer l.lock.Unlock() + if l.listener == nil { + return nil + } + close(l.quit) err := l.listener.Close() l.wg.Wait() + l.listener = nil return err } @@ -91,6 +98,5 @@ func (l *Listener) readLoop() { l.conns[addr.String()] = pConn go l.onAcceptFn(pConn) pConn.onNewMsg(buf[:n]) - } } diff --git a/relay/server/peer.go b/relay/server/peer.go index 2e40cbb12..d4b98b9b4 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -1,113 +1,34 @@ package server import ( - "fmt" "net" - "sync" log "github.com/sirupsen/logrus" -) -type Participant struct { - ChannelID uint16 - ChannelIDForeign uint16 - ConnForeign net.Conn - Peer *Peer -} + "github.com/netbirdio/netbird/relay/messages" +) type Peer struct { Log *log.Entry - id string + idS string + idB []byte conn net.Conn - - pendingParticipantByChannelID map[uint16]*Participant - participantByID map[uint16]*Participant // used for package transfer - participantByPeerID map[string]*Participant // used for channel linking - - lastId uint16 - lastIdLock sync.Mutex } -func NewPeer(id string, conn net.Conn) *Peer { +func NewPeer(id []byte, conn net.Conn) *Peer { + log.Debugf("new peer: %v", id) + stringID := messages.HashIDToString(id) return &Peer{ - Log: log.WithField("peer_id", id), - id: id, - conn: conn, - pendingParticipantByChannelID: make(map[uint16]*Participant), - participantByID: make(map[uint16]*Participant), - participantByPeerID: make(map[string]*Participant), + Log: log.WithField("peer_id", stringID), + idB: id, + idS: stringID, + conn: conn, } } -func (p *Peer) BindChannel(remotePeerId string) uint16 { - ch, ok := p.participantByPeerID[remotePeerId] - if ok { - return ch.ChannelID - } - - channelID := p.newChannelID() - channel := &Participant{ - ChannelID: channelID, - } - p.pendingParticipantByChannelID[channelID] = channel - p.participantByPeerID[remotePeerId] = channel - return channelID +func (p *Peer) ID() []byte { + return p.idB } -func (p *Peer) UnBindChannel(remotePeerId string) { - pa, ok := p.participantByPeerID[remotePeerId] - if !ok { - return - } - - p.Log.Debugf("unbind channel with '%s': %d", remotePeerId, pa.ChannelID) - p.pendingParticipantByChannelID[pa.ChannelID] = pa - delete(p.participantByID, pa.ChannelID) -} - -func (p *Peer) AddParticipant(peer *Peer, remoteChannelID uint16) (uint16, bool) { - participant, ok := p.participantByPeerID[peer.ID()] - if !ok { - return 0, false - } - participant.ChannelIDForeign = remoteChannelID - participant.ConnForeign = peer.conn - participant.Peer = peer - - delete(p.pendingParticipantByChannelID, participant.ChannelID) - p.participantByID[participant.ChannelID] = participant - return participant.ChannelID, true -} - -func (p *Peer) DeleteParticipants() { - for _, participant := range p.participantByID { - participant.Peer.UnBindChannel(p.id) - } -} - -func (p *Peer) ConnByChannelID(dstID uint16) (uint16, net.Conn, error) { - ch, ok := p.participantByID[dstID] - if !ok { - return 0, nil, fmt.Errorf("destination channel not found") - } - - return ch.ChannelIDForeign, ch.ConnForeign, nil -} - -func (p *Peer) ID() string { - return p.id -} - -func (p *Peer) newChannelID() uint16 { - p.lastIdLock.Lock() - defer p.lastIdLock.Unlock() - for { - p.lastId++ - if _, ok := p.pendingParticipantByChannelID[p.lastId]; ok { - continue - } - if _, ok := p.participantByID[p.lastId]; ok { - continue - } - return p.lastId - } +func (p *Peer) String() string { + return p.idS } diff --git a/relay/server/server.go b/relay/server/server.go index a66341e51..4b05975ec 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -16,7 +16,6 @@ import ( // todo: // authentication: provide JWT token via RPC call. The MGM server can forward the token to the agents. // connection timeout handling -// implement HA (High Availability) mode type Server struct { store *Store @@ -75,54 +74,35 @@ func (r *Server) accept(conn net.Conn) { return } switch msgType { - case messages.MsgTypeBindNewChannel: - dstPeerId, err := messages.UnmarshalBindNewChannel(buf[:n]) - if err != nil { - log.Errorf("failed to unmarshal bind new channel message: %s", err) - continue - } - - channelID := r.store.Link(peer, dstPeerId) - - msg := messages.MarshalBindResponseMsg(channelID, dstPeerId) - _, err = conn.Write(msg) - if err != nil { - peer.Log.Errorf("failed to response to bind request: %s", err) - continue - } - peer.Log.Debugf("bind new channel with '%s', channelID: %d", dstPeerId, channelID) case messages.MsgTypeTransport: msg := buf[:n] - channelId, err := messages.UnmarshalTransportID(msg) + peerID, err := messages.UnmarshalTransportID(msg) if err != nil { peer.Log.Errorf("failed to unmarshal transport message: %s", err) continue } go func() { - foreignChannelID, remoteConn, err := peer.ConnByChannelID(channelId) - if err != nil { - peer.Log.Errorf("failed to transport message from peer '%s' to '%d': %s", peer.ID(), channelId, err) + stringPeerID := messages.HashIDToString(peerID) + dp, ok := r.store.Peer(stringPeerID) + if !ok { + peer.Log.Errorf("peer not found: %s", stringPeerID) return } - - err = transportTo(remoteConn, foreignChannelID, msg) + err := messages.UpdateTransportMsg(msg, peer.ID()) if err != nil { - peer.Log.Errorf("failed to transport message from peer '%s' to '%d': %s", peer.ID(), channelId, err) + peer.Log.Errorf("failed to update transport message: %s", err) + return } + _, err = dp.conn.Write(msg) + if err != nil { + peer.Log.Errorf("failed to write transport message to: %s", dp.String()) + } + return }() } } } -func transportTo(conn net.Conn, channelID uint16, msg []byte) error { - err := messages.UpdateTransportMsg(msg, channelID) - if err != nil { - return err - } - _, err = conn.Write(msg) - return err -} - func handShake(conn net.Conn) (*Peer, error) { buf := make([]byte, 1500) n, err := conn.Read(buf) diff --git a/relay/server/store.go b/relay/server/store.go index fdbc118e4..f785f4d0e 100644 --- a/relay/server/store.go +++ b/relay/server/store.go @@ -5,8 +5,8 @@ import ( ) type Store struct { - peers map[string]*Peer // Key is the id (public key or sha-256) of the peer - peersLock sync.Mutex + peers map[string]*Peer // consider to use [32]byte as key. The Peer(id string) would be faster + peersLock sync.RWMutex } func NewStore() *Store { @@ -18,31 +18,20 @@ func NewStore() *Store { func (s *Store) AddPeer(peer *Peer) { s.peersLock.Lock() defer s.peersLock.Unlock() - s.peers[peer.ID()] = peer -} - -func (s *Store) Link(peer *Peer, peerForeignID string) uint16 { - s.peersLock.Lock() - defer s.peersLock.Unlock() - - channelId := peer.BindChannel(peerForeignID) - dstPeer, ok := s.peers[peerForeignID] - if !ok { - return channelId - } - - foreignChannelID, ok := dstPeer.AddParticipant(peer, channelId) - if !ok { - return channelId - } - peer.AddParticipant(dstPeer, foreignChannelID) - return channelId + s.peers[peer.String()] = peer } func (s *Store) DeletePeer(peer *Peer) { s.peersLock.Lock() defer s.peersLock.Unlock() - delete(s.peers, peer.ID()) - peer.DeleteParticipants() + delete(s.peers, peer.String()) +} + +func (s *Store) Peer(id string) (*Peer, bool) { + s.peersLock.RLock() + defer s.peersLock.RUnlock() + + p, ok := s.peers[id] + return p, ok } diff --git a/relay/test/client_test.go b/relay/test/client_test.go index 87a07fa94..675962eef 100644 --- a/relay/test/client_test.go +++ b/relay/test/client_test.go @@ -50,11 +50,6 @@ func TestClient(t *testing.T) { } defer clientPlaceHolder.Close() - _, err = clientAlice.BindChannel("clientPlaceHolder") - if err != nil { - t.Fatalf("failed to bind channel: %s", err) - } - clientBob := client.NewClient(addr, "bob") err = clientBob.Connect() if err != nil { @@ -62,12 +57,12 @@ func TestClient(t *testing.T) { } defer clientBob.Close() - connAliceToBob, err := clientAlice.BindChannel("bob") + connAliceToBob, err := clientAlice.OpenConn("bob") if err != nil { t.Fatalf("failed to bind channel: %s", err) } - connBobToAlice, err := clientBob.BindChannel("alice") + connBobToAlice, err := clientBob.OpenConn("alice") if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -154,6 +149,8 @@ func TestRegistrationTimeout(t *testing.T) { } func TestEcho(t *testing.T) { + idAlice := "alice" + idBob := "bob" addr := "localhost:1234" srv := server.NewServer() go func() { @@ -170,7 +167,7 @@ func TestEcho(t *testing.T) { } }() - clientAlice := client.NewClient(addr, "alice") + clientAlice := client.NewClient(addr, idAlice) err := clientAlice.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -182,7 +179,7 @@ func TestEcho(t *testing.T) { } }() - clientBob := client.NewClient(addr, "bob") + clientBob := client.NewClient(addr, idBob) err = clientBob.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -194,12 +191,12 @@ func TestEcho(t *testing.T) { } }() - connAliceToBob, err := clientAlice.BindChannel("bob") + connAliceToBob, err := clientAlice.OpenConn(idBob) if err != nil { t.Fatalf("failed to bind channel: %s", err) } - connBobToAlice, err := clientBob.BindChannel("alice") + connBobToAlice, err := clientBob.OpenConn(idAlice) if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -262,7 +259,7 @@ func TestBindToUnavailabePeer(t *testing.T) { } }() - _, err = clientAlice.BindChannel("bob") + _, err = clientAlice.OpenConn("bob") if err != nil { t.Errorf("failed to bind channel: %s", err) } @@ -292,7 +289,7 @@ func TestBindReconnect(t *testing.T) { t.Errorf("failed to connect to server: %s", err) } - _, err = clientAlice.BindChannel("bob") + _, err = clientAlice.OpenConn("bob") if err != nil { t.Errorf("failed to bind channel: %s", err) } @@ -303,7 +300,7 @@ func TestBindReconnect(t *testing.T) { t.Errorf("failed to connect to server: %s", err) } - chBob, err := clientBob.BindChannel("alice") + chBob, err := clientBob.OpenConn("alice") if err != nil { t.Errorf("failed to bind channel: %s", err) } @@ -320,7 +317,7 @@ func TestBindReconnect(t *testing.T) { t.Errorf("failed to connect to server: %s", err) } - chAlice, err := clientAlice.BindChannel("bob") + chAlice, err := clientAlice.OpenConn("bob") if err != nil { t.Errorf("failed to bind channel: %s", err) }