mirror of
https://github.com/netbirdio/netbird.git
synced 2025-03-28 08:37:57 +01:00
abstract peer channel (#101)
* abstract peer channel * remove wip code * refactor NewServer with Peer updates channel * add PeersUpdateManager tests * adding documentation * using older version of linter * verbose lint * skip cache * setup go version * extra output * configure fetch-depth * exit 0 * skip-build-cache: true * disabling failure for lint for now * fix: darwin issue * enable lint failure * remove sock file for macOS * refactor: remove tests interdependence * fixed linux native iface Co-authored-by: braginini <bangvalo@gmail.com>
This commit is contained in:
parent
4f4edf8442
commit
a31cbb1f5b
4
.github/workflows/golangci-lint.yml
vendored
4
.github/workflows/golangci-lint.yml
vendored
@ -11,4 +11,6 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v2
|
||||
uses: golangci/golangci-lint-action@v2
|
||||
|
||||
|
||||
|
@ -38,7 +38,8 @@ func startManagement(config *mgmt.Config, t *testing.T) (*grpc.Server, net.Liste
|
||||
}
|
||||
|
||||
accountManager := mgmt.NewManager(store)
|
||||
mgmtServer, err := mgmt.NewServer(config, accountManager)
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager()
|
||||
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -13,10 +13,13 @@ import (
|
||||
|
||||
const (
|
||||
defaultMTU = 1280
|
||||
WgPort = 51820
|
||||
)
|
||||
|
||||
var tunIface tun.Device
|
||||
var (
|
||||
tunIface tun.Device
|
||||
// todo check after move the WgPort constant to the client
|
||||
WgPort = 51820
|
||||
)
|
||||
|
||||
// CreateWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation
|
||||
func CreateWithUserspace(iface string, address string) error {
|
||||
|
@ -3,6 +3,7 @@ package iface
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
@ -40,5 +41,23 @@ func addRoute(iface string, ipNet *net.IPNet) error {
|
||||
|
||||
// Closes the tunnel interface
|
||||
func Close() error {
|
||||
return CloseWithUserspace()
|
||||
name, err := tunIface.Name()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sockPath := "/var/run/wireguard/" + name + ".sock"
|
||||
|
||||
err = CloseWithUserspace()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := os.Stat(sockPath); err == nil {
|
||||
err = os.Remove(sockPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ func CreateWithKernel(iface string, address string) error {
|
||||
}
|
||||
|
||||
// check if interface exists
|
||||
l, err := netlink.LinkByName(WgInterfaceDefault)
|
||||
l, err := netlink.LinkByName(iface)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case netlink.LinkNotFoundError:
|
||||
@ -148,6 +148,7 @@ func Close() error {
|
||||
return err
|
||||
}
|
||||
for _, wgDev := range devList {
|
||||
// todo check after move the WgPort constant to the client
|
||||
if wgDev.ListenPort == WgPort {
|
||||
iface = wgDev.Name
|
||||
break
|
||||
|
@ -12,33 +12,61 @@ import (
|
||||
|
||||
// keep darwin compability
|
||||
const (
|
||||
ifaceName = "utun999"
|
||||
key = "0PMI6OkB5JmB+Jj/iWWHekuQRx+bipZirWCWKFXexHc="
|
||||
peerPubKey = "Ok0mC0qlJyXEPKh2UFIpsI2jG0L7LRpC3sLAusSJ5CQ="
|
||||
)
|
||||
|
||||
func init() {
|
||||
log.SetLevel(log.DebugLevel)
|
||||
}
|
||||
|
||||
//
|
||||
func Test_CreateInterface(t *testing.T) {
|
||||
level, _ := log.ParseLevel("Debug")
|
||||
log.SetLevel(level)
|
||||
ifaceName := "utun999"
|
||||
wgIP := "10.99.99.1/24"
|
||||
err := Create(ifaceName, wgIP)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
err = Close()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer wg.Close()
|
||||
defer func() {
|
||||
err = wg.Close()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = wg.Device(ifaceName)
|
||||
d, err := wg.Device(ifaceName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// todo move the WgPort constant to the client
|
||||
WgPort = d.ListenPort
|
||||
}
|
||||
|
||||
func Test_ConfigureInterface(t *testing.T) {
|
||||
err := Configure(ifaceName, key)
|
||||
ifaceName := "utun1000"
|
||||
wgIP := "10.99.99.10/24"
|
||||
err := Create(ifaceName, wgIP)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
err = Close()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
err = Configure(ifaceName, key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -47,7 +75,12 @@ func Test_ConfigureInterface(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer wg.Close()
|
||||
defer func() {
|
||||
err = wg.Close()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
wgDevice, err := wg.Device(ifaceName)
|
||||
if err != nil {
|
||||
@ -59,14 +92,30 @@ func Test_ConfigureInterface(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_UpdatePeer(t *testing.T) {
|
||||
keepAlive := 15 * time.Second
|
||||
allowedIP := "10.99.99.2/32"
|
||||
endpoint := "127.0.0.1:9900"
|
||||
err := UpdatePeer(ifaceName, peerPubKey, allowedIP, keepAlive, endpoint)
|
||||
ifaceName := "utun1001"
|
||||
wgIP := "10.99.99.20/24"
|
||||
err := Create(ifaceName, wgIP)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
peer, err := getPeer()
|
||||
defer func() {
|
||||
err = Close()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
err = Configure(ifaceName, key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
keepAlive := 15 * time.Second
|
||||
allowedIP := "10.99.99.2/32"
|
||||
endpoint := "127.0.0.1:9900"
|
||||
err = UpdatePeer(ifaceName, peerPubKey, allowedIP, keepAlive, endpoint)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
peer, err := getPeer(ifaceName, t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -95,13 +144,37 @@ func Test_UpdatePeer(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_UpdatePeerEndpoint(t *testing.T) {
|
||||
newEndpoint := "127.0.0.1:9999"
|
||||
err := UpdatePeerEndpoint(ifaceName, peerPubKey, newEndpoint)
|
||||
ifaceName := "utun1002"
|
||||
wgIP := "10.99.99.30/24"
|
||||
err := Create(ifaceName, wgIP)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
err = Close()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
err = Configure(ifaceName, key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
keepAlive := 15 * time.Second
|
||||
allowedIP := "10.99.99.2/32"
|
||||
endpoint := "127.0.0.1:9900"
|
||||
err = UpdatePeer(ifaceName, peerPubKey, allowedIP, keepAlive, endpoint)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
peer, err := getPeer()
|
||||
newEndpoint := "127.0.0.1:9999"
|
||||
err = UpdatePeerEndpoint(ifaceName, peerPubKey, newEndpoint)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
peer, err := getPeer(ifaceName, t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -112,28 +185,79 @@ func Test_UpdatePeerEndpoint(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_RemovePeer(t *testing.T) {
|
||||
err := RemovePeer(ifaceName, peerPubKey)
|
||||
ifaceName := "utun1003"
|
||||
wgIP := "10.99.99.40/24"
|
||||
err := Create(ifaceName, wgIP)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = getPeer()
|
||||
defer func() {
|
||||
err = Close()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
err = Configure(ifaceName, key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
keepAlive := 15 * time.Second
|
||||
allowedIP := "10.99.99.2/32"
|
||||
endpoint := "127.0.0.1:9900"
|
||||
err = UpdatePeer(ifaceName, peerPubKey, allowedIP, keepAlive, endpoint)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = RemovePeer(ifaceName, peerPubKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = getPeer(ifaceName, t)
|
||||
if err.Error() != "peer not found" {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
func Test_Close(t *testing.T) {
|
||||
err := Close()
|
||||
ifaceName := "utun1004"
|
||||
wgIP := "10.99.99.50/24"
|
||||
err := Create(ifaceName, wgIP)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
err = wg.Close()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
d, err := wg.Device(ifaceName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// todo move the WgPort constant to the client
|
||||
WgPort = d.ListenPort
|
||||
err = Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
func getPeer() (wgtypes.Peer, error) {
|
||||
func getPeer(ifaceName string, t *testing.T) (wgtypes.Peer, error) {
|
||||
emptyPeer := wgtypes.Peer{}
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return emptyPeer, err
|
||||
}
|
||||
defer wg.Close()
|
||||
defer func() {
|
||||
err = wg.Close()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
wgDevice, err := wg.Device(ifaceName)
|
||||
if err != nil {
|
||||
|
@ -61,7 +61,8 @@ func startManagement(config *mgmt.Config, t *testing.T) (*grpc.Server, net.Liste
|
||||
}
|
||||
|
||||
accountManager := mgmt.NewManager(store)
|
||||
mgmtServer, err := mgmt.NewServer(config, accountManager)
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager()
|
||||
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -43,6 +43,7 @@ var (
|
||||
Short: "start Wiretrustee Management Server",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
flag.Parse()
|
||||
InitLog(logLevel)
|
||||
|
||||
config, err := loadConfig()
|
||||
if err != nil {
|
||||
@ -77,8 +78,8 @@ var (
|
||||
|
||||
opts = append(opts, grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
grpcServer := grpc.NewServer(opts...)
|
||||
|
||||
server, err := server.NewServer(config, accountManager)
|
||||
peersUpdateManager := server.NewPeersUpdateManager()
|
||||
server, err := server.NewServer(config, accountManager, peersUpdateManager)
|
||||
if err != nil {
|
||||
log.Fatalf("failed creating new server: %v", err)
|
||||
}
|
||||
|
@ -3,7 +3,6 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/ptypes/timestamp"
|
||||
@ -20,31 +19,26 @@ type Server struct {
|
||||
accountManager *AccountManager
|
||||
wgKey wgtypes.Key
|
||||
proto.UnimplementedManagementServiceServer
|
||||
peerChannels map[string]chan *UpdateChannelMessage
|
||||
channelsMux *sync.Mutex
|
||||
config *Config
|
||||
peersUpdateManager *PeersUpdateManager
|
||||
config *Config
|
||||
}
|
||||
|
||||
// AllowedIPsFormat generates Wireguard AllowedIPs format (e.g. 100.30.30.1/32)
|
||||
const AllowedIPsFormat = "%s/32"
|
||||
|
||||
type UpdateChannelMessage struct {
|
||||
Update *proto.SyncResponse
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
func NewServer(config *Config, accountManager *AccountManager) (*Server, error) {
|
||||
func NewServer(config *Config, accountManager *AccountManager, peersUpdateManager *PeersUpdateManager) (*Server, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Server{
|
||||
wgKey: key,
|
||||
// peerKey -> event channel
|
||||
peerChannels: make(map[string]chan *UpdateChannelMessage),
|
||||
channelsMux: &sync.Mutex{},
|
||||
accountManager: accountManager,
|
||||
config: config,
|
||||
peersUpdateManager: peersUpdateManager,
|
||||
accountManager: accountManager,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -90,8 +84,12 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
return err
|
||||
}
|
||||
|
||||
updates := s.openUpdatesChannel(peerKey.String())
|
||||
|
||||
updates := s.peersUpdateManager.CreateChannel(peerKey.String())
|
||||
err = s.accountManager.MarkPeerConnected(peerKey.String(), true)
|
||||
if err != nil {
|
||||
log.Warnf("failed marking peer as connected %s %v", peerKey, err)
|
||||
}
|
||||
// Todo start turn credentials goroutine
|
||||
// keep a connection to the peer and send updates when available
|
||||
for {
|
||||
select {
|
||||
@ -119,15 +117,18 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
case <-srv.Context().Done():
|
||||
// happens when connection drops, e.g. client disconnects
|
||||
log.Debugf("stream of peer %s has been closed", peerKey.String())
|
||||
s.closeUpdatesChannel(peerKey.String())
|
||||
s.peersUpdateManager.CloseChannel(peerKey.String())
|
||||
err := s.accountManager.MarkPeerConnected(peerKey.String(), false)
|
||||
if err != nil {
|
||||
log.Warnf("failed marking peer as disconnected %s %v", peerKey, err)
|
||||
}
|
||||
// todo stop turn goroutine
|
||||
return srv.Context().Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Peer, error) {
|
||||
s.channelsMux.Lock()
|
||||
defer s.channelsMux.Unlock()
|
||||
|
||||
meta := req.GetMeta()
|
||||
if meta == nil {
|
||||
@ -157,16 +158,18 @@ func (s *Server) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Pe
|
||||
|
||||
// notify other peers of our registration
|
||||
for _, remotePeer := range peers {
|
||||
if channel, ok := s.peerChannels[remotePeer.Key]; ok {
|
||||
// exclude notified peer and add ourselves
|
||||
peersToSend := []*Peer{peer}
|
||||
for _, p := range peers {
|
||||
if remotePeer.Key != p.Key {
|
||||
peersToSend = append(peersToSend, p)
|
||||
}
|
||||
// exclude notified peer and add ourselves
|
||||
peersToSend := []*Peer{peer}
|
||||
for _, p := range peers {
|
||||
if remotePeer.Key != p.Key {
|
||||
peersToSend = append(peersToSend, p)
|
||||
}
|
||||
update := toSyncResponse(s.config, peer, peersToSend)
|
||||
channel <- &UpdateChannelMessage{Update: update}
|
||||
}
|
||||
update := toSyncResponse(s.config, peer, peersToSend)
|
||||
err = s.peersUpdateManager.SendUpdate(remotePeer.Key, &UpdateMessage{Update: update})
|
||||
if err != nil {
|
||||
// todo rethink if we should keep this return
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@ -212,6 +215,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
return nil, status.Error(codes.Internal, "internal server error")
|
||||
}
|
||||
}
|
||||
// Todo fill up turn credentials
|
||||
|
||||
// if peer has reached this point then it has logged in
|
||||
loginResp := &proto.LoginResponse{
|
||||
@ -310,44 +314,6 @@ func (s *Server) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty,
|
||||
return &proto.Empty{}, nil
|
||||
}
|
||||
|
||||
// openUpdatesChannel creates a go channel for a given peer used to deliver updates relevant to the peer.
|
||||
func (s *Server) openUpdatesChannel(peerKey string) chan *UpdateChannelMessage {
|
||||
s.channelsMux.Lock()
|
||||
defer s.channelsMux.Unlock()
|
||||
if channel, ok := s.peerChannels[peerKey]; ok {
|
||||
delete(s.peerChannels, peerKey)
|
||||
close(channel)
|
||||
}
|
||||
//mbragin: todo shouldn't it be more? or configurable?
|
||||
channel := make(chan *UpdateChannelMessage, 100)
|
||||
s.peerChannels[peerKey] = channel
|
||||
|
||||
err := s.accountManager.MarkPeerConnected(peerKey, true)
|
||||
if err != nil {
|
||||
log.Warnf("failed marking peer as connected %s %v", peerKey, err)
|
||||
}
|
||||
|
||||
log.Debugf("opened updates channel for a peer %s", peerKey)
|
||||
return channel
|
||||
}
|
||||
|
||||
// closeUpdatesChannel closes updates channel of a given peer
|
||||
func (s *Server) closeUpdatesChannel(peerKey string) {
|
||||
s.channelsMux.Lock()
|
||||
defer s.channelsMux.Unlock()
|
||||
if channel, ok := s.peerChannels[peerKey]; ok {
|
||||
delete(s.peerChannels, peerKey)
|
||||
close(channel)
|
||||
}
|
||||
|
||||
err := s.accountManager.MarkPeerConnected(peerKey, false)
|
||||
if err != nil {
|
||||
log.Warnf("failed marking peer as disconnected %s %v", peerKey, err)
|
||||
}
|
||||
|
||||
log.Debugf("closed updates channel of a peer %s", peerKey)
|
||||
}
|
||||
|
||||
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
|
||||
func (s *Server) sendInitialSync(peerKey wgtypes.Key, peer *Peer, srv proto.ManagementService_SyncServer) error {
|
||||
|
||||
@ -362,7 +328,7 @@ func (s *Server) sendInitialSync(peerKey wgtypes.Key, peer *Peer, srv proto.Mana
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "error handling request")
|
||||
}
|
||||
|
||||
// Todo fill up the turn credentials
|
||||
err = srv.Send(&proto.EncryptedMessage{
|
||||
WgPubKey: s.wgKey.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
|
@ -491,7 +491,8 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
|
||||
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
|
||||
}
|
||||
accountManager := server.NewManager(store)
|
||||
mgmtServer, err := server.NewServer(config, accountManager)
|
||||
peersUpdateManager := server.NewPeersUpdateManager()
|
||||
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||
go func() {
|
||||
|
64
management/server/updatechannel.go
Normal file
64
management/server/updatechannel.go
Normal file
@ -0,0 +1,64 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/wiretrustee/wiretrustee/management/proto"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type UpdateMessage struct {
|
||||
Update *proto.SyncResponse
|
||||
}
|
||||
type PeersUpdateManager struct {
|
||||
peerChannels map[string]chan *UpdateMessage
|
||||
channelsMux *sync.Mutex
|
||||
}
|
||||
|
||||
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
|
||||
func NewPeersUpdateManager() *PeersUpdateManager {
|
||||
return &PeersUpdateManager{
|
||||
peerChannels: make(map[string]chan *UpdateMessage),
|
||||
channelsMux: &sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
// SendUpdate sends update message to the peer's channel
|
||||
func (p *PeersUpdateManager) SendUpdate(peer string, update *UpdateMessage) error {
|
||||
p.channelsMux.Lock()
|
||||
defer p.channelsMux.Unlock()
|
||||
if channel, ok := p.peerChannels[peer]; ok {
|
||||
channel <- update
|
||||
return nil
|
||||
}
|
||||
log.Debugf("peer %s has no channel", peer)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer.
|
||||
func (p *PeersUpdateManager) CreateChannel(peerKey string) chan *UpdateMessage {
|
||||
p.channelsMux.Lock()
|
||||
defer p.channelsMux.Unlock()
|
||||
|
||||
if channel, ok := p.peerChannels[peerKey]; ok {
|
||||
delete(p.peerChannels, peerKey)
|
||||
close(channel)
|
||||
}
|
||||
//mbragin: todo shouldn't it be more? or configurable?
|
||||
channel := make(chan *UpdateMessage, 100)
|
||||
p.peerChannels[peerKey] = channel
|
||||
|
||||
log.Debugf("opened updates channel for a peer %s", peerKey)
|
||||
return channel
|
||||
}
|
||||
|
||||
// CloseChannel closes updates channel of a given peer
|
||||
func (p *PeersUpdateManager) CloseChannel(peerKey string) {
|
||||
p.channelsMux.Lock()
|
||||
defer p.channelsMux.Unlock()
|
||||
if channel, ok := p.peerChannels[peerKey]; ok {
|
||||
delete(p.peerChannels, peerKey)
|
||||
close(channel)
|
||||
}
|
||||
|
||||
log.Debugf("closed updates channel of a peer %s", peerKey)
|
||||
}
|
49
management/server/updatechannel_test.go
Normal file
49
management/server/updatechannel_test.go
Normal file
@ -0,0 +1,49 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/wiretrustee/wiretrustee/management/proto"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var peersUpdater *PeersUpdateManager
|
||||
|
||||
func TestCreateChannel(t *testing.T) {
|
||||
peer := "test-create"
|
||||
peersUpdater = NewPeersUpdateManager()
|
||||
defer peersUpdater.CloseChannel(peer)
|
||||
|
||||
_ = peersUpdater.CreateChannel(peer)
|
||||
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
||||
t.Error("Error creating the channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendUpdate(t *testing.T) {
|
||||
peer := "test-sendupdate"
|
||||
update := &UpdateMessage{Update: &proto.SyncResponse{}}
|
||||
_ = peersUpdater.CreateChannel(peer)
|
||||
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
||||
t.Error("Error creating the channel")
|
||||
}
|
||||
err := peersUpdater.SendUpdate(peer, update)
|
||||
if err != nil {
|
||||
t.Error("Error sending update: ", err)
|
||||
}
|
||||
select {
|
||||
case <-peersUpdater.peerChannels[peer]:
|
||||
default:
|
||||
t.Error("Update wasn't send")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseChannel(t *testing.T) {
|
||||
peer := "test-close"
|
||||
_ = peersUpdater.CreateChannel(peer)
|
||||
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
||||
t.Error("Error creating the channel")
|
||||
}
|
||||
peersUpdater.CloseChannel(peer)
|
||||
if _, ok := peersUpdater.peerChannels[peer]; ok {
|
||||
t.Error("Error closing the channel")
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user