Compare commits

...

4 Commits

Author SHA1 Message Date
b179cd3cf4 Hashing the WireGuard interface
Hashing the interface and using ephmeral ports so that the admin doesn't
choose an interface and port combination. An administrator can alteranatively
decide to provide port but this isn't critical.
2023-11-20 13:03:42 +00:00
8f211aa116 Merge pull request #18 from tim-beatham/26-performance-testing
Stubbing out WireGuard components
2023-11-20 11:29:37 +00:00
388153e706 Stubbing out WireGuard components
Stubbing our WireGuard components so that I can use docker/podman
network_mode=host. This is much more efficient than the docker/podman
userspace network.
2023-11-20 11:28:12 +00:00
023565d985 Merge pull request #17 from tim-beatham/25-ability-to-aliases
25 ability to aliases
2023-11-17 22:20:57 +00:00
18 changed files with 270 additions and 142 deletions

View File

@ -16,7 +16,6 @@ const SockAddr = "/tmp/wgmesh_ipc.sock"
type CreateMeshParams struct { type CreateMeshParams struct {
Client *ipcRpc.Client Client *ipcRpc.Client
IfName string
WgPort int WgPort int
Endpoint string Endpoint string
} }
@ -24,7 +23,6 @@ type CreateMeshParams struct {
func createMesh(args *CreateMeshParams) string { func createMesh(args *CreateMeshParams) string {
var reply string var reply string
newMeshParams := ipc.NewMeshArgs{ newMeshParams := ipc.NewMeshArgs{
IfName: args.IfName,
WgPort: args.WgPort, WgPort: args.WgPort,
Endpoint: args.Endpoint, Endpoint: args.Endpoint,
} }
@ -68,7 +66,6 @@ func joinMesh(params *JoinMeshParams) string {
args := ipc.JoinMeshArgs{ args := ipc.JoinMeshArgs{
MeshId: params.MeshId, MeshId: params.MeshId,
IpAdress: params.IpAddress, IpAdress: params.IpAddress,
IfName: params.IfName,
Port: params.WgPort, Port: params.WgPort,
} }
@ -216,6 +213,23 @@ func deleteService(client *ipcRpc.Client, service string) {
fmt.Println(reply) fmt.Println(reply)
} }
func getNode(client *ipcRpc.Client, nodeId, meshId string) {
var reply string
args := &ipc.GetNodeArgs{
NodeId: nodeId,
MeshId: meshId,
}
err := client.Call("IpcHandler.GetNode", &args, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func main() { func main() {
parser := argparse.NewParser("wg-mesh", parser := argparse.NewParser("wg-mesh",
"wg-mesh Manipulate WireGuard meshes") "wg-mesh Manipulate WireGuard meshes")
@ -232,15 +246,14 @@ func main() {
putAliasCmd := parser.NewCommand("put-alias", "Place an alias for the node") putAliasCmd := parser.NewCommand("put-alias", "Place an alias for the node")
setServiceCmd := parser.NewCommand("set-service", "Place a service into your advertisements") setServiceCmd := parser.NewCommand("set-service", "Place a service into your advertisements")
deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements") deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements")
getNodeCmd := parser.NewCommand("get-node", "Get a specific node from the mesh")
var newMeshIfName *string = newMeshCmd.String("f", "ifname", &argparse.Options{Required: true}) var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{})
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{}) var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{})
var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true}) var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true}) var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true})
var joinMeshIfName *string = joinMeshCmd.String("f", "ifname", &argparse.Options{Required: true}) var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{})
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{}) var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{})
var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true}) var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true})
@ -261,6 +274,9 @@ func main() {
var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{Required: true}) var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{Required: true})
var getNodeNodeId *string = getNodeCmd.String("n", "nodeid", &argparse.Options{Required: true})
var getNodeMeshId *string = getNodeCmd.String("m", "meshid", &argparse.Options{Required: true})
err := parser.Parse(os.Args) err := parser.Parse(os.Args)
if err != nil { if err != nil {
@ -277,7 +293,6 @@ func main() {
if newMeshCmd.Happened() { if newMeshCmd.Happened() {
fmt.Println(createMesh(&CreateMeshParams{ fmt.Println(createMesh(&CreateMeshParams{
Client: client, Client: client,
IfName: *newMeshIfName,
WgPort: *newMeshPort, WgPort: *newMeshPort,
Endpoint: *newMeshEndpoint, Endpoint: *newMeshEndpoint,
})) }))
@ -290,7 +305,6 @@ func main() {
if joinMeshCmd.Happened() { if joinMeshCmd.Happened() {
fmt.Println(joinMesh(&JoinMeshParams{ fmt.Println(joinMesh(&JoinMeshParams{
Client: client, Client: client,
IfName: *joinMeshIfName,
WgPort: *joinMeshPort, WgPort: *joinMeshPort,
IpAddress: *joinMeshIpAddress, IpAddress: *joinMeshIpAddress,
MeshId: *joinMeshId, MeshId: *joinMeshId,
@ -329,4 +343,8 @@ func main() {
if deleteServiceCmd.Happened() { if deleteServiceCmd.Happened() {
deleteService(client, *deleteServiceKey) deleteService(client, *deleteServiceKey)
} }
if getNodeCmd.Happened() {
getNode(client, *getNodeNodeId, *getNodeMeshId)
}
} }

View File

@ -1,7 +1,8 @@
package main package main
import ( import (
"log" "net/http"
_ "net/http/pprof"
"os" "os"
"os/signal" "os/signal"
@ -35,6 +36,12 @@ func main() {
return return
} }
if conf.Profile {
go func() {
http.ListenAndServe("localhost:6060", nil)
}()
}
var robinRpc robin.WgRpc var robinRpc robin.WgRpc
var robinIpc robin.IpcHandler var robinIpc robin.IpcHandler
var syncProvider sync.SyncServiceImpl var syncProvider sync.SyncServiceImpl
@ -65,7 +72,7 @@ func main() {
return return
} }
log.Println("Running IPC Handler") logging.Log.WriteInfof("Running IPC Handler")
go ipc.RunIpcHandler(&robinIpc) go ipc.RunIpcHandler(&robinIpc)
go syncScheduler.Run() go syncScheduler.Run()

7
go.mod
View File

@ -5,9 +5,12 @@ go 1.21.3
require ( require (
github.com/akamensky/argparse v1.4.0 github.com/akamensky/argparse v1.4.0
github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9 github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9
github.com/gin-gonic/gin v1.9.1
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/jmespath/go-jmespath v0.4.0 github.com/jmespath/go-jmespath v0.4.0
github.com/jsimonetti/rtnetlink v1.3.5
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
golang.org/x/sys v0.14.0
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
google.golang.org/grpc v1.58.1 google.golang.org/grpc v1.58.1
google.golang.org/protobuf v1.31.0 google.golang.org/protobuf v1.31.0
@ -19,7 +22,6 @@ require (
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect github.com/gin-contrib/sse v0.1.0 // indirect
github.com/gin-gonic/gin v1.9.1 // indirect
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect github.com/go-playground/validator/v10 v10.14.0 // indirect
@ -27,7 +29,6 @@ require (
github.com/golang/protobuf v1.5.3 // indirect github.com/golang/protobuf v1.5.3 // indirect
github.com/google/go-cmp v0.5.9 // indirect github.com/google/go-cmp v0.5.9 // indirect
github.com/josharian/native v1.1.0 // indirect github.com/josharian/native v1.1.0 // indirect
github.com/jsimonetti/rtnetlink v1.3.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect github.com/leodido/go-urn v1.2.4 // indirect
@ -40,12 +41,10 @@ require (
github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect github.com/ugorji/go/codec v1.2.11 // indirect
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect
golang.org/x/arch v0.3.0 // indirect golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.13.0 // indirect golang.org/x/crypto v0.13.0 // indirect
golang.org/x/net v0.15.0 // indirect golang.org/x/net v0.15.0 // indirect
golang.org/x/sync v0.3.0 // indirect golang.org/x/sync v0.3.0 // indirect
golang.org/x/sys v0.12.0 // indirect
golang.org/x/text v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 // indirect golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect

View File

@ -62,11 +62,11 @@ func (s *SmegServer) CreateMesh(c *gin.Context) {
c.JSON(http.StatusBadRequest, &gin.H{ c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(), "error": err.Error(),
}) })
return return
} }
ipcRequest := ipc.NewMeshArgs{ ipcRequest := ipc.NewMeshArgs{
IfName: createMesh.IfName,
WgPort: createMesh.WgPort, WgPort: createMesh.WgPort,
} }
@ -100,7 +100,6 @@ func (s *SmegServer) JoinMesh(c *gin.Context) {
ipcRequest := ipc.JoinMeshArgs{ ipcRequest := ipc.JoinMeshArgs{
MeshId: joinMesh.MeshId, MeshId: joinMesh.MeshId,
IpAdress: joinMesh.Bootstrap, IpAdress: joinMesh.Bootstrap,
IfName: joinMesh.IfName,
Port: joinMesh.WgPort, Port: joinMesh.WgPort,
} }

View File

@ -18,13 +18,11 @@ type SmegMesh struct {
} }
type CreateMeshRequest struct { type CreateMeshRequest struct {
IfName string `json:"ifName" binding:"required"` WgPort int `json:"port" binding:"gte=1024,lt=65535"`
WgPort int `json:"port" binding:"required,gte=1024,lt=65535"`
} }
type JoinMeshRequest struct { type JoinMeshRequest struct {
IfName string `json:"ifName" binding:"required"` WgPort int `json:"port" binding:"gte=1024,lt=65535"`
WgPort int `json:"port" binding:"required,gte=1024,lt=65535"`
Bootstrap string `json:"bootstrap" binding:"required"` Bootstrap string `json:"bootstrap" binding:"required"`
MeshId string `json:"meshid" binding:"required"` MeshId string `json:"meshid" binding:"required"`
} }

View File

@ -18,12 +18,14 @@ import (
// CrdtMeshManager manages nodes in the crdt mesh // CrdtMeshManager manages nodes in the crdt mesh
type CrdtMeshManager struct { type CrdtMeshManager struct {
MeshId string MeshId string
IfName string IfName string
Client *wgctrl.Client Client *wgctrl.Client
doc *automerge.Doc doc *automerge.Doc
LastHash automerge.ChangeHash LastHash automerge.ChangeHash
conf *conf.WgMeshConfiguration conf *conf.WgMeshConfiguration
cache *MeshCrdt
lastCacheHash automerge.ChangeHash
} }
func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) { func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
@ -40,9 +42,31 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt) c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt)
} }
func (c *CrdtMeshManager) GetNodeIds() []string {
keys, _ := c.doc.Path("nodes").Map().Keys()
return keys
}
// GetMesh(): Converts the document into a struct // GetMesh(): Converts the document into a struct
func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) { func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
return automerge.As[*MeshCrdt](c.doc.Root()) changes, err := c.doc.Changes(c.lastCacheHash)
if err != nil {
return nil, err
}
if c.cache == nil || len(changes) > 3 {
c.lastCacheHash = c.LastHash
cache, err := automerge.As[*MeshCrdt](c.doc.Root())
if err != nil {
return nil, err
}
c.cache = cache
}
return c.cache, nil
} }
// GetMeshId returns the meshid of the mesh // GetMeshId returns the meshid of the mesh
@ -82,13 +106,23 @@ func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, erro
manager.IfName = params.DevName manager.IfName = params.DevName
manager.Client = params.Client manager.Client = params.Client
manager.conf = &params.Conf manager.conf = &params.Conf
manager.cache = nil
return &manager, nil return &manager, nil
} }
// GetNode: returns a mesh node crdt.Close releases resources used by a Client. // NodeExists: returns true if the node exists. Returns false
func (m *CrdtMeshManager) GetNode(endpoint string) (*MeshNodeCrdt, error) { func (m *CrdtMeshManager) NodeExists(key string) bool {
node, err := m.doc.Path("nodes").Map().Get(key)
return node.Kind() == automerge.KindMap && err != nil
}
func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
node, err := m.doc.Path("nodes").Map().Get(endpoint) node, err := m.doc.Path("nodes").Map().Get(endpoint)
if node.Kind() != automerge.KindMap {
return nil, fmt.Errorf("GetNode: something went wrong %s is not a map type")
}
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -49,6 +49,10 @@ type WgMeshConfiguration struct {
Timeout int `yaml:"timeout"` Timeout int `yaml:"timeout"`
// PruneTime number of seconds before we consider the 'node' as dead // PruneTime number of seconds before we consider the 'node' as dead
PruneTime int `yaml:"pruneTime"` PruneTime int `yaml:"pruneTime"`
// Profile whether or not to include a http server that profiles the code
Profile bool `yaml:"profile"`
// StubWg whether or not to stub the WireGuard types
StubWg bool `yaml:"stubWg"`
} }
func ValidateConfiguration(c *WgMeshConfiguration) error { func ValidateConfiguration(c *WgMeshConfiguration) error {

View File

@ -11,8 +11,6 @@ import (
) )
type NewMeshArgs struct { type NewMeshArgs struct {
// IfName is the interface that the mesh instance will run on
IfName string
// WgPort is the WireGuard port to expose // WgPort is the WireGuard port to expose
WgPort int WgPort int
// Endpoint is the routable alias of the machine. Can be an IP // Endpoint is the routable alias of the machine. Can be an IP
@ -25,8 +23,6 @@ type JoinMeshArgs struct {
MeshId string MeshId string
// IpAddress is a routable IP in another mesh // IpAddress is a routable IP in another mesh
IpAdress string IpAdress string
// IfName is the interface name of the mesh
IfName string
// Port is the WireGuard port to expose // Port is the WireGuard port to expose
Port int Port int
// Endpoint is the routable address of this machine. If not provided // Endpoint is the routable address of this machine. If not provided
@ -52,6 +48,11 @@ type QueryMesh struct {
Query string Query string
} }
type GetNodeArgs struct {
NodeId string
MeshId string
}
type MeshIpc interface { type MeshIpc interface {
CreateMesh(args *NewMeshArgs, reply *string) error CreateMesh(args *NewMeshArgs, reply *string) error
ListMeshes(name string, reply *ListMeshReply) error ListMeshes(name string, reply *ListMeshReply) error
@ -64,6 +65,7 @@ type MeshIpc interface {
PutDescription(description string, reply *string) error PutDescription(description string, reply *string) error
PutAlias(alias string, reply *string) error PutAlias(alias string, reply *string) error
PutService(args PutServiceArgs, reply *string) error PutService(args PutServiceArgs, reply *string) error
GetNode(args GetNodeArgs, reply *string) error
DeleteService(service string, reply *string) error DeleteService(service string, reply *string) error
} }

View File

@ -14,7 +14,7 @@ import (
) )
type MeshManager interface { type MeshManager interface {
CreateMesh(devName string, port int) (string, error) CreateMesh(port int) (string, error)
AddMesh(params *AddMeshParams) error AddMesh(params *AddMeshParams) error
HasChanges(meshid string) bool HasChanges(meshid string) bool
GetMesh(meshId string) MeshProvider GetMesh(meshId string) MeshProvider
@ -34,6 +34,7 @@ type MeshManager interface {
Prune() error Prune() error
Close() error Close() error
GetMonitor() MeshMonitor GetMonitor() MeshMonitor
GetNode(string, string) MeshNode
} }
type MeshManagerImpl struct { type MeshManagerImpl struct {
@ -79,6 +80,22 @@ func (m *MeshManagerImpl) SetService(service string, value string) error {
return nil return nil
} }
func (m *MeshManagerImpl) GetNode(meshid, nodeId string) MeshNode {
mesh, ok := m.Meshes[meshid]
if !ok {
return nil
}
node, err := mesh.GetNode(nodeId)
if err != nil {
return nil
}
return node
}
// GetMonitor implements MeshManager. // GetMonitor implements MeshManager.
func (m *MeshManagerImpl) GetMonitor() MeshMonitor { func (m *MeshManagerImpl) GetMonitor() MeshMonitor {
return m.Monitor return m.Monitor
@ -98,15 +115,25 @@ func (m *MeshManagerImpl) Prune() error {
} }
// CreateMesh: Creates a new mesh, stores it and returns the mesh id // CreateMesh: Creates a new mesh, stores it and returns the mesh id
func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) { func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
meshId, err := m.idGenerator.GetId() meshId, err := m.idGenerator.GetId()
var ifName string = ""
if err != nil { if err != nil {
return "", err return "", err
} }
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(port)
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
}
}
nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{ nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: devName, DevName: ifName,
Port: port, Port: port,
Conf: m.conf, Conf: m.conf,
Client: m.Client, Client: m.Client,
@ -117,30 +144,31 @@ func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) {
return "", fmt.Errorf("error creating mesh: %w", err) return "", fmt.Errorf("error creating mesh: %w", err)
} }
err = m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: devName,
Port: port,
})
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
}
m.Meshes[meshId] = nodeManager m.Meshes[meshId] = nodeManager
return meshId, nil return meshId, nil
} }
type AddMeshParams struct { type AddMeshParams struct {
MeshId string MeshId string
DevName string
WgPort int WgPort int
MeshBytes []byte MeshBytes []byte
} }
// AddMesh: Add the mesh to the list of meshes // AddMesh: Add the mesh to the list of meshes
func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error { func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
var ifName string
var err error
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(params.WgPort)
if err != nil {
return err
}
}
meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{ meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: params.DevName, DevName: ifName,
Port: params.WgPort, Port: params.WgPort,
Conf: m.conf, Conf: m.conf,
Client: m.Client, Client: m.Client,
@ -158,11 +186,7 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
} }
m.Meshes[params.MeshId] = meshProvider m.Meshes[params.MeshId] = meshProvider
return nil
return m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: params.DevName,
Port: params.WgPort,
})
} }
// HasChanges returns true if the mesh has changes // HasChanges returns true if the mesh has changes
@ -195,6 +219,11 @@ func (s *MeshManagerImpl) EnableInterface(meshId string) error {
// GetPublicKey: Gets the public key of the WireGuard mesh // GetPublicKey: Gets the public key of the WireGuard mesh
func (s *MeshManagerImpl) GetPublicKey(meshId string) (*wgtypes.Key, error) { func (s *MeshManagerImpl) GetPublicKey(meshId string) (*wgtypes.Key, error) {
if s.conf.StubWg {
zeroedKey := make([]byte, wgtypes.KeyLen)
return (*wgtypes.Key)(zeroedKey), nil
}
mesh, ok := s.Meshes[meshId] mesh, ok := s.Meshes[meshId]
if !ok { if !ok {
@ -216,7 +245,6 @@ type AddSelfParams struct {
// WgPort is the WireGuard port to advertise // WgPort is the WireGuard port to advertise
WgPort int WgPort int
// Endpoint is the alias of the machine to send routable packets // Endpoint is the alias of the machine to send routable packets
// to
Endpoint string Endpoint string
} }
@ -247,16 +275,18 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
Endpoint: params.Endpoint, Endpoint: params.Endpoint,
}) })
device, err := mesh.GetDevice() if !s.conf.StubWg {
device, err := mesh.GetDevice()
if err != nil { if err != nil {
return fmt.Errorf("failed to get device %w", err) return fmt.Errorf("failed to get device %w", err)
} }
err = s.interfaceManipulator.AddAddress(device.Name, fmt.Sprintf("%s/64", nodeIP)) err = s.interfaceManipulator.AddAddress(device.Name, fmt.Sprintf("%s/64", nodeIP))
if err != nil { if err != nil {
return fmt.Errorf("addSelf: failed to add address to dev %w", err) return fmt.Errorf("addSelf: failed to add address to dev %w", err)
}
} }
s.Meshes[params.MeshId].AddNode(node) s.Meshes[params.MeshId].AddNode(node)
@ -277,13 +307,16 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
return err return err
} }
device, err := mesh.GetDevice() if !s.conf.StubWg {
device, e := mesh.GetDevice()
if err != nil { if e != nil {
return err return err
}
err = s.interfaceManipulator.RemoveInterface(device.Name)
} }
err = s.interfaceManipulator.RemoveInterface(device.Name)
delete(s.Meshes, meshId) delete(s.Meshes, meshId)
return err return err
} }
@ -295,15 +328,9 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
return nil, fmt.Errorf("mesh %s does not exist", meshId) return nil, fmt.Errorf("mesh %s does not exist", meshId)
} }
snapshot, err := meshInstance.GetMesh() node, err := meshInstance.GetNode(s.HostParameters.HostEndpoint)
if err != nil { if err != nil {
return nil, err
}
node, ok := snapshot.GetNodes()[s.HostParameters.HostEndpoint]
if !ok {
return nil, errors.New("the node doesn't exist in the mesh") return nil, errors.New("the node doesn't exist in the mesh")
} }
@ -317,15 +344,17 @@ func (s *MeshManagerImpl) ApplyConfig() error {
return err return err
} }
return s.RouteManager.InstallRoutes() return nil
} }
func (s *MeshManagerImpl) SetDescription(description string) error { func (s *MeshManagerImpl) SetDescription(description string) error {
for _, mesh := range s.Meshes { for _, mesh := range s.Meshes {
err := mesh.SetDescription(s.HostParameters.HostEndpoint, description) if mesh.NodeExists(s.HostParameters.HostEndpoint) {
err := mesh.SetDescription(s.HostParameters.HostEndpoint, description)
if err != nil { if err != nil {
return err return err
}
} }
} }
@ -335,10 +364,12 @@ func (s *MeshManagerImpl) SetDescription(description string) error {
// SetAlias implements MeshManager. // SetAlias implements MeshManager.
func (s *MeshManagerImpl) SetAlias(alias string) error { func (s *MeshManagerImpl) SetAlias(alias string) error {
for _, mesh := range s.Meshes { for _, mesh := range s.Meshes {
err := mesh.SetAlias(s.HostParameters.HostEndpoint, alias) if mesh.NodeExists(s.HostParameters.HostEndpoint) {
err := mesh.SetAlias(s.HostParameters.HostEndpoint, alias)
if err != nil { if err != nil {
return err return err
}
} }
} }
return nil return nil
@ -347,16 +378,8 @@ func (s *MeshManagerImpl) SetAlias(alias string) error {
// UpdateTimeStamp updates the timestamp of this node in all meshes // UpdateTimeStamp updates the timestamp of this node in all meshes
func (s *MeshManagerImpl) UpdateTimeStamp() error { func (s *MeshManagerImpl) UpdateTimeStamp() error {
for _, mesh := range s.Meshes { for _, mesh := range s.Meshes {
snapshot, err := mesh.GetMesh() if mesh.NodeExists(s.HostParameters.HostEndpoint) {
err := mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint)
if err != nil {
return err
}
_, exists := snapshot.GetNodes()[s.HostParameters.HostEndpoint]
if exists {
err = mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint)
if err != nil { if err != nil {
return err return err
@ -375,7 +398,12 @@ func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider {
return s.Meshes return s.Meshes
} }
// Close the mesh manager
func (s *MeshManagerImpl) Close() error { func (s *MeshManagerImpl) Close() error {
if s.conf.StubWg {
return nil
}
for _, mesh := range s.Meshes { for _, mesh := range s.Meshes {
dev, err := mesh.GetDevice() dev, err := mesh.GetDevice()

View File

@ -76,6 +76,21 @@ type MeshProviderStub struct {
snapshot *MeshSnapshotStub snapshot *MeshSnapshotStub
} }
// GetNodeIds implements MeshProvider.
func (*MeshProviderStub) GetNodeIds() []string {
panic("unimplemented")
}
// GetNode implements MeshProvider.
func (*MeshProviderStub) GetNode(string) (MeshNode, error) {
panic("unimplemented")
}
// NodeExists implements MeshProvider.
func (*MeshProviderStub) NodeExists(string) bool {
panic("unimplemented")
}
// AddService implements MeshProvider. // AddService implements MeshProvider.
func (*MeshProviderStub) AddService(nodeId string, key string, value string) error { func (*MeshProviderStub) AddService(nodeId string, key string, value string) error {
panic("unimplemented") panic("unimplemented")
@ -196,6 +211,11 @@ type MeshManagerStub struct {
meshes map[string]MeshProvider meshes map[string]MeshProvider
} }
// GetNode implements MeshManager.
func (*MeshManagerStub) GetNode(string, string) MeshNode {
panic("unimplemented")
}
// RemoveService implements MeshManager. // RemoveService implements MeshManager.
func (*MeshManagerStub) RemoveService(service string) error { func (*MeshManagerStub) RemoveService(service string) error {
panic("unimplemented") panic("unimplemented")
@ -230,7 +250,7 @@ func NewMeshManagerStub() MeshManager {
return &MeshManagerStub{meshes: make(map[string]MeshProvider)} return &MeshManagerStub{meshes: make(map[string]MeshProvider)}
} }
func (m *MeshManagerStub) CreateMesh(devName string, port int) (string, error) { func (m *MeshManagerStub) CreateMesh(port int) (string, error) {
return "tim123", nil return "tim123", nil
} }

View File

@ -92,7 +92,7 @@ type MeshSyncer interface {
type MeshProvider interface { type MeshProvider interface {
// AddNode() adds a node to the mesh // AddNode() adds a node to the mesh
AddNode(node MeshNode) AddNode(node MeshNode)
// GetMesh() returns a snapshot of the mesh provided by the mesh provider // GetMesh() returns a snapshot of the mesh provided by the mesh provider.
GetMesh() (MeshSnapshot, error) GetMesh() (MeshSnapshot, error)
// GetMeshId() returns the ID of the mesh network // GetMeshId() returns the ID of the mesh network
GetMeshId() string GetMeshId() string
@ -114,6 +114,10 @@ type MeshProvider interface {
RemoveRoutes(nodeId string, route ...string) error RemoveRoutes(nodeId string, route ...string) error
// GetSyncer: returns the automerge syncer for sync // GetSyncer: returns the automerge syncer for sync
GetSyncer() MeshSyncer GetSyncer() MeshSyncer
// GetNode get a particular not within the mesh
GetNode(string) (MeshNode, error)
// NodeExists: returns true if a particular node exists false otherwise
NodeExists(string) bool
// SetDescription: sets the description of this automerge data type // SetDescription: sets the description of this automerge data type
SetDescription(nodeId string, description string) error SetDescription(nodeId string, description string) error
// SetAlias: set the alias of the nodeId // SetAlias: set the alias of the nodeId
@ -125,6 +129,7 @@ type MeshProvider interface {
// Prune: prunes all nodes that have not updated their timestamp in // Prune: prunes all nodes that have not updated their timestamp in
// pruneAmount seconds // pruneAmount seconds
Prune(pruneAmount int) error Prune(pruneAmount int) error
GetNodeIds() []string
} }
// HostParameters contains the IDs of a node // HostParameters contains the IDs of a node

View File

@ -53,7 +53,7 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
return nil, err return nil, err
} }
nodes := lib.Map(lib.MapValues(snapshot.GetNodes()), meshNodeToQueryNode) nodes := lib.Map(lib.MapValues(snapshot.GetNodes()), MeshNodeToQueryNode)
result, err := jmespath.Search(queryParams, nodes) result, err := jmespath.Search(queryParams, nodes)
@ -65,7 +65,7 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
return bytes, err return bytes, err
} }
func meshNodeToQueryNode(node mesh.MeshNode) *QueryNode { func MeshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode := new(QueryNode) queryNode := new(QueryNode)
queryNode.HostEndpoint = node.GetHostEndpoint() queryNode.HostEndpoint = node.GetHostEndpoint()
pubKey, _ := node.GetPublicKey() pubKey, _ := node.GetPublicKey()

View File

@ -2,6 +2,7 @@ package robin
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
@ -10,6 +11,7 @@ import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/wgmesh/pkg/rpc"
) )
@ -18,7 +20,7 @@ type IpcHandler struct {
} }
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error { func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
meshId, err := n.Server.GetMeshManager().CreateMesh(args.IfName, args.WgPort) meshId, err := n.Server.GetMeshManager().CreateMesh(args.WgPort)
if err != nil { if err != nil {
return err return err
@ -81,7 +83,6 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{ err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
MeshId: args.MeshId, MeshId: args.MeshId,
DevName: args.IfName,
WgPort: args.Port, WgPort: args.Port,
MeshBytes: meshReply.Mesh, MeshBytes: meshReply.Mesh,
}) })
@ -242,6 +243,27 @@ func (n *IpcHandler) DeleteService(service string, reply *string) error {
return nil return nil
} }
func (n *IpcHandler) GetNode(args ipc.GetNodeArgs, reply *string) error {
node := n.Server.GetMeshManager().GetNode(args.MeshId, args.NodeId)
if node == nil {
*reply = "nil"
return nil
}
queryNode := query.MeshNodeToQueryNode(node)
bytes, err := json.Marshal(queryNode)
if err != nil {
*reply = err.Error()
return nil
}
*reply = string(bytes)
return nil
}
type RobinIpcParams struct { type RobinIpcParams struct {
CtrlServer ctrlserver.CtrlServer CtrlServer ctrlserver.CtrlServer
} }

View File

@ -1,7 +1,6 @@
package sync package sync
import ( import (
"errors"
"math/rand" "math/rand"
"sync" "sync"
"time" "time"
@ -36,24 +35,16 @@ func (s *SyncerImpl) Sync(meshId string) error {
} }
logging.Log.WriteInfof("UPDATING WG CONF") logging.Log.WriteInfof("UPDATING WG CONF")
err := s.manager.ApplyConfig()
if err != nil { if s.manager.HasChanges(meshId) {
logging.Log.WriteInfof("Failed to update config %w", err) err := s.manager.ApplyConfig()
if err != nil {
logging.Log.WriteInfof("Failed to update config %w", err)
}
} }
theMesh := s.manager.GetMesh(meshId) nodeNames := s.manager.GetMesh(meshId).GetNodeIds()
if theMesh == nil {
return errors.New("the provided mesh does not exist")
}
snapshot, _ := theMesh.GetMesh()
nodes := snapshot.GetNodes()
if len(nodes) <= 1 {
return nil
}
self, err := s.manager.GetSelf(meshId) self, err := s.manager.GetSelf(meshId)
@ -61,17 +52,6 @@ func (s *SyncerImpl) Sync(meshId string) error {
return err return err
} }
excludedNodes := map[string]struct{}{
self.GetHostEndpoint(): {},
}
meshNodes := lib.MapValuesWithExclude(nodes, excludedNodes)
getNames := func(node mesh.MeshNode) string {
return node.GetHostEndpoint()
}
nodeNames := lib.Map(meshNodes, getNames)
neighbours := s.cluster.GetNeighbours(nodeNames, self.GetHostEndpoint()) neighbours := s.cluster.GetNeighbours(nodeNames, self.GetHostEndpoint())
randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate) randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
@ -81,7 +61,7 @@ func (s *SyncerImpl) Sync(meshId string) error {
before := time.Now() before := time.Now()
if len(meshNodes) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance { if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
logging.Log.WriteInfof("Sending to random cluster") logging.Log.WriteInfof("Sending to random cluster")
interCluster := s.cluster.GetInterCluster(nodeNames, self.GetHostEndpoint()) interCluster := s.cluster.GetInterCluster(nodeNames, self.GetHostEndpoint())
randomSubset = append(randomSubset, interCluster) randomSubset = append(randomSubset, interCluster)
@ -109,7 +89,8 @@ func (s *SyncerImpl) Sync(meshId string) error {
// Check if any changes have occurred and trigger callbacks // Check if any changes have occurred and trigger callbacks
// if changes have occurred. // if changes have occurred.
return s.manager.GetMonitor().Trigger() // return s.manager.GetMonitor().Trigger()
return nil
} }
// SyncMeshes: Sync all meshes // SyncMeshes: Sync all meshes

View File

@ -50,7 +50,6 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endP
err = s.server.MeshManager.AddMesh(&mesh.AddMeshParams{ err = s.server.MeshManager.AddMesh(&mesh.AddMeshParams{
MeshId: meshId, MeshId: meshId,
DevName: ifName,
WgPort: port, WgPort: port,
MeshBytes: reply.Mesh, MeshBytes: reply.Mesh,
}) })

View File

@ -2,7 +2,7 @@ package wg
type WgInterfaceManipulatorStub struct{} type WgInterfaceManipulatorStub struct{}
func (i *WgInterfaceManipulatorStub) CreateInterface(params *CreateInterfaceParams) error { func (i *WgInterfaceManipulatorStub) CreateInterface(port int) error {
return nil return nil
} }

View File

@ -8,14 +8,9 @@ func (m *WgError) Error() string {
return m.msg return m.msg
} }
type CreateInterfaceParams struct {
IfName string
Port int
}
type WgInterfaceManipulator interface { type WgInterfaceManipulator interface {
// CreateInterface creates a WireGuard interface // CreateInterface creates a WireGuard interface
CreateInterface(params *CreateInterfaceParams) error CreateInterface(port int) (string, error)
// AddAddress adds an address to the given interface name // AddAddress adds an address to the given interface name
AddAddress(ifName string, addr string) error AddAddress(ifName string, addr string) error
// RemoveInterface removes the specified interface // RemoveInterface removes the specified interface

View File

@ -1,6 +1,9 @@
package wg package wg
import ( import (
"crypto"
"crypto/rand"
"encoding/base64"
"fmt" "fmt"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/lib"
@ -13,40 +16,54 @@ type WgInterfaceManipulatorImpl struct {
client *wgctrl.Client client *wgctrl.Client
} }
const hashLength = 6
// CreateInterface creates a WireGuard interface // CreateInterface creates a WireGuard interface
func (m *WgInterfaceManipulatorImpl) CreateInterface(params *CreateInterfaceParams) error { func (m *WgInterfaceManipulatorImpl) CreateInterface(port int) (string, error) {
rtnl, err := lib.NewRtNetlinkConfig() rtnl, err := lib.NewRtNetlinkConfig()
if err != nil { if err != nil {
return fmt.Errorf("failed to access link: %w", err) return "", fmt.Errorf("failed to access link: %w", err)
} }
defer rtnl.Close() defer rtnl.Close()
err = rtnl.CreateLink(params.IfName) randomBuf := make([]byte, 32)
_, err = rand.Read(randomBuf)
if err != nil { if err != nil {
return fmt.Errorf("failed to create link: %w", err) return "", err
}
md5 := crypto.MD5.New().Sum(randomBuf)
md5Str := fmt.Sprintf("wg%s", base64.StdEncoding.EncodeToString(md5)[:hashLength])
err = rtnl.CreateLink(md5Str)
if err != nil {
return "", fmt.Errorf("failed to create link: %w", err)
} }
privateKey, err := wgtypes.GeneratePrivateKey() privateKey, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
return fmt.Errorf("failed to create private key: %w", err) return "", fmt.Errorf("failed to create private key: %w", err)
} }
var cfg wgtypes.Config = wgtypes.Config{ var cfg wgtypes.Config = wgtypes.Config{
PrivateKey: &privateKey, PrivateKey: &privateKey,
ListenPort: &params.Port, ListenPort: &port,
} }
err = m.client.ConfigureDevice(params.IfName, cfg) err = m.client.ConfigureDevice(md5Str, cfg)
if err != nil { if err != nil {
return fmt.Errorf("failed to configure dev: %w", err) m.RemoveInterface(md5Str)
return "", fmt.Errorf("failed to configure dev: %w", err)
} }
logging.Log.WriteInfof("ip link set up dev %s type wireguard", params.IfName) logging.Log.WriteInfof("ip link set up dev %s type wireguard", md5Str)
return nil return md5Str, nil
} }
// Add an address to the given interface // Add an address to the given interface