mirror of
https://github.com/netbirdio/netbird.git
synced 2025-04-10 01:59:01 +02:00
abstract peer channel (#101)
* abstract peer channel * remove wip code * refactor NewServer with Peer updates channel * add PeersUpdateManager tests * adding documentation * using older version of linter * verbose lint * skip cache * setup go version * extra output * configure fetch-depth * exit 0 * skip-build-cache: true * disabling failure for lint for now * fix: darwin issue * enable lint failure * remove sock file for macOS * refactor: remove tests interdependence * fixed linux native iface Co-authored-by: braginini <bangvalo@gmail.com>
This commit is contained in:
parent
4f4edf8442
commit
a31cbb1f5b
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@ -12,3 +12,5 @@ jobs:
|
|||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: golangci/golangci-lint-action@v2
|
uses: golangci/golangci-lint-action@v2
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,7 +38,8 @@ func startManagement(config *mgmt.Config, t *testing.T) (*grpc.Server, net.Liste
|
|||||||
}
|
}
|
||||||
|
|
||||||
accountManager := mgmt.NewManager(store)
|
accountManager := mgmt.NewManager(store)
|
||||||
mgmtServer, err := mgmt.NewServer(config, accountManager)
|
peersUpdateManager := mgmt.NewPeersUpdateManager()
|
||||||
|
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -13,10 +13,13 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
defaultMTU = 1280
|
defaultMTU = 1280
|
||||||
WgPort = 51820
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var tunIface tun.Device
|
var (
|
||||||
|
tunIface tun.Device
|
||||||
|
// todo check after move the WgPort constant to the client
|
||||||
|
WgPort = 51820
|
||||||
|
)
|
||||||
|
|
||||||
// CreateWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation
|
// CreateWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation
|
||||||
func CreateWithUserspace(iface string, address string) error {
|
func CreateWithUserspace(iface string, address string) error {
|
||||||
|
@ -3,6 +3,7 @@ package iface
|
|||||||
import (
|
import (
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@ -40,5 +41,23 @@ func addRoute(iface string, ipNet *net.IPNet) error {
|
|||||||
|
|
||||||
// Closes the tunnel interface
|
// Closes the tunnel interface
|
||||||
func Close() error {
|
func Close() error {
|
||||||
return CloseWithUserspace()
|
name, err := tunIface.Name()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sockPath := "/var/run/wireguard/" + name + ".sock"
|
||||||
|
|
||||||
|
err = CloseWithUserspace()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(sockPath); err == nil {
|
||||||
|
err = os.Remove(sockPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -31,7 +31,7 @@ func CreateWithKernel(iface string, address string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check if interface exists
|
// check if interface exists
|
||||||
l, err := netlink.LinkByName(WgInterfaceDefault)
|
l, err := netlink.LinkByName(iface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch err.(type) {
|
switch err.(type) {
|
||||||
case netlink.LinkNotFoundError:
|
case netlink.LinkNotFoundError:
|
||||||
@ -148,6 +148,7 @@ func Close() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, wgDev := range devList {
|
for _, wgDev := range devList {
|
||||||
|
// todo check after move the WgPort constant to the client
|
||||||
if wgDev.ListenPort == WgPort {
|
if wgDev.ListenPort == WgPort {
|
||||||
iface = wgDev.Name
|
iface = wgDev.Name
|
||||||
break
|
break
|
||||||
|
@ -12,33 +12,61 @@ import (
|
|||||||
|
|
||||||
// keep darwin compability
|
// keep darwin compability
|
||||||
const (
|
const (
|
||||||
ifaceName = "utun999"
|
|
||||||
key = "0PMI6OkB5JmB+Jj/iWWHekuQRx+bipZirWCWKFXexHc="
|
key = "0PMI6OkB5JmB+Jj/iWWHekuQRx+bipZirWCWKFXexHc="
|
||||||
peerPubKey = "Ok0mC0qlJyXEPKh2UFIpsI2jG0L7LRpC3sLAusSJ5CQ="
|
peerPubKey = "Ok0mC0qlJyXEPKh2UFIpsI2jG0L7LRpC3sLAusSJ5CQ="
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
log.SetLevel(log.DebugLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
func Test_CreateInterface(t *testing.T) {
|
func Test_CreateInterface(t *testing.T) {
|
||||||
level, _ := log.ParseLevel("Debug")
|
ifaceName := "utun999"
|
||||||
log.SetLevel(level)
|
|
||||||
wgIP := "10.99.99.1/24"
|
wgIP := "10.99.99.1/24"
|
||||||
err := Create(ifaceName, wgIP)
|
err := Create(ifaceName, wgIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
defer func() {
|
||||||
|
err = Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
wg, err := wgctrl.New()
|
wg, err := wgctrl.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer wg.Close()
|
defer func() {
|
||||||
|
err = wg.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
_, err = wg.Device(ifaceName)
|
d, err := wg.Device(ifaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
// todo move the WgPort constant to the client
|
||||||
|
WgPort = d.ListenPort
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_ConfigureInterface(t *testing.T) {
|
func Test_ConfigureInterface(t *testing.T) {
|
||||||
err := Configure(ifaceName, key)
|
ifaceName := "utun1000"
|
||||||
|
wgIP := "10.99.99.10/24"
|
||||||
|
err := Create(ifaceName, wgIP)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err = Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = Configure(ifaceName, key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -47,7 +75,12 @@ func Test_ConfigureInterface(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer wg.Close()
|
defer func() {
|
||||||
|
err = wg.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
wgDevice, err := wg.Device(ifaceName)
|
wgDevice, err := wg.Device(ifaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -59,14 +92,30 @@ func Test_ConfigureInterface(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Test_UpdatePeer(t *testing.T) {
|
func Test_UpdatePeer(t *testing.T) {
|
||||||
keepAlive := 15 * time.Second
|
ifaceName := "utun1001"
|
||||||
allowedIP := "10.99.99.2/32"
|
wgIP := "10.99.99.20/24"
|
||||||
endpoint := "127.0.0.1:9900"
|
err := Create(ifaceName, wgIP)
|
||||||
err := UpdatePeer(ifaceName, peerPubKey, allowedIP, keepAlive, endpoint)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
peer, err := getPeer()
|
defer func() {
|
||||||
|
err = Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
err = Configure(ifaceName, key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
keepAlive := 15 * time.Second
|
||||||
|
allowedIP := "10.99.99.2/32"
|
||||||
|
endpoint := "127.0.0.1:9900"
|
||||||
|
err = UpdatePeer(ifaceName, peerPubKey, allowedIP, keepAlive, endpoint)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
peer, err := getPeer(ifaceName, t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -95,13 +144,37 @@ func Test_UpdatePeer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Test_UpdatePeerEndpoint(t *testing.T) {
|
func Test_UpdatePeerEndpoint(t *testing.T) {
|
||||||
newEndpoint := "127.0.0.1:9999"
|
ifaceName := "utun1002"
|
||||||
err := UpdatePeerEndpoint(ifaceName, peerPubKey, newEndpoint)
|
wgIP := "10.99.99.30/24"
|
||||||
|
err := Create(ifaceName, wgIP)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err = Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
err = Configure(ifaceName, key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
keepAlive := 15 * time.Second
|
||||||
|
allowedIP := "10.99.99.2/32"
|
||||||
|
endpoint := "127.0.0.1:9900"
|
||||||
|
err = UpdatePeer(ifaceName, peerPubKey, allowedIP, keepAlive, endpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
peer, err := getPeer()
|
newEndpoint := "127.0.0.1:9999"
|
||||||
|
err = UpdatePeerEndpoint(ifaceName, peerPubKey, newEndpoint)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
peer, err := getPeer(ifaceName, t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -112,28 +185,79 @@ func Test_UpdatePeerEndpoint(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Test_RemovePeer(t *testing.T) {
|
func Test_RemovePeer(t *testing.T) {
|
||||||
err := RemovePeer(ifaceName, peerPubKey)
|
ifaceName := "utun1003"
|
||||||
|
wgIP := "10.99.99.40/24"
|
||||||
|
err := Create(ifaceName, wgIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
_, err = getPeer()
|
defer func() {
|
||||||
|
err = Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
err = Configure(ifaceName, key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
keepAlive := 15 * time.Second
|
||||||
|
allowedIP := "10.99.99.2/32"
|
||||||
|
endpoint := "127.0.0.1:9900"
|
||||||
|
err = UpdatePeer(ifaceName, peerPubKey, allowedIP, keepAlive, endpoint)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
err = RemovePeer(ifaceName, peerPubKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_, err = getPeer(ifaceName, t)
|
||||||
if err.Error() != "peer not found" {
|
if err.Error() != "peer not found" {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func Test_Close(t *testing.T) {
|
func Test_Close(t *testing.T) {
|
||||||
err := Close()
|
ifaceName := "utun1004"
|
||||||
|
wgIP := "10.99.99.50/24"
|
||||||
|
err := Create(ifaceName, wgIP)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
wg, err := wgctrl.New()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err = wg.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
d, err := wg.Device(ifaceName)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// todo move the WgPort constant to the client
|
||||||
|
WgPort = d.ListenPort
|
||||||
|
err = Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func getPeer() (wgtypes.Peer, error) {
|
func getPeer(ifaceName string, t *testing.T) (wgtypes.Peer, error) {
|
||||||
emptyPeer := wgtypes.Peer{}
|
emptyPeer := wgtypes.Peer{}
|
||||||
wg, err := wgctrl.New()
|
wg, err := wgctrl.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return emptyPeer, err
|
return emptyPeer, err
|
||||||
}
|
}
|
||||||
defer wg.Close()
|
defer func() {
|
||||||
|
err = wg.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
wgDevice, err := wg.Device(ifaceName)
|
wgDevice, err := wg.Device(ifaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -61,7 +61,8 @@ func startManagement(config *mgmt.Config, t *testing.T) (*grpc.Server, net.Liste
|
|||||||
}
|
}
|
||||||
|
|
||||||
accountManager := mgmt.NewManager(store)
|
accountManager := mgmt.NewManager(store)
|
||||||
mgmtServer, err := mgmt.NewServer(config, accountManager)
|
peersUpdateManager := mgmt.NewPeersUpdateManager()
|
||||||
|
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -43,6 +43,7 @@ var (
|
|||||||
Short: "start Wiretrustee Management Server",
|
Short: "start Wiretrustee Management Server",
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
InitLog(logLevel)
|
||||||
|
|
||||||
config, err := loadConfig()
|
config, err := loadConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -77,8 +78,8 @@ var (
|
|||||||
|
|
||||||
opts = append(opts, grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
opts = append(opts, grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||||
grpcServer := grpc.NewServer(opts...)
|
grpcServer := grpc.NewServer(opts...)
|
||||||
|
peersUpdateManager := server.NewPeersUpdateManager()
|
||||||
server, err := server.NewServer(config, accountManager)
|
server, err := server.NewServer(config, accountManager, peersUpdateManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed creating new server: %v", err)
|
log.Fatalf("failed creating new server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,6 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/protobuf/ptypes/timestamp"
|
"github.com/golang/protobuf/ptypes/timestamp"
|
||||||
@ -20,31 +19,26 @@ type Server struct {
|
|||||||
accountManager *AccountManager
|
accountManager *AccountManager
|
||||||
wgKey wgtypes.Key
|
wgKey wgtypes.Key
|
||||||
proto.UnimplementedManagementServiceServer
|
proto.UnimplementedManagementServiceServer
|
||||||
peerChannels map[string]chan *UpdateChannelMessage
|
peersUpdateManager *PeersUpdateManager
|
||||||
channelsMux *sync.Mutex
|
config *Config
|
||||||
config *Config
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllowedIPsFormat generates Wireguard AllowedIPs format (e.g. 100.30.30.1/32)
|
// AllowedIPsFormat generates Wireguard AllowedIPs format (e.g. 100.30.30.1/32)
|
||||||
const AllowedIPsFormat = "%s/32"
|
const AllowedIPsFormat = "%s/32"
|
||||||
|
|
||||||
type UpdateChannelMessage struct {
|
|
||||||
Update *proto.SyncResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewServer creates a new Management server
|
// NewServer creates a new Management server
|
||||||
func NewServer(config *Config, accountManager *AccountManager) (*Server, error) {
|
func NewServer(config *Config, accountManager *AccountManager, peersUpdateManager *PeersUpdateManager) (*Server, error) {
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Server{
|
return &Server{
|
||||||
wgKey: key,
|
wgKey: key,
|
||||||
// peerKey -> event channel
|
// peerKey -> event channel
|
||||||
peerChannels: make(map[string]chan *UpdateChannelMessage),
|
peersUpdateManager: peersUpdateManager,
|
||||||
channelsMux: &sync.Mutex{},
|
accountManager: accountManager,
|
||||||
accountManager: accountManager,
|
config: config,
|
||||||
config: config,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -90,8 +84,12 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
updates := s.openUpdatesChannel(peerKey.String())
|
updates := s.peersUpdateManager.CreateChannel(peerKey.String())
|
||||||
|
err = s.accountManager.MarkPeerConnected(peerKey.String(), true)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed marking peer as connected %s %v", peerKey, err)
|
||||||
|
}
|
||||||
|
// Todo start turn credentials goroutine
|
||||||
// keep a connection to the peer and send updates when available
|
// keep a connection to the peer and send updates when available
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@ -119,15 +117,18 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
case <-srv.Context().Done():
|
case <-srv.Context().Done():
|
||||||
// happens when connection drops, e.g. client disconnects
|
// happens when connection drops, e.g. client disconnects
|
||||||
log.Debugf("stream of peer %s has been closed", peerKey.String())
|
log.Debugf("stream of peer %s has been closed", peerKey.String())
|
||||||
s.closeUpdatesChannel(peerKey.String())
|
s.peersUpdateManager.CloseChannel(peerKey.String())
|
||||||
|
err := s.accountManager.MarkPeerConnected(peerKey.String(), false)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed marking peer as disconnected %s %v", peerKey, err)
|
||||||
|
}
|
||||||
|
// todo stop turn goroutine
|
||||||
return srv.Context().Err()
|
return srv.Context().Err()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Peer, error) {
|
func (s *Server) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Peer, error) {
|
||||||
s.channelsMux.Lock()
|
|
||||||
defer s.channelsMux.Unlock()
|
|
||||||
|
|
||||||
meta := req.GetMeta()
|
meta := req.GetMeta()
|
||||||
if meta == nil {
|
if meta == nil {
|
||||||
@ -157,16 +158,18 @@ func (s *Server) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Pe
|
|||||||
|
|
||||||
// notify other peers of our registration
|
// notify other peers of our registration
|
||||||
for _, remotePeer := range peers {
|
for _, remotePeer := range peers {
|
||||||
if channel, ok := s.peerChannels[remotePeer.Key]; ok {
|
// exclude notified peer and add ourselves
|
||||||
// exclude notified peer and add ourselves
|
peersToSend := []*Peer{peer}
|
||||||
peersToSend := []*Peer{peer}
|
for _, p := range peers {
|
||||||
for _, p := range peers {
|
if remotePeer.Key != p.Key {
|
||||||
if remotePeer.Key != p.Key {
|
peersToSend = append(peersToSend, p)
|
||||||
peersToSend = append(peersToSend, p)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
update := toSyncResponse(s.config, peer, peersToSend)
|
}
|
||||||
channel <- &UpdateChannelMessage{Update: update}
|
update := toSyncResponse(s.config, peer, peersToSend)
|
||||||
|
err = s.peersUpdateManager.SendUpdate(remotePeer.Key, &UpdateMessage{Update: update})
|
||||||
|
if err != nil {
|
||||||
|
// todo rethink if we should keep this return
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -212,6 +215,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
|||||||
return nil, status.Error(codes.Internal, "internal server error")
|
return nil, status.Error(codes.Internal, "internal server error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Todo fill up turn credentials
|
||||||
|
|
||||||
// if peer has reached this point then it has logged in
|
// if peer has reached this point then it has logged in
|
||||||
loginResp := &proto.LoginResponse{
|
loginResp := &proto.LoginResponse{
|
||||||
@ -310,44 +314,6 @@ func (s *Server) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty,
|
|||||||
return &proto.Empty{}, nil
|
return &proto.Empty{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// openUpdatesChannel creates a go channel for a given peer used to deliver updates relevant to the peer.
|
|
||||||
func (s *Server) openUpdatesChannel(peerKey string) chan *UpdateChannelMessage {
|
|
||||||
s.channelsMux.Lock()
|
|
||||||
defer s.channelsMux.Unlock()
|
|
||||||
if channel, ok := s.peerChannels[peerKey]; ok {
|
|
||||||
delete(s.peerChannels, peerKey)
|
|
||||||
close(channel)
|
|
||||||
}
|
|
||||||
//mbragin: todo shouldn't it be more? or configurable?
|
|
||||||
channel := make(chan *UpdateChannelMessage, 100)
|
|
||||||
s.peerChannels[peerKey] = channel
|
|
||||||
|
|
||||||
err := s.accountManager.MarkPeerConnected(peerKey, true)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed marking peer as connected %s %v", peerKey, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("opened updates channel for a peer %s", peerKey)
|
|
||||||
return channel
|
|
||||||
}
|
|
||||||
|
|
||||||
// closeUpdatesChannel closes updates channel of a given peer
|
|
||||||
func (s *Server) closeUpdatesChannel(peerKey string) {
|
|
||||||
s.channelsMux.Lock()
|
|
||||||
defer s.channelsMux.Unlock()
|
|
||||||
if channel, ok := s.peerChannels[peerKey]; ok {
|
|
||||||
delete(s.peerChannels, peerKey)
|
|
||||||
close(channel)
|
|
||||||
}
|
|
||||||
|
|
||||||
err := s.accountManager.MarkPeerConnected(peerKey, false)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed marking peer as disconnected %s %v", peerKey, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("closed updates channel of a peer %s", peerKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
|
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
|
||||||
func (s *Server) sendInitialSync(peerKey wgtypes.Key, peer *Peer, srv proto.ManagementService_SyncServer) error {
|
func (s *Server) sendInitialSync(peerKey wgtypes.Key, peer *Peer, srv proto.ManagementService_SyncServer) error {
|
||||||
|
|
||||||
@ -362,7 +328,7 @@ func (s *Server) sendInitialSync(peerKey wgtypes.Key, peer *Peer, srv proto.Mana
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.Internal, "error handling request")
|
return status.Errorf(codes.Internal, "error handling request")
|
||||||
}
|
}
|
||||||
|
// Todo fill up the turn credentials
|
||||||
err = srv.Send(&proto.EncryptedMessage{
|
err = srv.Send(&proto.EncryptedMessage{
|
||||||
WgPubKey: s.wgKey.PublicKey().String(),
|
WgPubKey: s.wgKey.PublicKey().String(),
|
||||||
Body: encryptedResp,
|
Body: encryptedResp,
|
||||||
|
@ -491,7 +491,8 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
|
|||||||
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
|
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
|
||||||
}
|
}
|
||||||
accountManager := server.NewManager(store)
|
accountManager := server.NewManager(store)
|
||||||
mgmtServer, err := server.NewServer(config, accountManager)
|
peersUpdateManager := server.NewPeersUpdateManager()
|
||||||
|
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||||
go func() {
|
go func() {
|
||||||
|
64
management/server/updatechannel.go
Normal file
64
management/server/updatechannel.go
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/wiretrustee/wiretrustee/management/proto"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UpdateMessage struct {
|
||||||
|
Update *proto.SyncResponse
|
||||||
|
}
|
||||||
|
type PeersUpdateManager struct {
|
||||||
|
peerChannels map[string]chan *UpdateMessage
|
||||||
|
channelsMux *sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
|
||||||
|
func NewPeersUpdateManager() *PeersUpdateManager {
|
||||||
|
return &PeersUpdateManager{
|
||||||
|
peerChannels: make(map[string]chan *UpdateMessage),
|
||||||
|
channelsMux: &sync.Mutex{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendUpdate sends update message to the peer's channel
|
||||||
|
func (p *PeersUpdateManager) SendUpdate(peer string, update *UpdateMessage) error {
|
||||||
|
p.channelsMux.Lock()
|
||||||
|
defer p.channelsMux.Unlock()
|
||||||
|
if channel, ok := p.peerChannels[peer]; ok {
|
||||||
|
channel <- update
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
log.Debugf("peer %s has no channel", peer)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer.
|
||||||
|
func (p *PeersUpdateManager) CreateChannel(peerKey string) chan *UpdateMessage {
|
||||||
|
p.channelsMux.Lock()
|
||||||
|
defer p.channelsMux.Unlock()
|
||||||
|
|
||||||
|
if channel, ok := p.peerChannels[peerKey]; ok {
|
||||||
|
delete(p.peerChannels, peerKey)
|
||||||
|
close(channel)
|
||||||
|
}
|
||||||
|
//mbragin: todo shouldn't it be more? or configurable?
|
||||||
|
channel := make(chan *UpdateMessage, 100)
|
||||||
|
p.peerChannels[peerKey] = channel
|
||||||
|
|
||||||
|
log.Debugf("opened updates channel for a peer %s", peerKey)
|
||||||
|
return channel
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseChannel closes updates channel of a given peer
|
||||||
|
func (p *PeersUpdateManager) CloseChannel(peerKey string) {
|
||||||
|
p.channelsMux.Lock()
|
||||||
|
defer p.channelsMux.Unlock()
|
||||||
|
if channel, ok := p.peerChannels[peerKey]; ok {
|
||||||
|
delete(p.peerChannels, peerKey)
|
||||||
|
close(channel)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("closed updates channel of a peer %s", peerKey)
|
||||||
|
}
|
49
management/server/updatechannel_test.go
Normal file
49
management/server/updatechannel_test.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/wiretrustee/wiretrustee/management/proto"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var peersUpdater *PeersUpdateManager
|
||||||
|
|
||||||
|
func TestCreateChannel(t *testing.T) {
|
||||||
|
peer := "test-create"
|
||||||
|
peersUpdater = NewPeersUpdateManager()
|
||||||
|
defer peersUpdater.CloseChannel(peer)
|
||||||
|
|
||||||
|
_ = peersUpdater.CreateChannel(peer)
|
||||||
|
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
||||||
|
t.Error("Error creating the channel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendUpdate(t *testing.T) {
|
||||||
|
peer := "test-sendupdate"
|
||||||
|
update := &UpdateMessage{Update: &proto.SyncResponse{}}
|
||||||
|
_ = peersUpdater.CreateChannel(peer)
|
||||||
|
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
||||||
|
t.Error("Error creating the channel")
|
||||||
|
}
|
||||||
|
err := peersUpdater.SendUpdate(peer, update)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Error sending update: ", err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-peersUpdater.peerChannels[peer]:
|
||||||
|
default:
|
||||||
|
t.Error("Update wasn't send")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloseChannel(t *testing.T) {
|
||||||
|
peer := "test-close"
|
||||||
|
_ = peersUpdater.CreateChannel(peer)
|
||||||
|
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
||||||
|
t.Error("Error creating the channel")
|
||||||
|
}
|
||||||
|
peersUpdater.CloseChannel(peer)
|
||||||
|
if _, ok := peersUpdater.peerChannels[peer]; ok {
|
||||||
|
t.Error("Error closing the channel")
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user