diff --git a/pkg/dns/dns.go b/pkg/dns/dns.go index dab9143..8847ee0 100644 --- a/pkg/dns/dns.go +++ b/pkg/dns/dns.go @@ -29,7 +29,7 @@ func (d *DNSHandler) queryMesh(meshId, alias string) net.IP { err := d.client.Call("IpcHandler.Query", &ipc.QueryMesh{ MeshId: meshId, - Query: "[?alias == 'tim'] | [0]", + Query: fmt.Sprintf("[?alias == '%s'] | [0]", alias), }, &reply) if err != nil { diff --git a/pkg/mesh/graph.go b/pkg/mesh/graph.go index 81e592b..33891f3 100644 --- a/pkg/mesh/graph.go +++ b/pkg/mesh/graph.go @@ -61,7 +61,7 @@ func (c *MeshDOTConverter) graphNode(g *graph.Graph, node MeshNode, meshId strin self, _ := c.manager.GetSelf(meshId) - if node.GetHostEndpoint() == self.GetHostEndpoint() { + if NodeEquals(self, node) { return } diff --git a/pkg/mesh/manager.go b/pkg/mesh/manager.go index 961ecb0..f36d6ea 100644 --- a/pkg/mesh/manager.go +++ b/pkg/mesh/manager.go @@ -266,20 +266,16 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error { params.WgPort = device.ListenPort } - pubKey, err := s.GetPublicKey(params.MeshId) + pubKey := s.HostParameters.PrivateKey.PublicKey() - if err != nil { - return err - } - - nodeIP, err := s.ipAllocator.GetIP(*pubKey, params.MeshId) + nodeIP, err := s.ipAllocator.GetIP(pubKey, params.MeshId) if err != nil { return err } node := s.nodeFactory.Build(&MeshNodeFactoryParams{ - PublicKey: pubKey, + PublicKey: &pubKey, NodeIP: nodeIP, WgPort: params.WgPort, Endpoint: params.Endpoint, diff --git a/pkg/mesh/route.go b/pkg/mesh/route.go index e8d8f4f..9b9edd8 100644 --- a/pkg/mesh/route.go +++ b/pkg/mesh/route.go @@ -45,7 +45,7 @@ func (r *RouteManagerImpl) UpdateRoutes() error { return err } - err = mesh1.AddRoutes(self.GetHostEndpoint(), ipNet.String()) + err = mesh1.AddRoutes(NodeID(self), ipNet.String()) if err != nil { return err @@ -74,7 +74,7 @@ func (r *RouteManagerImpl) RemoveRoutes(meshId string) error { return err } - mesh1.RemoveRoutes(self.GetHostEndpoint(), ipNet.String()) + mesh1.RemoveRoutes(NodeID(self), ipNet.String()) } return nil } @@ -152,7 +152,7 @@ func (m *RouteManagerImpl) installRoutes(meshProvider MeshProvider) error { } for _, node := range mesh.GetNodes() { - if self.GetHostEndpoint() == node.GetHostEndpoint() { + if NodeEquals(self, node) { continue } diff --git a/pkg/mesh/stub_types.go b/pkg/mesh/stub_types.go index 7fe53e2..0a007c3 100644 --- a/pkg/mesh/stub_types.go +++ b/pkg/mesh/stub_types.go @@ -23,7 +23,7 @@ type MeshNodeStub struct { // GetType implements MeshNode. func (*MeshNodeStub) GetType() conf.NodeType { - return PEER + return conf.PEER_ROLE } // GetServices implements MeshNode. diff --git a/pkg/mesh/types.go b/pkg/mesh/types.go index cd5d0ed..39b7e67 100644 --- a/pkg/mesh/types.go +++ b/pkg/mesh/types.go @@ -10,13 +10,6 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -const ( - // Data Exchanged Between Peers - PEER conf.NodeType = "peer" - // Data Exchanged Between Clients - CLIENT conf.NodeType = "client" -) - // MeshNode represents an implementation of a node in a mesh type MeshNode interface { // GetHostEndpoint: gets the gRPC endpoint of the node @@ -45,7 +38,15 @@ type MeshNode interface { // NodeEquals: determines if two mesh nodes are equivalent to one another func NodeEquals(node1, node2 MeshNode) bool { - return node1.GetHostEndpoint() == node2.GetHostEndpoint() + key1, _ := node1.GetPublicKey() + key2, _ := node2.GetPublicKey() + + return key1.String() == key2.String() +} + +func NodeID(node MeshNode) string { + key, _ := node.GetPublicKey() + return key.String() } type MeshSnapshot interface {