diff --git a/pkg/mesh/config.go b/pkg/mesh/config.go index 328bacb..f156259 100644 --- a/pkg/mesh/config.go +++ b/pkg/mesh/config.go @@ -155,68 +155,70 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][] return routes } -func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { - snap, err := mesh.GetMesh() - - if err != nil { - return err +// getCorrespondignPeer: gets the peer corresponding to the client +func (m *WgMeshConfigApplyer) getCorrespondingPeer(peers []MeshNode, client MeshNode) MeshNode { + hashFunc := func(mn MeshNode) int { + pubKey, _ := mn.GetPublicKey() + return lib.HashString(pubKey.String()) } - nodes := lib.MapValues(snap.GetNodes()) - - slices.SortFunc(nodes, func(a, b MeshNode) int { - return strings.Compare(string(a.GetType()), string(b.GetType())) - }) - - peerConfigs := make([]wgtypes.PeerConfig, len(nodes)) - - peers := lib.Filter(nodes, func(mn MeshNode) bool { - return mn.GetType() == conf.PEER_ROLE - }) - - clients := lib.Filter(nodes, func(mn MeshNode) bool { - return mn.GetType() == conf.CLIENT_ROLE - }) - - var count int = 0 + peer := lib.ConsistentHash(peers, client, hashFunc, hashFunc) + return peer +} +func (m *WgMeshConfigApplyer) getClientConfig(mesh MeshProvider, peers []MeshNode, clients []MeshNode) (*wgtypes.Config, error) { self, err := m.meshManager.GetSelf(mesh.GetMeshId()) if err != nil { - return err + return nil, err } + peer := m.getCorrespondingPeer(peers, self) + + pubKey, _ := peer.GetPublicKey() + + keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second + endpoint, err := net.ResolveUDPAddr("udp", peer.GetWgEndpoint()) + + if err != nil { + return nil, err + } + + allowedips := make([]net.IPNet, 1) + _, ipnet, _ := net.ParseCIDR("::/0") + allowedips[0] = *ipnet + + peerCfgs := make([]wgtypes.PeerConfig, 1) + + peerCfgs[0] = wgtypes.PeerConfig{ + PublicKey: pubKey, + Endpoint: endpoint, + PersistentKeepaliveInterval: &keepAlive, + AllowedIPs: allowedips, + } + + cfg := wgtypes.Config{ + Peers: peerCfgs, + } + + return &cfg, err +} + +func (m *WgMeshConfigApplyer) getPeerConfig(mesh MeshProvider, peers []MeshNode, clients []MeshNode, dev *wgtypes.Device) (*wgtypes.Config, error) { peerToClients := make(map[string][]net.IPNet) routes := m.getRoutes(mesh) installedRoutes := make([]lib.Route, 0) + peerConfigs := make([]wgtypes.PeerConfig, 0) + self, err := m.meshManager.GetSelf(mesh.GetMeshId()) - dev, _ := mesh.GetDevice() + if err != nil { + return nil, err + } for _, n := range clients { - if NodeEquals(n, self) { - continue - } - - if self.GetType() == conf.PEER_ROLE { - client, err := m.convertMeshNode(n, dev, peerToClients, routes) - - if err != nil { - return err - } - - peerConfigs[count] = *client - count++ - - } else if len(peers) > 0 && self.GetType() == conf.CLIENT_ROLE { - hashFunc := func(mn MeshNode) int { - pubKey, _ := mn.GetPublicKey() - return lib.HashString(pubKey.String()) - } - - peer := lib.ConsistentHash(peers, n, hashFunc, hashFunc) - + if len(peers) > 0 { + peer := m.getCorrespondingPeer(peers, n) pubKey, _ := peer.GetPublicKey() - clients, ok := peerToClients[pubKey.String()] if !ok { @@ -225,6 +227,16 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { } peerToClients[pubKey.String()] = append(clients, *n.GetWgHost()) + + if NodeEquals(self, peer) { + cfg, err := m.convertMeshNode(n, dev, peerToClients, routes) + + if err != nil { + return nil, err + } + + peerConfigs = append(peerConfigs, *cfg) + } } } @@ -236,7 +248,7 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { peer, err := m.convertMeshNode(n, dev, peerToClients, routes) if err != nil { - return err + return nil, err } for _, route := range peer.AllowedIPs { @@ -251,21 +263,66 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { } } - peerConfigs[count] = *peer - count++ + peerConfigs = append(peerConfigs, *peer) } cfg := wgtypes.Config{ - Peers: peerConfigs, + Peers: peerConfigs, + ReplacePeers: true, } - err = m.meshManager.GetClient().ConfigureDevice(dev.Name, cfg) + err = m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...) + return &cfg, err +} + +func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { + snap, err := mesh.GetMesh() if err != nil { return err } - return m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...) + nodes := lib.MapValues(snap.GetNodes()) + dev, _ := mesh.GetDevice() + + slices.SortFunc(nodes, func(a, b MeshNode) int { + return strings.Compare(string(a.GetType()), string(b.GetType())) + }) + + peers := lib.Filter(nodes, func(mn MeshNode) bool { + return mn.GetType() == conf.PEER_ROLE + }) + + clients := lib.Filter(nodes, func(mn MeshNode) bool { + return mn.GetType() == conf.CLIENT_ROLE + }) + + self, err := m.meshManager.GetSelf(mesh.GetMeshId()) + + if err != nil { + return err + } + + var cfg *wgtypes.Config = nil + + switch self.GetType() { + case conf.PEER_ROLE: + cfg, err = m.getPeerConfig(mesh, peers, clients, dev) + case conf.CLIENT_ROLE: + cfg, err = m.getClientConfig(mesh, peers, clients) + } + + if err != nil { + return err + } + + err = m.meshManager.GetClient().ConfigureDevice(dev.Name, *cfg) + + if err != nil { + return err + } + + return nil } func (m *WgMeshConfigApplyer) ApplyConfig() error { @@ -294,7 +351,8 @@ func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error { } m.meshManager.GetClient().ConfigureDevice(dev.Name, wgtypes.Config{ - Peers: make([]wgtypes.PeerConfig, 0), + Peers: make([]wgtypes.PeerConfig, 0), + ReplacePeers: true, }) return nil