diff --git a/pkg/mesh/config.go b/pkg/mesh/config.go index 55c0035..a784748 100644 --- a/pkg/mesh/config.go +++ b/pkg/mesh/config.go @@ -25,7 +25,7 @@ type WgMeshConfigApplyer struct { routeInstaller route.RouteInstaller } -func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) { +func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, peerToClients map[string][]net.IPNet) (*wgtypes.PeerConfig, error) { endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint()) if err != nil { @@ -46,7 +46,13 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode) (*wgtypes.PeerConfi allowedips = append(allowedips, *ipnet) } - keepAlive := time.Duration(m.config.KeepAliveTime) * time.Second + clients, ok := peerToClients[node.GetWgHost().String()] + + if ok { + allowedips = append(allowedips, clients...) + } + + keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second peerConfig := wgtypes.PeerConfig{ PublicKey: pubKey, @@ -86,37 +92,45 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { return err } + peerToClients := make(map[string][]net.IPNet) + for _, n := range nodes { if NodeEquals(n, self) { continue } - if n.GetType() == conf.CLIENT_ROLE && len(peers) > 0 { + if n.GetType() == conf.CLIENT_ROLE && len(peers) > 0 && self.GetType() == conf.CLIENT_ROLE { peer := lib.ConsistentHash(peers, n, func(mn MeshNode) int { return lib.HashString(mn.GetWgHost().String()) }) - if !NodeEquals(peer, self) { - dev, err := mesh.GetDevice() + dev, err := mesh.GetDevice() - if err != nil { - return err - } - - rtnl.AddRoute(dev.Name, lib.Route{ - Gateway: peer.GetWgHost().IP, - Destination: *n.GetWgHost(), - }) - - if err != nil { - return err - } - - continue + if err != nil { + return err } + + rtnl.AddRoute(dev.Name, lib.Route{ + Gateway: peer.GetWgHost().IP, + Destination: *n.GetWgHost(), + }) + + if err != nil { + return err + } + + clients, ok := peerToClients[peer.GetWgHost().String()] + + if !ok { + clients = make([]net.IPNet, 0) + peerToClients[peer.GetWgHost().String()] = clients + } + + peerToClients[peer.GetWgHost().String()] = append(clients, *n.GetWgHost()) + continue } - peer, err := m.convertMeshNode(n) + peer, err := m.convertMeshNode(n, peerToClients) if err != nil { return err diff --git a/pkg/mesh/types.go b/pkg/mesh/types.go index 4613bfc..fe38b56 100644 --- a/pkg/mesh/types.go +++ b/pkg/mesh/types.go @@ -4,7 +4,6 @@ package mesh import ( "net" - "slices" "github.com/tim-beatham/wgmesh/pkg/conf" "golang.zx2c4.com/wireguard/wgctrl" @@ -46,42 +45,7 @@ type MeshNode interface { // NodeEquals: determines if two mesh nodes are equivalent to one another func NodeEquals(node1, node2 MeshNode) bool { - if node1.GetHostEndpoint() != node2.GetHostEndpoint() { - return false - } - - node1Pub, _ := node1.GetPublicKey() - node2Pub, _ := node2.GetPublicKey() - - if node1Pub != node2Pub { - return false - } - - if node1.GetWgEndpoint() != node2.GetWgEndpoint() { - return false - } - - if node1.GetWgHost() != node2.GetWgHost() { - return false - } - - if !slices.Equal(node1.GetRoutes(), node2.GetRoutes()) { - return false - } - - if node1.GetIdentifier() != node2.GetIdentifier() { - return false - } - - if node1.GetDescription() != node2.GetDescription() { - return false - } - - if node1.GetAlias() != node2.GetAlias() { - return false - } - - return true + return node1.GetHostEndpoint() == node2.GetHostEndpoint() } type MeshSnapshot interface {