Stubbing out WireGuard components

Stubbing our WireGuard components so that I can use docker/podman
network_mode=host. This is much more efficient than the docker/podman
userspace network.
This commit is contained in:
Tim Beatham 2023-11-20 11:28:12 +00:00
parent 023565d985
commit 388153e706
12 changed files with 224 additions and 93 deletions

View File

@ -216,6 +216,23 @@ func deleteService(client *ipcRpc.Client, service string) {
fmt.Println(reply) fmt.Println(reply)
} }
func getNode(client *ipcRpc.Client, nodeId, meshId string) {
var reply string
args := &ipc.GetNodeArgs{
NodeId: nodeId,
MeshId: meshId,
}
err := client.Call("IpcHandler.GetNode", &args, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func main() { func main() {
parser := argparse.NewParser("wg-mesh", parser := argparse.NewParser("wg-mesh",
"wg-mesh Manipulate WireGuard meshes") "wg-mesh Manipulate WireGuard meshes")
@ -232,6 +249,7 @@ func main() {
putAliasCmd := parser.NewCommand("put-alias", "Place an alias for the node") putAliasCmd := parser.NewCommand("put-alias", "Place an alias for the node")
setServiceCmd := parser.NewCommand("set-service", "Place a service into your advertisements") setServiceCmd := parser.NewCommand("set-service", "Place a service into your advertisements")
deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements") deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements")
getNodeCmd := parser.NewCommand("get-node", "Get a specific node from the mesh")
var newMeshIfName *string = newMeshCmd.String("f", "ifname", &argparse.Options{Required: true}) var newMeshIfName *string = newMeshCmd.String("f", "ifname", &argparse.Options{Required: true})
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{Required: true}) var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
@ -261,6 +279,9 @@ func main() {
var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{Required: true}) var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{Required: true})
var getNodeNodeId *string = getNodeCmd.String("n", "nodeid", &argparse.Options{Required: true})
var getNodeMeshId *string = getNodeCmd.String("m", "meshid", &argparse.Options{Required: true})
err := parser.Parse(os.Args) err := parser.Parse(os.Args)
if err != nil { if err != nil {
@ -329,4 +350,8 @@ func main() {
if deleteServiceCmd.Happened() { if deleteServiceCmd.Happened() {
deleteService(client, *deleteServiceKey) deleteService(client, *deleteServiceKey)
} }
if getNodeCmd.Happened() {
getNode(client, *getNodeNodeId, *getNodeMeshId)
}
} }

View File

@ -1,7 +1,8 @@
package main package main
import ( import (
"log" "net/http"
_ "net/http/pprof"
"os" "os"
"os/signal" "os/signal"
@ -35,6 +36,12 @@ func main() {
return return
} }
if conf.Profile {
go func() {
http.ListenAndServe("localhost:6060", nil)
}()
}
var robinRpc robin.WgRpc var robinRpc robin.WgRpc
var robinIpc robin.IpcHandler var robinIpc robin.IpcHandler
var syncProvider sync.SyncServiceImpl var syncProvider sync.SyncServiceImpl
@ -65,7 +72,7 @@ func main() {
return return
} }
log.Println("Running IPC Handler") logging.Log.WriteInfof("Running IPC Handler")
go ipc.RunIpcHandler(&robinIpc) go ipc.RunIpcHandler(&robinIpc)
go syncScheduler.Run() go syncScheduler.Run()

7
go.mod
View File

@ -5,9 +5,12 @@ go 1.21.3
require ( require (
github.com/akamensky/argparse v1.4.0 github.com/akamensky/argparse v1.4.0
github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9 github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9
github.com/gin-gonic/gin v1.9.1
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/jmespath/go-jmespath v0.4.0 github.com/jmespath/go-jmespath v0.4.0
github.com/jsimonetti/rtnetlink v1.3.5
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
golang.org/x/sys v0.14.0
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
google.golang.org/grpc v1.58.1 google.golang.org/grpc v1.58.1
google.golang.org/protobuf v1.31.0 google.golang.org/protobuf v1.31.0
@ -19,7 +22,6 @@ require (
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect github.com/gin-contrib/sse v0.1.0 // indirect
github.com/gin-gonic/gin v1.9.1 // indirect
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect github.com/go-playground/validator/v10 v10.14.0 // indirect
@ -27,7 +29,6 @@ require (
github.com/golang/protobuf v1.5.3 // indirect github.com/golang/protobuf v1.5.3 // indirect
github.com/google/go-cmp v0.5.9 // indirect github.com/google/go-cmp v0.5.9 // indirect
github.com/josharian/native v1.1.0 // indirect github.com/josharian/native v1.1.0 // indirect
github.com/jsimonetti/rtnetlink v1.3.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect github.com/leodido/go-urn v1.2.4 // indirect
@ -40,12 +41,10 @@ require (
github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect github.com/ugorji/go/codec v1.2.11 // indirect
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect
golang.org/x/arch v0.3.0 // indirect golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.13.0 // indirect golang.org/x/crypto v0.13.0 // indirect
golang.org/x/net v0.15.0 // indirect golang.org/x/net v0.15.0 // indirect
golang.org/x/sync v0.3.0 // indirect golang.org/x/sync v0.3.0 // indirect
golang.org/x/sys v0.12.0 // indirect
golang.org/x/text v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 // indirect golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect

View File

@ -24,6 +24,8 @@ type CrdtMeshManager struct {
doc *automerge.Doc doc *automerge.Doc
LastHash automerge.ChangeHash LastHash automerge.ChangeHash
conf *conf.WgMeshConfiguration conf *conf.WgMeshConfiguration
cache *MeshCrdt
lastCacheHash automerge.ChangeHash
} }
func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) { func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
@ -40,9 +42,31 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt) c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt)
} }
func (c *CrdtMeshManager) GetNodeIds() []string {
keys, _ := c.doc.Path("nodes").Map().Keys()
return keys
}
// GetMesh(): Converts the document into a struct // GetMesh(): Converts the document into a struct
func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) { func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
return automerge.As[*MeshCrdt](c.doc.Root()) changes, err := c.doc.Changes(c.lastCacheHash)
if err != nil {
return nil, err
}
if c.cache == nil || len(changes) > 3 {
c.lastCacheHash = c.LastHash
cache, err := automerge.As[*MeshCrdt](c.doc.Root())
if err != nil {
return nil, err
}
c.cache = cache
}
return c.cache, nil
} }
// GetMeshId returns the meshid of the mesh // GetMeshId returns the meshid of the mesh
@ -82,13 +106,23 @@ func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, erro
manager.IfName = params.DevName manager.IfName = params.DevName
manager.Client = params.Client manager.Client = params.Client
manager.conf = &params.Conf manager.conf = &params.Conf
manager.cache = nil
return &manager, nil return &manager, nil
} }
// GetNode: returns a mesh node crdt.Close releases resources used by a Client. // NodeExists: returns true if the node exists. Returns false
func (m *CrdtMeshManager) GetNode(endpoint string) (*MeshNodeCrdt, error) { func (m *CrdtMeshManager) NodeExists(key string) bool {
node, err := m.doc.Path("nodes").Map().Get(key)
return node.Kind() == automerge.KindMap && err != nil
}
func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
node, err := m.doc.Path("nodes").Map().Get(endpoint) node, err := m.doc.Path("nodes").Map().Get(endpoint)
if node.Kind() != automerge.KindMap {
return nil, fmt.Errorf("GetNode: something went wrong %s is not a map type")
}
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -49,6 +49,10 @@ type WgMeshConfiguration struct {
Timeout int `yaml:"timeout"` Timeout int `yaml:"timeout"`
// PruneTime number of seconds before we consider the 'node' as dead // PruneTime number of seconds before we consider the 'node' as dead
PruneTime int `yaml:"pruneTime"` PruneTime int `yaml:"pruneTime"`
// Profile whether or not to include a http server that profiles the code
Profile bool `yaml:"profile"`
// StubWg whether or not to stub the WireGuard types
StubWg bool `yaml:"stubWg"`
} }
func ValidateConfiguration(c *WgMeshConfiguration) error { func ValidateConfiguration(c *WgMeshConfiguration) error {

View File

@ -52,6 +52,11 @@ type QueryMesh struct {
Query string Query string
} }
type GetNodeArgs struct {
NodeId string
MeshId string
}
type MeshIpc interface { type MeshIpc interface {
CreateMesh(args *NewMeshArgs, reply *string) error CreateMesh(args *NewMeshArgs, reply *string) error
ListMeshes(name string, reply *ListMeshReply) error ListMeshes(name string, reply *ListMeshReply) error
@ -64,6 +69,7 @@ type MeshIpc interface {
PutDescription(description string, reply *string) error PutDescription(description string, reply *string) error
PutAlias(alias string, reply *string) error PutAlias(alias string, reply *string) error
PutService(args PutServiceArgs, reply *string) error PutService(args PutServiceArgs, reply *string) error
GetNode(args GetNodeArgs, reply *string) error
DeleteService(service string, reply *string) error DeleteService(service string, reply *string) error
} }

View File

@ -34,6 +34,7 @@ type MeshManager interface {
Prune() error Prune() error
Close() error Close() error
GetMonitor() MeshMonitor GetMonitor() MeshMonitor
GetNode(string, string) MeshNode
} }
type MeshManagerImpl struct { type MeshManagerImpl struct {
@ -79,6 +80,22 @@ func (m *MeshManagerImpl) SetService(service string, value string) error {
return nil return nil
} }
func (m *MeshManagerImpl) GetNode(meshid, nodeId string) MeshNode {
mesh, ok := m.Meshes[meshid]
if !ok {
return nil
}
node, err := mesh.GetNode(nodeId)
if err != nil {
return nil
}
return node
}
// GetMonitor implements MeshManager. // GetMonitor implements MeshManager.
func (m *MeshManagerImpl) GetMonitor() MeshMonitor { func (m *MeshManagerImpl) GetMonitor() MeshMonitor {
return m.Monitor return m.Monitor
@ -117,6 +134,7 @@ func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) {
return "", fmt.Errorf("error creating mesh: %w", err) return "", fmt.Errorf("error creating mesh: %w", err)
} }
if !m.conf.StubWg {
err = m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{ err = m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: devName, IfName: devName,
Port: port, Port: port,
@ -125,6 +143,7 @@ func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) {
if err != nil { if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err) return "", fmt.Errorf("error creating mesh: %w", err)
} }
}
m.Meshes[meshId] = nodeManager m.Meshes[meshId] = nodeManager
return meshId, nil return meshId, nil
@ -159,10 +178,14 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
m.Meshes[params.MeshId] = meshProvider m.Meshes[params.MeshId] = meshProvider
if !m.conf.StubWg {
return m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{ return m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: params.DevName, IfName: params.DevName,
Port: params.WgPort, Port: params.WgPort,
}) })
}
return nil
} }
// HasChanges returns true if the mesh has changes // HasChanges returns true if the mesh has changes
@ -195,6 +218,11 @@ func (s *MeshManagerImpl) EnableInterface(meshId string) error {
// GetPublicKey: Gets the public key of the WireGuard mesh // GetPublicKey: Gets the public key of the WireGuard mesh
func (s *MeshManagerImpl) GetPublicKey(meshId string) (*wgtypes.Key, error) { func (s *MeshManagerImpl) GetPublicKey(meshId string) (*wgtypes.Key, error) {
if s.conf.StubWg {
zeroedKey := make([]byte, wgtypes.KeyLen)
return (*wgtypes.Key)(zeroedKey), nil
}
mesh, ok := s.Meshes[meshId] mesh, ok := s.Meshes[meshId]
if !ok { if !ok {
@ -216,7 +244,6 @@ type AddSelfParams struct {
// WgPort is the WireGuard port to advertise // WgPort is the WireGuard port to advertise
WgPort int WgPort int
// Endpoint is the alias of the machine to send routable packets // Endpoint is the alias of the machine to send routable packets
// to
Endpoint string Endpoint string
} }
@ -247,6 +274,7 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
Endpoint: params.Endpoint, Endpoint: params.Endpoint,
}) })
if !s.conf.StubWg {
device, err := mesh.GetDevice() device, err := mesh.GetDevice()
if err != nil { if err != nil {
@ -258,6 +286,7 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
if err != nil { if err != nil {
return fmt.Errorf("addSelf: failed to add address to dev %w", err) return fmt.Errorf("addSelf: failed to add address to dev %w", err)
} }
}
s.Meshes[params.MeshId].AddNode(node) s.Meshes[params.MeshId].AddNode(node)
return s.RouteManager.UpdateRoutes() return s.RouteManager.UpdateRoutes()
@ -277,13 +306,16 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
return err return err
} }
device, err := mesh.GetDevice() if !s.conf.StubWg {
device, e := mesh.GetDevice()
if err != nil { if e != nil {
return err return err
} }
err = s.interfaceManipulator.RemoveInterface(device.Name) err = s.interfaceManipulator.RemoveInterface(device.Name)
}
delete(s.Meshes, meshId) delete(s.Meshes, meshId)
return err return err
} }
@ -295,15 +327,9 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
return nil, fmt.Errorf("mesh %s does not exist", meshId) return nil, fmt.Errorf("mesh %s does not exist", meshId)
} }
snapshot, err := meshInstance.GetMesh() node, err := meshInstance.GetNode(s.HostParameters.HostEndpoint)
if err != nil { if err != nil {
return nil, err
}
node, ok := snapshot.GetNodes()[s.HostParameters.HostEndpoint]
if !ok {
return nil, errors.New("the node doesn't exist in the mesh") return nil, errors.New("the node doesn't exist in the mesh")
} }
@ -317,17 +343,19 @@ func (s *MeshManagerImpl) ApplyConfig() error {
return err return err
} }
return s.RouteManager.InstallRoutes() return nil
} }
func (s *MeshManagerImpl) SetDescription(description string) error { func (s *MeshManagerImpl) SetDescription(description string) error {
for _, mesh := range s.Meshes { for _, mesh := range s.Meshes {
if mesh.NodeExists(s.HostParameters.HostEndpoint) {
err := mesh.SetDescription(s.HostParameters.HostEndpoint, description) err := mesh.SetDescription(s.HostParameters.HostEndpoint, description)
if err != nil { if err != nil {
return err return err
} }
} }
}
return nil return nil
} }
@ -335,28 +363,22 @@ func (s *MeshManagerImpl) SetDescription(description string) error {
// SetAlias implements MeshManager. // SetAlias implements MeshManager.
func (s *MeshManagerImpl) SetAlias(alias string) error { func (s *MeshManagerImpl) SetAlias(alias string) error {
for _, mesh := range s.Meshes { for _, mesh := range s.Meshes {
if mesh.NodeExists(s.HostParameters.HostEndpoint) {
err := mesh.SetAlias(s.HostParameters.HostEndpoint, alias) err := mesh.SetAlias(s.HostParameters.HostEndpoint, alias)
if err != nil { if err != nil {
return err return err
} }
} }
}
return nil return nil
} }
// UpdateTimeStamp updates the timestamp of this node in all meshes // UpdateTimeStamp updates the timestamp of this node in all meshes
func (s *MeshManagerImpl) UpdateTimeStamp() error { func (s *MeshManagerImpl) UpdateTimeStamp() error {
for _, mesh := range s.Meshes { for _, mesh := range s.Meshes {
snapshot, err := mesh.GetMesh() if mesh.NodeExists(s.HostParameters.HostEndpoint) {
err := mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint)
if err != nil {
return err
}
_, exists := snapshot.GetNodes()[s.HostParameters.HostEndpoint]
if exists {
err = mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint)
if err != nil { if err != nil {
return err return err
@ -375,7 +397,12 @@ func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider {
return s.Meshes return s.Meshes
} }
// Close the mesh manager
func (s *MeshManagerImpl) Close() error { func (s *MeshManagerImpl) Close() error {
if s.conf.StubWg {
return nil
}
for _, mesh := range s.Meshes { for _, mesh := range s.Meshes {
dev, err := mesh.GetDevice() dev, err := mesh.GetDevice()

View File

@ -76,6 +76,21 @@ type MeshProviderStub struct {
snapshot *MeshSnapshotStub snapshot *MeshSnapshotStub
} }
// GetNodeIds implements MeshProvider.
func (*MeshProviderStub) GetNodeIds() []string {
panic("unimplemented")
}
// GetNode implements MeshProvider.
func (*MeshProviderStub) GetNode(string) (MeshNode, error) {
panic("unimplemented")
}
// NodeExists implements MeshProvider.
func (*MeshProviderStub) NodeExists(string) bool {
panic("unimplemented")
}
// AddService implements MeshProvider. // AddService implements MeshProvider.
func (*MeshProviderStub) AddService(nodeId string, key string, value string) error { func (*MeshProviderStub) AddService(nodeId string, key string, value string) error {
panic("unimplemented") panic("unimplemented")
@ -196,6 +211,11 @@ type MeshManagerStub struct {
meshes map[string]MeshProvider meshes map[string]MeshProvider
} }
// GetNode implements MeshManager.
func (*MeshManagerStub) GetNode(string, string) MeshNode {
panic("unimplemented")
}
// RemoveService implements MeshManager. // RemoveService implements MeshManager.
func (*MeshManagerStub) RemoveService(service string) error { func (*MeshManagerStub) RemoveService(service string) error {
panic("unimplemented") panic("unimplemented")

View File

@ -92,7 +92,7 @@ type MeshSyncer interface {
type MeshProvider interface { type MeshProvider interface {
// AddNode() adds a node to the mesh // AddNode() adds a node to the mesh
AddNode(node MeshNode) AddNode(node MeshNode)
// GetMesh() returns a snapshot of the mesh provided by the mesh provider // GetMesh() returns a snapshot of the mesh provided by the mesh provider.
GetMesh() (MeshSnapshot, error) GetMesh() (MeshSnapshot, error)
// GetMeshId() returns the ID of the mesh network // GetMeshId() returns the ID of the mesh network
GetMeshId() string GetMeshId() string
@ -114,6 +114,10 @@ type MeshProvider interface {
RemoveRoutes(nodeId string, route ...string) error RemoveRoutes(nodeId string, route ...string) error
// GetSyncer: returns the automerge syncer for sync // GetSyncer: returns the automerge syncer for sync
GetSyncer() MeshSyncer GetSyncer() MeshSyncer
// GetNode get a particular not within the mesh
GetNode(string) (MeshNode, error)
// NodeExists: returns true if a particular node exists false otherwise
NodeExists(string) bool
// SetDescription: sets the description of this automerge data type // SetDescription: sets the description of this automerge data type
SetDescription(nodeId string, description string) error SetDescription(nodeId string, description string) error
// SetAlias: set the alias of the nodeId // SetAlias: set the alias of the nodeId
@ -125,6 +129,7 @@ type MeshProvider interface {
// Prune: prunes all nodes that have not updated their timestamp in // Prune: prunes all nodes that have not updated their timestamp in
// pruneAmount seconds // pruneAmount seconds
Prune(pruneAmount int) error Prune(pruneAmount int) error
GetNodeIds() []string
} }
// HostParameters contains the IDs of a node // HostParameters contains the IDs of a node

View File

@ -53,7 +53,7 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
return nil, err return nil, err
} }
nodes := lib.Map(lib.MapValues(snapshot.GetNodes()), meshNodeToQueryNode) nodes := lib.Map(lib.MapValues(snapshot.GetNodes()), MeshNodeToQueryNode)
result, err := jmespath.Search(queryParams, nodes) result, err := jmespath.Search(queryParams, nodes)
@ -65,7 +65,7 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
return bytes, err return bytes, err
} }
func meshNodeToQueryNode(node mesh.MeshNode) *QueryNode { func MeshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode := new(QueryNode) queryNode := new(QueryNode)
queryNode.HostEndpoint = node.GetHostEndpoint() queryNode.HostEndpoint = node.GetHostEndpoint()
pubKey, _ := node.GetPublicKey() pubKey, _ := node.GetPublicKey()

View File

@ -2,6 +2,7 @@ package robin
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
@ -10,6 +11,7 @@ import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/wgmesh/pkg/rpc"
) )
@ -242,6 +244,27 @@ func (n *IpcHandler) DeleteService(service string, reply *string) error {
return nil return nil
} }
func (n *IpcHandler) GetNode(args ipc.GetNodeArgs, reply *string) error {
node := n.Server.GetMeshManager().GetNode(args.MeshId, args.NodeId)
if node == nil {
*reply = "nil"
return nil
}
queryNode := query.MeshNodeToQueryNode(node)
bytes, err := json.Marshal(queryNode)
if err != nil {
*reply = err.Error()
return nil
}
*reply = string(bytes)
return nil
}
type RobinIpcParams struct { type RobinIpcParams struct {
CtrlServer ctrlserver.CtrlServer CtrlServer ctrlserver.CtrlServer
} }

View File

@ -1,7 +1,6 @@
package sync package sync
import ( import (
"errors"
"math/rand" "math/rand"
"sync" "sync"
"time" "time"
@ -36,24 +35,16 @@ func (s *SyncerImpl) Sync(meshId string) error {
} }
logging.Log.WriteInfof("UPDATING WG CONF") logging.Log.WriteInfof("UPDATING WG CONF")
if s.manager.HasChanges(meshId) {
err := s.manager.ApplyConfig() err := s.manager.ApplyConfig()
if err != nil { if err != nil {
logging.Log.WriteInfof("Failed to update config %w", err) logging.Log.WriteInfof("Failed to update config %w", err)
} }
theMesh := s.manager.GetMesh(meshId)
if theMesh == nil {
return errors.New("the provided mesh does not exist")
} }
snapshot, _ := theMesh.GetMesh() nodeNames := s.manager.GetMesh(meshId).GetNodeIds()
nodes := snapshot.GetNodes()
if len(nodes) <= 1 {
return nil
}
self, err := s.manager.GetSelf(meshId) self, err := s.manager.GetSelf(meshId)
@ -61,17 +52,6 @@ func (s *SyncerImpl) Sync(meshId string) error {
return err return err
} }
excludedNodes := map[string]struct{}{
self.GetHostEndpoint(): {},
}
meshNodes := lib.MapValuesWithExclude(nodes, excludedNodes)
getNames := func(node mesh.MeshNode) string {
return node.GetHostEndpoint()
}
nodeNames := lib.Map(meshNodes, getNames)
neighbours := s.cluster.GetNeighbours(nodeNames, self.GetHostEndpoint()) neighbours := s.cluster.GetNeighbours(nodeNames, self.GetHostEndpoint())
randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate) randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
@ -81,7 +61,7 @@ func (s *SyncerImpl) Sync(meshId string) error {
before := time.Now() before := time.Now()
if len(meshNodes) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance { if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
logging.Log.WriteInfof("Sending to random cluster") logging.Log.WriteInfof("Sending to random cluster")
interCluster := s.cluster.GetInterCluster(nodeNames, self.GetHostEndpoint()) interCluster := s.cluster.GetInterCluster(nodeNames, self.GetHostEndpoint())
randomSubset = append(randomSubset, interCluster) randomSubset = append(randomSubset, interCluster)
@ -109,7 +89,8 @@ func (s *SyncerImpl) Sync(meshId string) error {
// Check if any changes have occurred and trigger callbacks // Check if any changes have occurred and trigger callbacks
// if changes have occurred. // if changes have occurred.
return s.manager.GetMonitor().Trigger() // return s.manager.GetMonitor().Trigger()
return nil
} }
// SyncMeshes: Sync all meshes // SyncMeshes: Sync all meshes