diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 956fcd170..d302dc89f 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -11,4 +11,6 @@ jobs: steps: - uses: actions/checkout@v2 - name: golangci-lint - uses: golangci/golangci-lint-action@v2 \ No newline at end of file + uses: golangci/golangci-lint-action@v2 + + diff --git a/client/cmd/testutil.go b/client/cmd/testutil.go index 436217002..0cb7eb7fd 100644 --- a/client/cmd/testutil.go +++ b/client/cmd/testutil.go @@ -38,7 +38,8 @@ func startManagement(config *mgmt.Config, t *testing.T) (*grpc.Server, net.Liste } accountManager := mgmt.NewManager(store) - mgmtServer, err := mgmt.NewServer(config, accountManager) + peersUpdateManager := mgmt.NewPeersUpdateManager() + mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager) if err != nil { t.Fatal(err) } diff --git a/iface/iface.go b/iface/iface.go index e6c9a3bcd..f1da4bdd3 100644 --- a/iface/iface.go +++ b/iface/iface.go @@ -13,10 +13,13 @@ import ( const ( defaultMTU = 1280 - WgPort = 51820 ) -var tunIface tun.Device +var ( + tunIface tun.Device + // todo check after move the WgPort constant to the client + WgPort = 51820 +) // CreateWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation func CreateWithUserspace(iface string, address string) error { diff --git a/iface/iface_darwin.go b/iface/iface_darwin.go index 85a8825e2..e93db7759 100644 --- a/iface/iface_darwin.go +++ b/iface/iface_darwin.go @@ -3,6 +3,7 @@ package iface import ( log "github.com/sirupsen/logrus" "net" + "os" "os/exec" "strings" ) @@ -40,5 +41,23 @@ func addRoute(iface string, ipNet *net.IPNet) error { // Closes the tunnel interface func Close() error { - return CloseWithUserspace() + name, err := tunIface.Name() + if err != nil { + return err + } + + sockPath := "/var/run/wireguard/" + name + ".sock" + + err = CloseWithUserspace() + if err != nil { + return err + } + + if _, err := os.Stat(sockPath); err == nil { + err = os.Remove(sockPath) + if err != nil { + return err + } + } + return nil } diff --git a/iface/iface_linux.go b/iface/iface_linux.go index f2c0300e3..d2ffe88b6 100644 --- a/iface/iface_linux.go +++ b/iface/iface_linux.go @@ -31,7 +31,7 @@ func CreateWithKernel(iface string, address string) error { } // check if interface exists - l, err := netlink.LinkByName(WgInterfaceDefault) + l, err := netlink.LinkByName(iface) if err != nil { switch err.(type) { case netlink.LinkNotFoundError: @@ -148,6 +148,7 @@ func Close() error { return err } for _, wgDev := range devList { + // todo check after move the WgPort constant to the client if wgDev.ListenPort == WgPort { iface = wgDev.Name break diff --git a/iface/iface_test.go b/iface/iface_test.go index b21f41e57..f54ecccc6 100644 --- a/iface/iface_test.go +++ b/iface/iface_test.go @@ -12,33 +12,61 @@ import ( // keep darwin compability const ( - ifaceName = "utun999" key = "0PMI6OkB5JmB+Jj/iWWHekuQRx+bipZirWCWKFXexHc=" peerPubKey = "Ok0mC0qlJyXEPKh2UFIpsI2jG0L7LRpC3sLAusSJ5CQ=" ) +func init() { + log.SetLevel(log.DebugLevel) +} + +// func Test_CreateInterface(t *testing.T) { - level, _ := log.ParseLevel("Debug") - log.SetLevel(level) + ifaceName := "utun999" wgIP := "10.99.99.1/24" err := Create(ifaceName, wgIP) if err != nil { t.Fatal(err) } + defer func() { + err = Close() + if err != nil { + t.Error(err) + } + }() wg, err := wgctrl.New() if err != nil { t.Fatal(err) } - defer wg.Close() + defer func() { + err = wg.Close() + if err != nil { + t.Error(err) + } + }() - _, err = wg.Device(ifaceName) + d, err := wg.Device(ifaceName) if err != nil { t.Fatal(err) } + // todo move the WgPort constant to the client + WgPort = d.ListenPort } - func Test_ConfigureInterface(t *testing.T) { - err := Configure(ifaceName, key) + ifaceName := "utun1000" + wgIP := "10.99.99.10/24" + err := Create(ifaceName, wgIP) + if err != nil { + t.Fatal(err) + } + defer func() { + err = Close() + if err != nil { + t.Error(err) + } + }() + + err = Configure(ifaceName, key) if err != nil { t.Fatal(err) } @@ -47,7 +75,12 @@ func Test_ConfigureInterface(t *testing.T) { if err != nil { t.Fatal(err) } - defer wg.Close() + defer func() { + err = wg.Close() + if err != nil { + t.Error(err) + } + }() wgDevice, err := wg.Device(ifaceName) if err != nil { @@ -59,14 +92,30 @@ func Test_ConfigureInterface(t *testing.T) { } func Test_UpdatePeer(t *testing.T) { - keepAlive := 15 * time.Second - allowedIP := "10.99.99.2/32" - endpoint := "127.0.0.1:9900" - err := UpdatePeer(ifaceName, peerPubKey, allowedIP, keepAlive, endpoint) + ifaceName := "utun1001" + wgIP := "10.99.99.20/24" + err := Create(ifaceName, wgIP) if err != nil { t.Fatal(err) } - peer, err := getPeer() + defer func() { + err = Close() + if err != nil { + t.Error(err) + } + }() + err = Configure(ifaceName, key) + if err != nil { + t.Fatal(err) + } + keepAlive := 15 * time.Second + allowedIP := "10.99.99.2/32" + endpoint := "127.0.0.1:9900" + err = UpdatePeer(ifaceName, peerPubKey, allowedIP, keepAlive, endpoint) + if err != nil { + t.Fatal(err) + } + peer, err := getPeer(ifaceName, t) if err != nil { t.Fatal(err) } @@ -95,13 +144,37 @@ func Test_UpdatePeer(t *testing.T) { } func Test_UpdatePeerEndpoint(t *testing.T) { - newEndpoint := "127.0.0.1:9999" - err := UpdatePeerEndpoint(ifaceName, peerPubKey, newEndpoint) + ifaceName := "utun1002" + wgIP := "10.99.99.30/24" + err := Create(ifaceName, wgIP) + if err != nil { + t.Fatal(err) + } + defer func() { + err = Close() + if err != nil { + t.Error(err) + } + }() + err = Configure(ifaceName, key) + if err != nil { + t.Fatal(err) + } + keepAlive := 15 * time.Second + allowedIP := "10.99.99.2/32" + endpoint := "127.0.0.1:9900" + err = UpdatePeer(ifaceName, peerPubKey, allowedIP, keepAlive, endpoint) if err != nil { t.Fatal(err) } - peer, err := getPeer() + newEndpoint := "127.0.0.1:9999" + err = UpdatePeerEndpoint(ifaceName, peerPubKey, newEndpoint) + if err != nil { + t.Fatal(err) + } + + peer, err := getPeer(ifaceName, t) if err != nil { t.Fatal(err) } @@ -112,28 +185,79 @@ func Test_UpdatePeerEndpoint(t *testing.T) { } func Test_RemovePeer(t *testing.T) { - err := RemovePeer(ifaceName, peerPubKey) + ifaceName := "utun1003" + wgIP := "10.99.99.40/24" + err := Create(ifaceName, wgIP) if err != nil { t.Fatal(err) } - _, err = getPeer() + defer func() { + err = Close() + if err != nil { + t.Error(err) + } + }() + err = Configure(ifaceName, key) + if err != nil { + t.Fatal(err) + } + keepAlive := 15 * time.Second + allowedIP := "10.99.99.2/32" + endpoint := "127.0.0.1:9900" + err = UpdatePeer(ifaceName, peerPubKey, allowedIP, keepAlive, endpoint) + if err != nil { + t.Fatal(err) + } + err = RemovePeer(ifaceName, peerPubKey) + if err != nil { + t.Fatal(err) + } + _, err = getPeer(ifaceName, t) if err.Error() != "peer not found" { t.Fatal(err) } } func Test_Close(t *testing.T) { - err := Close() + ifaceName := "utun1004" + wgIP := "10.99.99.50/24" + err := Create(ifaceName, wgIP) + if err != nil { + t.Fatal(err) + } + wg, err := wgctrl.New() + if err != nil { + t.Fatal(err) + } + defer func() { + err = wg.Close() + if err != nil { + t.Error(err) + } + }() + + d, err := wg.Device(ifaceName) + if err != nil { + t.Fatal(err) + } + // todo move the WgPort constant to the client + WgPort = d.ListenPort + err = Close() if err != nil { t.Fatal(err) } } -func getPeer() (wgtypes.Peer, error) { +func getPeer(ifaceName string, t *testing.T) (wgtypes.Peer, error) { emptyPeer := wgtypes.Peer{} wg, err := wgctrl.New() if err != nil { return emptyPeer, err } - defer wg.Close() + defer func() { + err = wg.Close() + if err != nil { + t.Error(err) + } + }() wgDevice, err := wg.Device(ifaceName) if err != nil { diff --git a/management/client/client_test.go b/management/client/client_test.go index 9b75182ef..80bf7967b 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -61,7 +61,8 @@ func startManagement(config *mgmt.Config, t *testing.T) (*grpc.Server, net.Liste } accountManager := mgmt.NewManager(store) - mgmtServer, err := mgmt.NewServer(config, accountManager) + peersUpdateManager := mgmt.NewPeersUpdateManager() + mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager) if err != nil { t.Fatal(err) } diff --git a/management/cmd/management.go b/management/cmd/management.go index 223d542a3..dfc222d53 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -43,6 +43,7 @@ var ( Short: "start Wiretrustee Management Server", Run: func(cmd *cobra.Command, args []string) { flag.Parse() + InitLog(logLevel) config, err := loadConfig() if err != nil { @@ -77,8 +78,8 @@ var ( opts = append(opts, grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) grpcServer := grpc.NewServer(opts...) - - server, err := server.NewServer(config, accountManager) + peersUpdateManager := server.NewPeersUpdateManager() + server, err := server.NewServer(config, accountManager, peersUpdateManager) if err != nil { log.Fatalf("failed creating new server: %v", err) } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 86ed47622..1c80bc5f3 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -3,7 +3,6 @@ package server import ( "context" "fmt" - "sync" "time" "github.com/golang/protobuf/ptypes/timestamp" @@ -20,31 +19,26 @@ type Server struct { accountManager *AccountManager wgKey wgtypes.Key proto.UnimplementedManagementServiceServer - peerChannels map[string]chan *UpdateChannelMessage - channelsMux *sync.Mutex - config *Config + peersUpdateManager *PeersUpdateManager + config *Config } // AllowedIPsFormat generates Wireguard AllowedIPs format (e.g. 100.30.30.1/32) const AllowedIPsFormat = "%s/32" -type UpdateChannelMessage struct { - Update *proto.SyncResponse -} - // NewServer creates a new Management server -func NewServer(config *Config, accountManager *AccountManager) (*Server, error) { +func NewServer(config *Config, accountManager *AccountManager, peersUpdateManager *PeersUpdateManager) (*Server, error) { key, err := wgtypes.GeneratePrivateKey() if err != nil { return nil, err } + return &Server{ wgKey: key, // peerKey -> event channel - peerChannels: make(map[string]chan *UpdateChannelMessage), - channelsMux: &sync.Mutex{}, - accountManager: accountManager, - config: config, + peersUpdateManager: peersUpdateManager, + accountManager: accountManager, + config: config, }, nil } @@ -90,8 +84,12 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S return err } - updates := s.openUpdatesChannel(peerKey.String()) - + updates := s.peersUpdateManager.CreateChannel(peerKey.String()) + err = s.accountManager.MarkPeerConnected(peerKey.String(), true) + if err != nil { + log.Warnf("failed marking peer as connected %s %v", peerKey, err) + } + // Todo start turn credentials goroutine // keep a connection to the peer and send updates when available for { select { @@ -119,15 +117,18 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S case <-srv.Context().Done(): // happens when connection drops, e.g. client disconnects log.Debugf("stream of peer %s has been closed", peerKey.String()) - s.closeUpdatesChannel(peerKey.String()) + s.peersUpdateManager.CloseChannel(peerKey.String()) + err := s.accountManager.MarkPeerConnected(peerKey.String(), false) + if err != nil { + log.Warnf("failed marking peer as disconnected %s %v", peerKey, err) + } + // todo stop turn goroutine return srv.Context().Err() } } } func (s *Server) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Peer, error) { - s.channelsMux.Lock() - defer s.channelsMux.Unlock() meta := req.GetMeta() if meta == nil { @@ -157,16 +158,18 @@ func (s *Server) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Pe // notify other peers of our registration for _, remotePeer := range peers { - if channel, ok := s.peerChannels[remotePeer.Key]; ok { - // exclude notified peer and add ourselves - peersToSend := []*Peer{peer} - for _, p := range peers { - if remotePeer.Key != p.Key { - peersToSend = append(peersToSend, p) - } + // exclude notified peer and add ourselves + peersToSend := []*Peer{peer} + for _, p := range peers { + if remotePeer.Key != p.Key { + peersToSend = append(peersToSend, p) } - update := toSyncResponse(s.config, peer, peersToSend) - channel <- &UpdateChannelMessage{Update: update} + } + update := toSyncResponse(s.config, peer, peersToSend) + err = s.peersUpdateManager.SendUpdate(remotePeer.Key, &UpdateMessage{Update: update}) + if err != nil { + // todo rethink if we should keep this return + return nil, err } } @@ -212,6 +215,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto return nil, status.Error(codes.Internal, "internal server error") } } + // Todo fill up turn credentials // if peer has reached this point then it has logged in loginResp := &proto.LoginResponse{ @@ -310,44 +314,6 @@ func (s *Server) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, return &proto.Empty{}, nil } -// openUpdatesChannel creates a go channel for a given peer used to deliver updates relevant to the peer. -func (s *Server) openUpdatesChannel(peerKey string) chan *UpdateChannelMessage { - s.channelsMux.Lock() - defer s.channelsMux.Unlock() - if channel, ok := s.peerChannels[peerKey]; ok { - delete(s.peerChannels, peerKey) - close(channel) - } - //mbragin: todo shouldn't it be more? or configurable? - channel := make(chan *UpdateChannelMessage, 100) - s.peerChannels[peerKey] = channel - - err := s.accountManager.MarkPeerConnected(peerKey, true) - if err != nil { - log.Warnf("failed marking peer as connected %s %v", peerKey, err) - } - - log.Debugf("opened updates channel for a peer %s", peerKey) - return channel -} - -// closeUpdatesChannel closes updates channel of a given peer -func (s *Server) closeUpdatesChannel(peerKey string) { - s.channelsMux.Lock() - defer s.channelsMux.Unlock() - if channel, ok := s.peerChannels[peerKey]; ok { - delete(s.peerChannels, peerKey) - close(channel) - } - - err := s.accountManager.MarkPeerConnected(peerKey, false) - if err != nil { - log.Warnf("failed marking peer as disconnected %s %v", peerKey, err) - } - - log.Debugf("closed updates channel of a peer %s", peerKey) -} - // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization func (s *Server) sendInitialSync(peerKey wgtypes.Key, peer *Peer, srv proto.ManagementService_SyncServer) error { @@ -362,7 +328,7 @@ func (s *Server) sendInitialSync(peerKey wgtypes.Key, peer *Peer, srv proto.Mana if err != nil { return status.Errorf(codes.Internal, "error handling request") } - + // Todo fill up the turn credentials err = srv.Send(&proto.EncryptedMessage{ WgPubKey: s.wgKey.PublicKey().String(), Body: encryptedResp, diff --git a/management/server/management_test.go b/management/server/management_test.go index 4a89c290e..33a5b506c 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -491,7 +491,8 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) { log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } accountManager := server.NewManager(store) - mgmtServer, err := server.NewServer(config, accountManager) + peersUpdateManager := server.NewPeersUpdateManager() + mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager) Expect(err).NotTo(HaveOccurred()) mgmtProto.RegisterManagementServiceServer(s, mgmtServer) go func() { diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go new file mode 100644 index 000000000..9095f4b1f --- /dev/null +++ b/management/server/updatechannel.go @@ -0,0 +1,64 @@ +package server + +import ( + log "github.com/sirupsen/logrus" + "github.com/wiretrustee/wiretrustee/management/proto" + "sync" +) + +type UpdateMessage struct { + Update *proto.SyncResponse +} +type PeersUpdateManager struct { + peerChannels map[string]chan *UpdateMessage + channelsMux *sync.Mutex +} + +// NewPeersUpdateManager returns a new instance of PeersUpdateManager +func NewPeersUpdateManager() *PeersUpdateManager { + return &PeersUpdateManager{ + peerChannels: make(map[string]chan *UpdateMessage), + channelsMux: &sync.Mutex{}, + } +} + +// SendUpdate sends update message to the peer's channel +func (p *PeersUpdateManager) SendUpdate(peer string, update *UpdateMessage) error { + p.channelsMux.Lock() + defer p.channelsMux.Unlock() + if channel, ok := p.peerChannels[peer]; ok { + channel <- update + return nil + } + log.Debugf("peer %s has no channel", peer) + return nil +} + +// CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer. +func (p *PeersUpdateManager) CreateChannel(peerKey string) chan *UpdateMessage { + p.channelsMux.Lock() + defer p.channelsMux.Unlock() + + if channel, ok := p.peerChannels[peerKey]; ok { + delete(p.peerChannels, peerKey) + close(channel) + } + //mbragin: todo shouldn't it be more? or configurable? + channel := make(chan *UpdateMessage, 100) + p.peerChannels[peerKey] = channel + + log.Debugf("opened updates channel for a peer %s", peerKey) + return channel +} + +// CloseChannel closes updates channel of a given peer +func (p *PeersUpdateManager) CloseChannel(peerKey string) { + p.channelsMux.Lock() + defer p.channelsMux.Unlock() + if channel, ok := p.peerChannels[peerKey]; ok { + delete(p.peerChannels, peerKey) + close(channel) + } + + log.Debugf("closed updates channel of a peer %s", peerKey) +} diff --git a/management/server/updatechannel_test.go b/management/server/updatechannel_test.go new file mode 100644 index 000000000..e0087f17c --- /dev/null +++ b/management/server/updatechannel_test.go @@ -0,0 +1,49 @@ +package server + +import ( + "github.com/wiretrustee/wiretrustee/management/proto" + "testing" +) + +var peersUpdater *PeersUpdateManager + +func TestCreateChannel(t *testing.T) { + peer := "test-create" + peersUpdater = NewPeersUpdateManager() + defer peersUpdater.CloseChannel(peer) + + _ = peersUpdater.CreateChannel(peer) + if _, ok := peersUpdater.peerChannels[peer]; !ok { + t.Error("Error creating the channel") + } +} + +func TestSendUpdate(t *testing.T) { + peer := "test-sendupdate" + update := &UpdateMessage{Update: &proto.SyncResponse{}} + _ = peersUpdater.CreateChannel(peer) + if _, ok := peersUpdater.peerChannels[peer]; !ok { + t.Error("Error creating the channel") + } + err := peersUpdater.SendUpdate(peer, update) + if err != nil { + t.Error("Error sending update: ", err) + } + select { + case <-peersUpdater.peerChannels[peer]: + default: + t.Error("Update wasn't send") + } +} + +func TestCloseChannel(t *testing.T) { + peer := "test-close" + _ = peersUpdater.CreateChannel(peer) + if _, ok := peersUpdater.peerChannels[peer]; !ok { + t.Error("Error creating the channel") + } + peersUpdater.CloseChannel(peer) + if _, ok := peersUpdater.peerChannels[peer]; ok { + t.Error("Error closing the channel") + } +}