Compare commits

..

9 Commits

Author SHA1 Message Date
b179cd3cf4 Hashing the WireGuard interface
Hashing the interface and using ephmeral ports so that the admin doesn't
choose an interface and port combination. An administrator can alteranatively
decide to provide port but this isn't critical.
2023-11-20 13:03:42 +00:00
8f211aa116 Merge pull request #18 from tim-beatham/26-performance-testing
Stubbing out WireGuard components
2023-11-20 11:29:37 +00:00
388153e706 Stubbing out WireGuard components
Stubbing our WireGuard components so that I can use docker/podman
network_mode=host. This is much more efficient than the docker/podman
userspace network.
2023-11-20 11:28:12 +00:00
023565d985 Merge pull request #17 from tim-beatham/25-ability-to-aliases
25 ability to aliases
2023-11-17 22:20:57 +00:00
36c264b38e 25-ability-aliases
Fixed unit tests failing
2023-11-17 22:18:53 +00:00
68db795f47 Ability to specify aliases
Ability to specify aliases that automatically append to /etc/hosts
2023-11-17 22:13:51 +00:00
f6160fe138 Adding aliases that automatically gets added 2023-11-17 19:13:20 +00:00
2c5289afb0 Merge pull request #16 from tim-beatham/15-add-rest-api
Developed a rest API
2023-11-15 12:57:05 +00:00
7199d07a76 Added smegmesh submodule 2023-11-13 10:46:52 +00:00
28 changed files with 921 additions and 175 deletions

3
.gitmodules vendored Normal file
View File

@ -0,0 +1,3 @@
[submodule "smegmesh-web"]
path = smegmesh-web
url = git@github.com:tim-beatham/smegmesh-web.git

View File

@ -13,5 +13,5 @@ func main() {
log.Fatal(err.Error())
}
apiServer.Run(":8080")
apiServer.Run(":40000")
}

View File

@ -16,7 +16,6 @@ const SockAddr = "/tmp/wgmesh_ipc.sock"
type CreateMeshParams struct {
Client *ipcRpc.Client
IfName string
WgPort int
Endpoint string
}
@ -24,7 +23,6 @@ type CreateMeshParams struct {
func createMesh(args *CreateMeshParams) string {
var reply string
newMeshParams := ipc.NewMeshArgs{
IfName: args.IfName,
WgPort: args.WgPort,
Endpoint: args.Endpoint,
}
@ -68,7 +66,6 @@ func joinMesh(params *JoinMeshParams) string {
args := ipc.JoinMeshArgs{
MeshId: params.MeshId,
IpAdress: params.IpAddress,
IfName: params.IfName,
Port: params.WgPort,
}
@ -171,6 +168,68 @@ func putDescription(client *ipcRpc.Client, description string) {
fmt.Println(reply)
}
// putAlias: puts an alias for the node
func putAlias(client *ipcRpc.Client, alias string) {
var reply string
err := client.Call("IpcHandler.PutAlias", &alias, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func setService(client *ipcRpc.Client, service, value string) {
var reply string
serviceArgs := &ipc.PutServiceArgs{
Service: service,
Value: value,
}
err := client.Call("IpcHandler.PutService", serviceArgs, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func deleteService(client *ipcRpc.Client, service string) {
var reply string
err := client.Call("IpcHandler.PutService", &service, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func getNode(client *ipcRpc.Client, nodeId, meshId string) {
var reply string
args := &ipc.GetNodeArgs{
NodeId: nodeId,
MeshId: meshId,
}
err := client.Call("IpcHandler.GetNode", &args, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func main() {
parser := argparse.NewParser("wg-mesh",
"wg-mesh Manipulate WireGuard meshes")
@ -184,19 +243,19 @@ func main() {
leaveMeshCmd := parser.NewCommand("leave-mesh", "Leave a mesh network")
queryMeshCmd := parser.NewCommand("query-mesh", "Query a mesh network using JMESPath")
putDescriptionCmd := parser.NewCommand("put-description", "Place a description for the node")
putAliasCmd := parser.NewCommand("put-alias", "Place an alias for the node")
setServiceCmd := parser.NewCommand("set-service", "Place a service into your advertisements")
deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements")
getNodeCmd := parser.NewCommand("get-node", "Get a specific node from the mesh")
var newMeshIfName *string = newMeshCmd.String("f", "ifname", &argparse.Options{Required: true})
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{})
var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{})
var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true})
var joinMeshIfName *string = joinMeshCmd.String("f", "ifname", &argparse.Options{Required: true})
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{})
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{})
// var getMeshId *string = getMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true})
var getGraphMeshId *string = getGraphCmd.String("m", "mesh", &argparse.Options{Required: true})
@ -208,6 +267,16 @@ func main() {
var description *string = putDescriptionCmd.String("d", "description", &argparse.Options{Required: true})
var alias *string = putAliasCmd.String("a", "alias", &argparse.Options{Required: true})
var serviceKey *string = setServiceCmd.String("s", "service", &argparse.Options{Required: true})
var serviceValue *string = setServiceCmd.String("v", "value", &argparse.Options{Required: true})
var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{Required: true})
var getNodeNodeId *string = getNodeCmd.String("n", "nodeid", &argparse.Options{Required: true})
var getNodeMeshId *string = getNodeCmd.String("m", "meshid", &argparse.Options{Required: true})
err := parser.Parse(os.Args)
if err != nil {
@ -224,7 +293,6 @@ func main() {
if newMeshCmd.Happened() {
fmt.Println(createMesh(&CreateMeshParams{
Client: client,
IfName: *newMeshIfName,
WgPort: *newMeshPort,
Endpoint: *newMeshEndpoint,
}))
@ -237,7 +305,6 @@ func main() {
if joinMeshCmd.Happened() {
fmt.Println(joinMesh(&JoinMeshParams{
Client: client,
IfName: *joinMeshIfName,
WgPort: *joinMeshPort,
IpAddress: *joinMeshIpAddress,
MeshId: *joinMeshId,
@ -245,10 +312,6 @@ func main() {
}))
}
// if getMeshCmd.Happened() {
// getMesh(client, *getMeshId)
// }
if getGraphCmd.Happened() {
getGraph(client, *getGraphMeshId)
}
@ -268,4 +331,20 @@ func main() {
if putDescriptionCmd.Happened() {
putDescription(client, *description)
}
if putAliasCmd.Happened() {
putAlias(client, *alias)
}
if setServiceCmd.Happened() {
setService(client, *serviceKey, *serviceValue)
}
if deleteServiceCmd.Happened() {
deleteService(client, *deleteServiceKey)
}
if getNodeCmd.Happened() {
getNode(client, *getNodeNodeId, *getNodeMeshId)
}
}

View File

@ -1,7 +1,8 @@
package main
import (
"log"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
@ -35,6 +36,12 @@ func main() {
return
}
if conf.Profile {
go func() {
http.ListenAndServe("localhost:6060", nil)
}()
}
var robinRpc robin.WgRpc
var robinIpc robin.IpcHandler
var syncProvider sync.SyncServiceImpl
@ -45,8 +52,8 @@ func main() {
SyncProvider: &syncProvider,
Client: client,
}
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
syncProvider.Server = ctrlServer
syncRequester := sync.NewSyncRequester(ctrlServer)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester)
@ -65,7 +72,7 @@ func main() {
return
}
log.Println("Running IPC Handler")
logging.Log.WriteInfof("Running IPC Handler")
go ipc.RunIpcHandler(&robinIpc)
go syncScheduler.Run()

7
go.mod
View File

@ -5,9 +5,12 @@ go 1.21.3
require (
github.com/akamensky/argparse v1.4.0
github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9
github.com/gin-gonic/gin v1.9.1
github.com/google/uuid v1.3.0
github.com/jmespath/go-jmespath v0.4.0
github.com/jsimonetti/rtnetlink v1.3.5
github.com/sirupsen/logrus v1.9.3
golang.org/x/sys v0.14.0
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
google.golang.org/grpc v1.58.1
google.golang.org/protobuf v1.31.0
@ -19,7 +22,6 @@ require (
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/gin-gonic/gin v1.9.1 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
@ -27,7 +29,6 @@ require (
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/jsimonetti/rtnetlink v1.3.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
@ -40,12 +41,10 @@ require (
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.13.0 // indirect
golang.org/x/net v0.15.0 // indirect
golang.org/x/sync v0.3.0 // indirect
golang.org/x/sys v0.12.0 // indirect
golang.org/x/text v0.13.0 // indirect
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect

View File

@ -37,6 +37,8 @@ func meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode {
Description: meshNode.Description,
Routes: meshNode.Routes,
PublicKey: meshNode.PublicKey,
Alias: meshNode.Alias,
Services: meshNode.Services,
}
}
@ -60,11 +62,11 @@ func (s *SmegServer) CreateMesh(c *gin.Context) {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
ipcRequest := ipc.NewMeshArgs{
IfName: createMesh.IfName,
WgPort: createMesh.WgPort,
}
@ -98,7 +100,6 @@ func (s *SmegServer) JoinMesh(c *gin.Context) {
ipcRequest := ipc.JoinMeshArgs{
MeshId: joinMesh.MeshId,
IpAdress: joinMesh.Bootstrap,
IfName: joinMesh.IfName,
Port: joinMesh.WgPort,
}

View File

@ -1,13 +1,15 @@
package api
type SmegNode struct {
WgHost string `json:"wgHost"`
WgEndpoint string `json:"wgEndpoint"`
Endpoint string `json:"endpoint"`
Timestamp int `json:"timestamp"`
Description string `json:"description"`
PublicKey string `json:"publicKey"`
Routes []string `json:"routes"`
Alias string `json:"alias"`
WgHost string `json:"wgHost"`
WgEndpoint string `json:"wgEndpoint"`
Endpoint string `json:"endpoint"`
Timestamp int `json:"timestamp"`
Description string `json:"description"`
PublicKey string `json:"publicKey"`
Routes []string `json:"routes"`
Services map[string]string `json:"services"`
}
type SmegMesh struct {
@ -16,13 +18,11 @@ type SmegMesh struct {
}
type CreateMeshRequest struct {
IfName string `json:"ifName" binding:"required"`
WgPort int `json:"port" binding:"required,gte=1024,lt=65535"`
WgPort int `json:"port" binding:"gte=1024,lt=65535"`
}
type JoinMeshRequest struct {
IfName string `json:"ifName" binding:"required"`
WgPort int `json:"port" binding:"required,gte=1024,lt=65535"`
WgPort int `json:"port" binding:"gte=1024,lt=65535"`
Bootstrap string `json:"bootstrap" binding:"required"`
MeshId string `json:"meshid" binding:"required"`
}

View File

@ -18,13 +18,14 @@ import (
// CrdtMeshManager manages nodes in the crdt mesh
type CrdtMeshManager struct {
MeshId string
IfName string
NodeId string
Client *wgctrl.Client
doc *automerge.Doc
LastHash automerge.ChangeHash
conf *conf.WgMeshConfiguration
MeshId string
IfName string
Client *wgctrl.Client
doc *automerge.Doc
LastHash automerge.ChangeHash
conf *conf.WgMeshConfiguration
cache *MeshCrdt
lastCacheHash automerge.ChangeHash
}
func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
@ -35,14 +36,37 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
}
crdt.Routes = make(map[string]interface{})
crdt.Services = make(map[string]string)
crdt.Timestamp = time.Now().Unix()
c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt)
}
func (c *CrdtMeshManager) GetNodeIds() []string {
keys, _ := c.doc.Path("nodes").Map().Keys()
return keys
}
// GetMesh(): Converts the document into a struct
func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
return automerge.As[*MeshCrdt](c.doc.Root())
changes, err := c.doc.Changes(c.lastCacheHash)
if err != nil {
return nil, err
}
if c.cache == nil || len(changes) > 3 {
c.lastCacheHash = c.LastHash
cache, err := automerge.As[*MeshCrdt](c.doc.Root())
if err != nil {
return nil, err
}
c.cache = cache
}
return c.cache, nil
}
// GetMeshId returns the meshid of the mesh
@ -82,13 +106,23 @@ func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, erro
manager.IfName = params.DevName
manager.Client = params.Client
manager.conf = &params.Conf
manager.cache = nil
return &manager, nil
}
// GetNode: returns a mesh node crdt.Close releases resources used by a Client.
func (m *CrdtMeshManager) GetNode(endpoint string) (*MeshNodeCrdt, error) {
// NodeExists: returns true if the node exists. Returns false
func (m *CrdtMeshManager) NodeExists(key string) bool {
node, err := m.doc.Path("nodes").Map().Get(key)
return node.Kind() == automerge.KindMap && err != nil
}
func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
node, err := m.doc.Path("nodes").Map().Get(endpoint)
if node.Kind() != automerge.KindMap {
return nil, fmt.Errorf("GetNode: something went wrong %s is not a map type")
}
if err != nil {
return nil, err
}
@ -178,6 +212,72 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro
return err
}
func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
return err
}
if node.Kind() != automerge.KindMap {
return fmt.Errorf("%s does not exist", nodeId)
}
err = node.Map().Set("alias", alias)
if err == nil {
logging.Log.WriteInfof("Updated Alias for %s to %s", nodeId, alias)
}
return err
}
func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return fmt.Errorf("AddService: node %s does not exist", nodeId)
}
service, err := node.Map().Get("services")
if err != nil {
return err
}
if service.Kind() != automerge.KindMap {
return fmt.Errorf("AddService: services property does not exist in node")
}
return service.Map().Set(key, value)
}
func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return fmt.Errorf("RemoveService: node %s does not exist", nodeId)
}
service, err := node.Map().Get("services")
if err != nil {
return err
}
if service.Kind() != automerge.KindMap {
return fmt.Errorf("services property does not exist")
}
err = service.Map().Delete(key)
if err != nil {
return fmt.Errorf("service %s does not exist", key)
}
return nil
}
// AddRoutes: adds routes to the specific nodeId
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -336,6 +436,20 @@ func (m *MeshNodeCrdt) GetIdentifier() string {
return strings.Join(constituents, ":")
}
func (m *MeshNodeCrdt) GetAlias() string {
return m.Alias
}
func (m *MeshNodeCrdt) GetServices() map[string]string {
services := make(map[string]string)
for key, service := range m.Services {
services[key] = service
}
return services
}
func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
nodes := make(map[string]mesh.MeshNode)
@ -348,6 +462,8 @@ func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
Timestamp: node.Timestamp,
Routes: node.Routes,
Description: node.Description,
Alias: node.Alias,
Services: node.GetServices(),
}
}

View File

@ -35,7 +35,9 @@ func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNod
WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()),
// Always set the routes as empty.
// Routes handled by external component
Routes: map[string]interface{}{},
Routes: map[string]interface{}{},
Description: "",
Alias: "",
}
}

View File

@ -8,7 +8,9 @@ type MeshNodeCrdt struct {
WgHost string `automerge:"wgHost"`
Timestamp int64 `automerge:"timestamp"`
Routes map[string]interface{} `automerge:"routes"`
Alias string `automerge:"alias"`
Description string `automerge:"description"`
Services map[string]string `automerge:"services"`
}
// MeshCrdt: Represents the mesh network as a whole

View File

@ -49,6 +49,10 @@ type WgMeshConfiguration struct {
Timeout int `yaml:"timeout"`
// PruneTime number of seconds before we consider the 'node' as dead
PruneTime int `yaml:"pruneTime"`
// Profile whether or not to include a http server that profiles the code
Profile bool `yaml:"profile"`
// StubWg whether or not to stub the WireGuard types
StubWg bool `yaml:"stubWg"`
}
func ValidateConfiguration(c *WgMeshConfiguration) error {

View File

@ -18,6 +18,8 @@ type MeshNode struct {
Timestamp int64
Routes []string
Description string
Alias string
Services map[string]string
}
// Represents a WireGuard Mesh

132
pkg/hosts/hosts.go Normal file
View File

@ -0,0 +1,132 @@
// hosts: utility for modifying the /etc/hosts file
package hosts
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"os"
"strings"
)
// HOSTS_FILE is the hosts file location
const HOSTS_FILE = "/etc/hosts"
const DOMAIN_HEADER = "#WG AUTO GENERATED HOSTS"
const DOMAIN_TRAILER = "#WG AUTO GENERATED HOSTS END"
type HostsEntry struct {
Alias string
Ip net.IP
}
// Generic interface to manipulate /etc/hosts file
type HostsManipulator interface {
// AddrAddr associates an aliasd with a given IP address
AddAddr(hosts ...HostsEntry)
// Remove deletes the entry from /etc/hosts
Remove(hosts ...HostsEntry)
// Writes the changes to /etc/hosts file
Write() error
}
type HostsManipulatorImpl struct {
hosts map[string]HostsEntry
}
// AddAddr implements HostsManipulator.
func (m *HostsManipulatorImpl) AddAddr(hosts ...HostsEntry) {
changed := false
for _, host := range hosts {
prev, ok := m.hosts[host.Ip.String()]
if !ok || prev.Alias != host.Alias {
changed = true
}
m.hosts[host.Ip.String()] = host
}
if changed {
m.Write()
}
}
// Remove implements HostsManipulator.
func (m *HostsManipulatorImpl) Remove(hosts ...HostsEntry) {
lenBefore := len(m.hosts)
for _, host := range hosts {
delete(m.hosts, host.Alias)
}
if lenBefore != len(m.hosts) {
m.Write()
}
}
func (m *HostsManipulatorImpl) removeHosts() string {
hostsFile, err := os.ReadFile(HOSTS_FILE)
if err != nil {
return ""
}
var contents strings.Builder
scanner := bufio.NewScanner(bytes.NewReader(hostsFile))
hostsSection := false
for scanner.Scan() {
line := scanner.Text()
if err == io.EOF {
break
} else if err != nil {
return ""
}
if !hostsSection && strings.Contains(line, DOMAIN_HEADER) {
hostsSection = true
}
if !hostsSection {
contents.WriteString(line + "\n")
}
if hostsSection && strings.Contains(line, DOMAIN_TRAILER) {
hostsSection = false
}
}
if scanner.Err() != nil && scanner.Err() != io.EOF {
return ""
}
return contents.String()
}
// Write implements HostsManipulator
func (m *HostsManipulatorImpl) Write() error {
contents := m.removeHosts()
var nextHosts strings.Builder
nextHosts.WriteString(contents)
nextHosts.WriteString(DOMAIN_HEADER + "\n")
for _, host := range m.hosts {
nextHosts.WriteString(fmt.Sprintf("%s\t%s\n", host.Ip.String(), host.Alias))
}
nextHosts.WriteString(DOMAIN_TRAILER + "\n")
return os.WriteFile(HOSTS_FILE, []byte(nextHosts.String()), 0644)
}
func NewHostsManipulator() HostsManipulator {
return &HostsManipulatorImpl{hosts: make(map[string]HostsEntry)}
}

View File

@ -11,8 +11,6 @@ import (
)
type NewMeshArgs struct {
// IfName is the interface that the mesh instance will run on
IfName string
// WgPort is the WireGuard port to expose
WgPort int
// Endpoint is the routable alias of the machine. Can be an IP
@ -25,8 +23,6 @@ type JoinMeshArgs struct {
MeshId string
// IpAddress is a routable IP in another mesh
IpAdress string
// IfName is the interface name of the mesh
IfName string
// Port is the WireGuard port to expose
Port int
// Endpoint is the routable address of this machine. If not provided
@ -34,6 +30,11 @@ type JoinMeshArgs struct {
Endpoint string
}
type PutServiceArgs struct {
Service string
Value string
}
type GetMeshReply struct {
Nodes []ctrlserver.MeshNode
}
@ -47,6 +48,11 @@ type QueryMesh struct {
Query string
}
type GetNodeArgs struct {
NodeId string
MeshId string
}
type MeshIpc interface {
CreateMesh(args *NewMeshArgs, reply *string) error
ListMeshes(name string, reply *ListMeshReply) error
@ -57,6 +63,10 @@ type MeshIpc interface {
GetDOT(meshId string, reply *string) error
Query(query QueryMesh, reply *string) error
PutDescription(description string, reply *string) error
PutAlias(alias string, reply *string) error
PutService(args PutServiceArgs, reply *string) error
GetNode(args GetNodeArgs, reply *string) error
DeleteService(service string, reply *string) error
}
const SockAddr = "/tmp/wgmesh_ipc.sock"

46
pkg/mesh/alias.go Normal file
View File

@ -0,0 +1,46 @@
package mesh
import (
"fmt"
"github.com/tim-beatham/wgmesh/pkg/hosts"
)
type MeshAliasManager interface {
AddAliases(nodes []MeshNode)
RemoveAliases(node []MeshNode)
}
type AliasManager struct {
hosts hosts.HostsManipulator
}
// AddAliases: on node update or change add aliases to the hosts file
func (a *AliasManager) AddAliases(nodes []MeshNode) {
for _, node := range nodes {
if node.GetAlias() != "" {
a.hosts.AddAddr(hosts.HostsEntry{
Alias: fmt.Sprintf("%s.smeg", node.GetAlias()),
Ip: node.GetWgHost().IP,
})
}
}
}
// RemoveAliases: on node remove remove aliases from the hosts file
func (a *AliasManager) RemoveAliases(nodes []MeshNode) {
for _, node := range nodes {
if node.GetAlias() != "" {
a.hosts.Remove(hosts.HostsEntry{
Alias: fmt.Sprintf("%s.smeg", node.GetAlias()),
Ip: node.GetWgHost().IP,
})
}
}
}
func NewAliasManager() MeshAliasManager {
return &AliasManager{
hosts: hosts.NewHostsManipulator(),
}
}

View File

@ -14,7 +14,7 @@ import (
)
type MeshManager interface {
CreateMesh(devName string, port int) (string, error)
CreateMesh(port int) (string, error)
AddMesh(params *AddMeshParams) error
HasChanges(meshid string) bool
GetMesh(meshId string) MeshProvider
@ -25,11 +25,16 @@ type MeshManager interface {
GetSelf(meshId string) (MeshNode, error)
ApplyConfig() error
SetDescription(description string) error
SetAlias(alias string) error
SetService(service string, value string) error
RemoveService(service string) error
UpdateTimeStamp() error
GetClient() *wgctrl.Client
GetMeshes() map[string]MeshProvider
Prune() error
Close() error
GetMonitor() MeshMonitor
GetNode(string, string) MeshNode
}
type MeshManagerImpl struct {
@ -46,6 +51,54 @@ type MeshManagerImpl struct {
idGenerator lib.IdGenerator
ipAllocator ip.IPAllocator
interfaceManipulator wg.WgInterfaceManipulator
Monitor MeshMonitor
}
// RemoveService implements MeshManager.
func (m *MeshManagerImpl) RemoveService(service string) error {
for _, mesh := range m.Meshes {
err := mesh.RemoveService(m.HostParameters.HostEndpoint, service)
if err != nil {
return err
}
}
return nil
}
// SetService implements MeshManager.
func (m *MeshManagerImpl) SetService(service string, value string) error {
for _, mesh := range m.Meshes {
err := mesh.AddService(m.HostParameters.HostEndpoint, service, value)
if err != nil {
return err
}
}
return nil
}
func (m *MeshManagerImpl) GetNode(meshid, nodeId string) MeshNode {
mesh, ok := m.Meshes[meshid]
if !ok {
return nil
}
node, err := mesh.GetNode(nodeId)
if err != nil {
return nil
}
return node
}
// GetMonitor implements MeshManager.
func (m *MeshManagerImpl) GetMonitor() MeshMonitor {
return m.Monitor
}
// Prune implements MeshManager.
@ -62,15 +115,25 @@ func (m *MeshManagerImpl) Prune() error {
}
// CreateMesh: Creates a new mesh, stores it and returns the mesh id
func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) {
func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
meshId, err := m.idGenerator.GetId()
var ifName string = ""
if err != nil {
return "", err
}
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(port)
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
}
}
nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: devName,
DevName: ifName,
Port: port,
Conf: m.conf,
Client: m.Client,
@ -81,30 +144,31 @@ func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) {
return "", fmt.Errorf("error creating mesh: %w", err)
}
err = m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: devName,
Port: port,
})
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
}
m.Meshes[meshId] = nodeManager
return meshId, nil
}
type AddMeshParams struct {
MeshId string
DevName string
WgPort int
MeshBytes []byte
}
// AddMesh: Add the mesh to the list of meshes
func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
var ifName string
var err error
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(params.WgPort)
if err != nil {
return err
}
}
meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: params.DevName,
DevName: ifName,
Port: params.WgPort,
Conf: m.conf,
Client: m.Client,
@ -122,11 +186,7 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
}
m.Meshes[params.MeshId] = meshProvider
return m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: params.DevName,
Port: params.WgPort,
})
return nil
}
// HasChanges returns true if the mesh has changes
@ -159,6 +219,11 @@ func (s *MeshManagerImpl) EnableInterface(meshId string) error {
// GetPublicKey: Gets the public key of the WireGuard mesh
func (s *MeshManagerImpl) GetPublicKey(meshId string) (*wgtypes.Key, error) {
if s.conf.StubWg {
zeroedKey := make([]byte, wgtypes.KeyLen)
return (*wgtypes.Key)(zeroedKey), nil
}
mesh, ok := s.Meshes[meshId]
if !ok {
@ -180,7 +245,6 @@ type AddSelfParams struct {
// WgPort is the WireGuard port to advertise
WgPort int
// Endpoint is the alias of the machine to send routable packets
// to
Endpoint string
}
@ -211,16 +275,18 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
Endpoint: params.Endpoint,
})
device, err := mesh.GetDevice()
if !s.conf.StubWg {
device, err := mesh.GetDevice()
if err != nil {
return fmt.Errorf("failed to get device %w", err)
}
if err != nil {
return fmt.Errorf("failed to get device %w", err)
}
err = s.interfaceManipulator.AddAddress(device.Name, fmt.Sprintf("%s/64", nodeIP))
err = s.interfaceManipulator.AddAddress(device.Name, fmt.Sprintf("%s/64", nodeIP))
if err != nil {
return fmt.Errorf("addSelf: failed to add address to dev %w", err)
if err != nil {
return fmt.Errorf("addSelf: failed to add address to dev %w", err)
}
}
s.Meshes[params.MeshId].AddNode(node)
@ -241,13 +307,16 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
return err
}
device, err := mesh.GetDevice()
if !s.conf.StubWg {
device, e := mesh.GetDevice()
if err != nil {
return err
if e != nil {
return err
}
err = s.interfaceManipulator.RemoveInterface(device.Name)
}
err = s.interfaceManipulator.RemoveInterface(device.Name)
delete(s.Meshes, meshId)
return err
}
@ -259,15 +328,9 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
return nil, fmt.Errorf("mesh %s does not exist", meshId)
}
snapshot, err := meshInstance.GetMesh()
node, err := meshInstance.GetNode(s.HostParameters.HostEndpoint)
if err != nil {
return nil, err
}
node, ok := snapshot.GetNodes()[s.HostParameters.HostEndpoint]
if !ok {
return nil, errors.New("the node doesn't exist in the mesh")
}
@ -281,34 +344,42 @@ func (s *MeshManagerImpl) ApplyConfig() error {
return err
}
return s.RouteManager.InstallRoutes()
return nil
}
func (s *MeshManagerImpl) SetDescription(description string) error {
for _, mesh := range s.Meshes {
err := mesh.SetDescription(s.HostParameters.HostEndpoint, description)
if mesh.NodeExists(s.HostParameters.HostEndpoint) {
err := mesh.SetDescription(s.HostParameters.HostEndpoint, description)
if err != nil {
return err
if err != nil {
return err
}
}
}
return nil
}
// SetAlias implements MeshManager.
func (s *MeshManagerImpl) SetAlias(alias string) error {
for _, mesh := range s.Meshes {
if mesh.NodeExists(s.HostParameters.HostEndpoint) {
err := mesh.SetAlias(s.HostParameters.HostEndpoint, alias)
if err != nil {
return err
}
}
}
return nil
}
// UpdateTimeStamp updates the timestamp of this node in all meshes
func (s *MeshManagerImpl) UpdateTimeStamp() error {
for _, mesh := range s.Meshes {
snapshot, err := mesh.GetMesh()
if err != nil {
return err
}
_, exists := snapshot.GetNodes()[s.HostParameters.HostEndpoint]
if exists {
err = mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint)
if mesh.NodeExists(s.HostParameters.HostEndpoint) {
err := mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint)
if err != nil {
return err
@ -327,7 +398,12 @@ func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider {
return s.Meshes
}
// Close the mesh manager
func (s *MeshManagerImpl) Close() error {
if s.conf.StubWg {
return nil
}
for _, mesh := range s.Meshes {
dev, err := mesh.GetDevice()
@ -359,7 +435,7 @@ type NewMeshManagerParams struct {
}
// Creates a new instance of a mesh manager with the given parameters
func NewMeshManager(params *NewMeshManagerParams) *MeshManagerImpl {
func NewMeshManager(params *NewMeshManagerParams) MeshManager {
hostParams := HostParameters{}
switch params.Conf.Endpoint {
@ -390,5 +466,11 @@ func NewMeshManager(params *NewMeshManagerParams) *MeshManagerImpl {
m.idGenerator = params.IdGenerator
m.ipAllocator = params.IPAllocator
m.interfaceManipulator = params.InterfaceManipulator
m.Monitor = NewMeshMonitor(m)
aliasManager := NewAliasManager()
m.Monitor.AddUpdateCallback(aliasManager.AddAliases)
m.Monitor.AddRemoveCallback(aliasManager.RemoveAliases)
return m
}

View File

@ -22,7 +22,7 @@ func getMeshConfiguration() *conf.WgMeshConfiguration {
}
}
func getMeshManager() *MeshManagerImpl {
func getMeshManager() MeshManager {
manager := NewMeshManager(&NewMeshManagerParams{
Conf: *getMeshConfiguration(),
Client: nil,
@ -51,7 +51,7 @@ func TestCreateMeshCreatesANewMeshProvider(t *testing.T) {
t.Fatal(`meshId should not be empty`)
}
_, exists := manager.Meshes[meshId]
_, exists := manager.GetMeshes()[meshId]
if !exists {
t.Fatal(`mesh was not created when it should be`)
@ -190,7 +190,7 @@ func TestLeaveMeshDeletesMesh(t *testing.T) {
t.Fatalf("%s", err.Error())
}
_, exists := manager.Meshes[meshId]
_, exists := manager.GetMeshes()[meshId]
if exists {
t.Fatalf(`expected mesh to have been deleted`)

81
pkg/mesh/monitor.go Normal file
View File

@ -0,0 +1,81 @@
package mesh
type OnChange = func([]MeshNode)
type MeshMonitor interface {
AddUpdateCallback(cb OnChange)
AddRemoveCallback(cb OnChange)
Trigger() error
}
type MeshMonitorImpl struct {
updateCbs []OnChange
removeCbs []OnChange
nodes map[string]MeshNode
manager MeshManager
}
// Trigger causes the mesh monitor to trigger all of
// the callbacks.
func (m *MeshMonitorImpl) Trigger() error {
changedNodes := make([]MeshNode, 0)
removedNodes := make([]MeshNode, 0)
nodes := make(map[string]MeshNode)
for _, mesh := range m.manager.GetMeshes() {
snapshot, err := mesh.GetMesh()
if err != nil {
return err
}
for _, node := range snapshot.GetNodes() {
previous, exists := m.nodes[node.GetWgHost().String()]
if !exists || !NodeEquals(previous, node) {
changedNodes = append(changedNodes, node)
}
nodes[node.GetWgHost().String()] = node
}
}
for _, previous := range m.nodes {
_, ok := nodes[previous.GetWgHost().String()]
if !ok {
removedNodes = append(removedNodes, previous)
}
}
if len(removedNodes) > 0 {
for _, cb := range m.removeCbs {
cb(removedNodes)
}
}
if len(changedNodes) > 0 {
for _, cb := range m.updateCbs {
cb(changedNodes)
}
}
return nil
}
func (m *MeshMonitorImpl) AddUpdateCallback(cb OnChange) {
m.updateCbs = append(m.updateCbs, cb)
}
func (m *MeshMonitorImpl) AddRemoveCallback(cb OnChange) {
m.removeCbs = append(m.removeCbs, cb)
}
func NewMeshMonitor(manager MeshManager) MeshMonitor {
return &MeshMonitorImpl{
updateCbs: make([]OnChange, 0),
nodes: make(map[string]MeshNode),
manager: manager,
}
}

View File

@ -21,6 +21,16 @@ type MeshNodeStub struct {
description string
}
// GetServices implements MeshNode.
func (*MeshNodeStub) GetServices() map[string]string {
return make(map[string]string)
}
// GetAlias implements MeshNode.
func (*MeshNodeStub) GetAlias() string {
return ""
}
func (m *MeshNodeStub) GetHostEndpoint() string {
return m.hostEndpoint
}
@ -66,6 +76,36 @@ type MeshProviderStub struct {
snapshot *MeshSnapshotStub
}
// GetNodeIds implements MeshProvider.
func (*MeshProviderStub) GetNodeIds() []string {
panic("unimplemented")
}
// GetNode implements MeshProvider.
func (*MeshProviderStub) GetNode(string) (MeshNode, error) {
panic("unimplemented")
}
// NodeExists implements MeshProvider.
func (*MeshProviderStub) NodeExists(string) bool {
panic("unimplemented")
}
// AddService implements MeshProvider.
func (*MeshProviderStub) AddService(nodeId string, key string, value string) error {
panic("unimplemented")
}
// RemoveService implements MeshProvider.
func (*MeshProviderStub) RemoveService(nodeId string, key string) error {
panic("unimplemented")
}
// SetAlias implements MeshProvider.
func (*MeshProviderStub) SetAlias(nodeId string, alias string) error {
panic("unimplemented")
}
// RemoveRoutes implements MeshProvider.
func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error {
panic("unimplemented")
@ -171,6 +211,31 @@ type MeshManagerStub struct {
meshes map[string]MeshProvider
}
// GetNode implements MeshManager.
func (*MeshManagerStub) GetNode(string, string) MeshNode {
panic("unimplemented")
}
// RemoveService implements MeshManager.
func (*MeshManagerStub) RemoveService(service string) error {
panic("unimplemented")
}
// SetService implements MeshManager.
func (*MeshManagerStub) SetService(service string, value string) error {
panic("unimplemented")
}
// GetMonitor implements MeshManager.
func (*MeshManagerStub) GetMonitor() MeshMonitor {
panic("unimplemented")
}
// SetAlias implements MeshManager.
func (*MeshManagerStub) SetAlias(alias string) error {
panic("unimplemented")
}
// Close implements MeshManager.
func (*MeshManagerStub) Close() error {
panic("unimplemented")
@ -185,7 +250,7 @@ func NewMeshManagerStub() MeshManager {
return &MeshManagerStub{meshes: make(map[string]MeshProvider)}
}
func (m *MeshManagerStub) CreateMesh(devName string, port int) (string, error) {
func (m *MeshManagerStub) CreateMesh(port int) (string, error) {
return "tim123", nil
}

View File

@ -4,6 +4,7 @@ package mesh
import (
"net"
"slices"
"github.com/tim-beatham/wgmesh/pkg/conf"
"golang.zx2c4.com/wireguard/wgctrl"
@ -28,6 +29,51 @@ type MeshNode interface {
GetIdentifier() string
// GetDescription: returns the description for this node
GetDescription() string
// GetAlias: associates the node with an alias. Potentially used
// for DNS and so forth.
GetAlias() string
// GetServices: returns a list of services offered by the node
GetServices() map[string]string
}
// NodeEquals: determines if two mesh nodes are equivalent to one another
func NodeEquals(node1, node2 MeshNode) bool {
if node1.GetHostEndpoint() != node2.GetHostEndpoint() {
return false
}
node1Pub, _ := node1.GetPublicKey()
node2Pub, _ := node2.GetPublicKey()
if node1Pub != node2Pub {
return false
}
if node1.GetWgEndpoint() != node2.GetWgEndpoint() {
return false
}
if node1.GetWgHost() != node2.GetWgHost() {
return false
}
if !slices.Equal(node1.GetRoutes(), node2.GetRoutes()) {
return false
}
if node1.GetIdentifier() != node2.GetIdentifier() {
return false
}
if node1.GetDescription() != node2.GetDescription() {
return false
}
if node1.GetAlias() != node2.GetAlias() {
return false
}
return true
}
type MeshSnapshot interface {
@ -46,7 +92,7 @@ type MeshSyncer interface {
type MeshProvider interface {
// AddNode() adds a node to the mesh
AddNode(node MeshNode)
// GetMesh() returns a snapshot of the mesh provided by the mesh provider
// GetMesh() returns a snapshot of the mesh provided by the mesh provider.
GetMesh() (MeshSnapshot, error)
// GetMeshId() returns the ID of the mesh network
GetMeshId() string
@ -68,11 +114,22 @@ type MeshProvider interface {
RemoveRoutes(nodeId string, route ...string) error
// GetSyncer: returns the automerge syncer for sync
GetSyncer() MeshSyncer
// GetNode get a particular not within the mesh
GetNode(string) (MeshNode, error)
// NodeExists: returns true if a particular node exists false otherwise
NodeExists(string) bool
// SetDescription: sets the description of this automerge data type
SetDescription(nodeId string, description string) error
// SetAlias: set the alias of the nodeId
SetAlias(nodeId string, alias string) error
// AddService: adds the service to the given node
AddService(nodeId, key, value string) error
// RemoveService: removes the service form the node. throws an error if the service does not exist
RemoveService(nodeId, key string) error
// Prune: prunes all nodes that have not updated their timestamp in
// pruneAmount seconds
Prune(pruneAmount int) error
GetNodeIds() []string
}
// HostParameters contains the IDs of a node

View File

@ -24,13 +24,15 @@ type QueryError struct {
}
type QueryNode struct {
HostEndpoint string `json:"hostEndpoint"`
PublicKey string `json:"publicKey"`
WgEndpoint string `json:"wgEndpoint"`
WgHost string `json:"wgHost"`
Timestamp int64 `json:"timestmap"`
Description string `json:"description"`
Routes []string `json:"routes"`
HostEndpoint string `json:"hostEndpoint"`
PublicKey string `json:"publicKey"`
WgEndpoint string `json:"wgEndpoint"`
WgHost string `json:"wgHost"`
Timestamp int64 `json:"timestmap"`
Description string `json:"description"`
Routes []string `json:"routes"`
Alias string `json:"alias"`
Services map[string]string `json:"services"`
}
func (m *QueryError) Error() string {
@ -51,7 +53,7 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
return nil, err
}
nodes := lib.Map(lib.MapValues(snapshot.GetNodes()), meshNodeToQueryNode)
nodes := lib.Map(lib.MapValues(snapshot.GetNodes()), MeshNodeToQueryNode)
result, err := jmespath.Search(queryParams, nodes)
@ -63,7 +65,7 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
return bytes, err
}
func meshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
func MeshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode := new(QueryNode)
queryNode.HostEndpoint = node.GetHostEndpoint()
pubKey, _ := node.GetPublicKey()
@ -76,6 +78,9 @@ func meshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode.Timestamp = node.GetTimeStamp()
queryNode.Routes = node.GetRoutes()
queryNode.Description = node.GetDescription()
queryNode.Alias = node.GetAlias()
queryNode.Services = node.GetServices()
return queryNode
}

View File

@ -2,6 +2,7 @@ package robin
import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
@ -10,6 +11,7 @@ import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query"
"github.com/tim-beatham/wgmesh/pkg/rpc"
)
@ -18,7 +20,7 @@ type IpcHandler struct {
}
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
meshId, err := n.Server.GetMeshManager().CreateMesh(args.IfName, args.WgPort)
meshId, err := n.Server.GetMeshManager().CreateMesh(args.WgPort)
if err != nil {
return err
@ -81,7 +83,6 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
MeshId: args.MeshId,
DevName: args.IfName,
WgPort: args.Port,
MeshBytes: meshReply.Mesh,
})
@ -117,6 +118,11 @@ func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error {
func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
mesh := n.Server.GetMeshManager().GetMesh(meshId)
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
meshSnapshot, err := mesh.GetMesh()
if err != nil {
@ -145,6 +151,8 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
Timestamp: node.GetTimeStamp(),
Routes: node.GetRoutes(),
Description: node.GetDescription(),
Alias: node.GetAlias(),
Services: node.GetServices(),
}
nodes[i] = node
@ -202,6 +210,60 @@ func (n *IpcHandler) PutDescription(description string, reply *string) error {
return nil
}
func (n *IpcHandler) PutAlias(alias string, reply *string) error {
err := n.Server.GetMeshManager().SetAlias(alias)
if err != nil {
return err
}
*reply = fmt.Sprintf("Set alias to %s", alias)
return nil
}
func (n *IpcHandler) PutService(service ipc.PutServiceArgs, reply *string) error {
err := n.Server.GetMeshManager().SetService(service.Service, service.Value)
if err != nil {
return err
}
*reply = "success"
return nil
}
func (n *IpcHandler) DeleteService(service string, reply *string) error {
err := n.Server.GetMeshManager().RemoveService(service)
if err != nil {
return err
}
*reply = "success"
return nil
}
func (n *IpcHandler) GetNode(args ipc.GetNodeArgs, reply *string) error {
node := n.Server.GetMeshManager().GetNode(args.MeshId, args.NodeId)
if node == nil {
*reply = "nil"
return nil
}
queryNode := query.MeshNodeToQueryNode(node)
bytes, err := json.Marshal(queryNode)
if err != nil {
*reply = err.Error()
return nil
}
*reply = string(bytes)
return nil
}
type RobinIpcParams struct {
CtrlServer ctrlserver.CtrlServer
}

View File

@ -1,7 +1,6 @@
package sync
import (
"errors"
"math/rand"
"sync"
"time"
@ -30,35 +29,22 @@ type SyncerImpl struct {
// Sync: Sync random nodes
func (s *SyncerImpl) Sync(meshId string) error {
logging.Log.WriteInfof("UPDATING WG CONF")
err := s.manager.ApplyConfig()
if err != nil {
logging.Log.WriteInfof("Failed to update config %w", err)
}
if !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
logging.Log.WriteInfof("No changes for %s", meshId)
return nil
}
theMesh := s.manager.GetMesh(meshId)
logging.Log.WriteInfof("UPDATING WG CONF")
if theMesh == nil {
return errors.New("the provided mesh does not exist")
if s.manager.HasChanges(meshId) {
err := s.manager.ApplyConfig()
if err != nil {
logging.Log.WriteInfof("Failed to update config %w", err)
}
}
snapshot, err := theMesh.GetMesh()
if err != nil {
return err
}
nodes := snapshot.GetNodes()
if len(nodes) <= 1 {
return nil
}
nodeNames := s.manager.GetMesh(meshId).GetNodeIds()
self, err := s.manager.GetSelf(meshId)
@ -66,17 +52,6 @@ func (s *SyncerImpl) Sync(meshId string) error {
return err
}
excludedNodes := map[string]struct{}{
self.GetHostEndpoint(): {},
}
meshNodes := lib.MapValuesWithExclude(nodes, excludedNodes)
getNames := func(node mesh.MeshNode) string {
return node.GetHostEndpoint()
}
nodeNames := lib.Map(meshNodes, getNames)
neighbours := s.cluster.GetNeighbours(nodeNames, self.GetHostEndpoint())
randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
@ -86,7 +61,7 @@ func (s *SyncerImpl) Sync(meshId string) error {
before := time.Now()
if len(meshNodes) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
logging.Log.WriteInfof("Sending to random cluster")
interCluster := s.cluster.GetInterCluster(nodeNames, self.GetHostEndpoint())
randomSubset = append(randomSubset, interCluster)
@ -107,16 +82,20 @@ func (s *SyncerImpl) Sync(meshId string) error {
waitGroup.Wait()
s.syncCount++
logging.Log.WriteInfof("SYNC TIME: %v", time.Now().Sub(before))
logging.Log.WriteInfof("SYNC TIME: %v", time.Since(before))
logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount)
s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount)
// Check if any changes have occurred and trigger callbacks
// if changes have occurred.
// return s.manager.GetMonitor().Trigger()
return nil
}
// SyncMeshes: Sync all meshes
func (s *SyncerImpl) SyncMeshes() error {
for meshId, _ := range s.manager.GetMeshes() {
for meshId := range s.manager.GetMeshes() {
err := s.Sync(meshId)
if err != nil {

View File

@ -50,7 +50,6 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endP
err = s.server.MeshManager.AddMesh(&mesh.AddMeshParams{
MeshId: meshId,
DevName: ifName,
WgPort: port,
MeshBytes: reply.Mesh,
})

View File

@ -2,7 +2,7 @@ package wg
type WgInterfaceManipulatorStub struct{}
func (i *WgInterfaceManipulatorStub) CreateInterface(params *CreateInterfaceParams) error {
func (i *WgInterfaceManipulatorStub) CreateInterface(port int) error {
return nil
}

View File

@ -8,14 +8,9 @@ func (m *WgError) Error() string {
return m.msg
}
type CreateInterfaceParams struct {
IfName string
Port int
}
type WgInterfaceManipulator interface {
// CreateInterface creates a WireGuard interface
CreateInterface(params *CreateInterfaceParams) error
CreateInterface(port int) (string, error)
// AddAddress adds an address to the given interface name
AddAddress(ifName string, addr string) error
// RemoveInterface removes the specified interface

View File

@ -1,6 +1,9 @@
package wg
import (
"crypto"
"crypto/rand"
"encoding/base64"
"fmt"
"github.com/tim-beatham/wgmesh/pkg/lib"
@ -13,40 +16,54 @@ type WgInterfaceManipulatorImpl struct {
client *wgctrl.Client
}
const hashLength = 6
// CreateInterface creates a WireGuard interface
func (m *WgInterfaceManipulatorImpl) CreateInterface(params *CreateInterfaceParams) error {
func (m *WgInterfaceManipulatorImpl) CreateInterface(port int) (string, error) {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return fmt.Errorf("failed to access link: %w", err)
return "", fmt.Errorf("failed to access link: %w", err)
}
defer rtnl.Close()
err = rtnl.CreateLink(params.IfName)
randomBuf := make([]byte, 32)
_, err = rand.Read(randomBuf)
if err != nil {
return fmt.Errorf("failed to create link: %w", err)
return "", err
}
md5 := crypto.MD5.New().Sum(randomBuf)
md5Str := fmt.Sprintf("wg%s", base64.StdEncoding.EncodeToString(md5)[:hashLength])
err = rtnl.CreateLink(md5Str)
if err != nil {
return "", fmt.Errorf("failed to create link: %w", err)
}
privateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return fmt.Errorf("failed to create private key: %w", err)
return "", fmt.Errorf("failed to create private key: %w", err)
}
var cfg wgtypes.Config = wgtypes.Config{
PrivateKey: &privateKey,
ListenPort: &params.Port,
ListenPort: &port,
}
err = m.client.ConfigureDevice(params.IfName, cfg)
err = m.client.ConfigureDevice(md5Str, cfg)
if err != nil {
return fmt.Errorf("failed to configure dev: %w", err)
m.RemoveInterface(md5Str)
return "", fmt.Errorf("failed to configure dev: %w", err)
}
logging.Log.WriteInfof("ip link set up dev %s type wireguard", params.IfName)
return nil
logging.Log.WriteInfof("ip link set up dev %s type wireguard", md5Str)
return md5Str, nil
}
// Add an address to the given interface

1
smegmesh-web Submodule

Submodule smegmesh-web added at c1128bcd98