Compare commits

...

10 Commits

Author SHA1 Message Date
388153e706 Stubbing out WireGuard components
Stubbing our WireGuard components so that I can use docker/podman
network_mode=host. This is much more efficient than the docker/podman
userspace network.
2023-11-20 11:28:12 +00:00
023565d985 Merge pull request #17 from tim-beatham/25-ability-to-aliases
25 ability to aliases
2023-11-17 22:20:57 +00:00
36c264b38e 25-ability-aliases
Fixed unit tests failing
2023-11-17 22:18:53 +00:00
68db795f47 Ability to specify aliases
Ability to specify aliases that automatically append to /etc/hosts
2023-11-17 22:13:51 +00:00
f6160fe138 Adding aliases that automatically gets added 2023-11-17 19:13:20 +00:00
2c5289afb0 Merge pull request #16 from tim-beatham/15-add-rest-api
Developed a rest API
2023-11-15 12:57:05 +00:00
7199d07a76 Added smegmesh submodule 2023-11-13 10:46:52 +00:00
5f176e731f Developed a rest API 2023-11-13 10:44:14 +00:00
44f119b45c Updating examples 2023-11-08 09:19:24 +00:00
5215d5d54d Merge pull request #14 from tim-beatham/13-netlink-api
Removed interface manipulation via os.Exec into
2023-11-07 19:53:39 +00:00
32 changed files with 1320 additions and 119 deletions

3
.gitmodules vendored Normal file
View File

@ -0,0 +1,3 @@
[submodule "smegmesh-web"]
path = smegmesh-web
url = git@github.com:tim-beatham/smegmesh-web.git

11
Containerfile Normal file
View File

@ -0,0 +1,11 @@
FROM docker.io/library/golang:bookworm
COPY ./ /wgmesh
RUN apt-get update && apt-get install -y \
wireguard \
wireguard-tools \
iproute2 \
iputils-ping \
tmux \
vim
WORKDIR /wgmesh
RUN go build -o /usr/local/bin ./...

1
Dockerfile Symbolic link
View File

@ -0,0 +1 @@
Containerfile

17
cmd/api/main.go Normal file
View File

@ -0,0 +1,17 @@
package main
import (
"log"
"github.com/tim-beatham/wgmesh/pkg/api"
)
func main() {
apiServer, err := api.NewSmegServer()
if err != nil {
log.Fatal(err.Error())
}
apiServer.Run(":40000")
}

View File

@ -171,6 +171,68 @@ func putDescription(client *ipcRpc.Client, description string) {
fmt.Println(reply) fmt.Println(reply)
} }
// putAlias: puts an alias for the node
func putAlias(client *ipcRpc.Client, alias string) {
var reply string
err := client.Call("IpcHandler.PutAlias", &alias, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func setService(client *ipcRpc.Client, service, value string) {
var reply string
serviceArgs := &ipc.PutServiceArgs{
Service: service,
Value: value,
}
err := client.Call("IpcHandler.PutService", serviceArgs, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func deleteService(client *ipcRpc.Client, service string) {
var reply string
err := client.Call("IpcHandler.PutService", &service, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func getNode(client *ipcRpc.Client, nodeId, meshId string) {
var reply string
args := &ipc.GetNodeArgs{
NodeId: nodeId,
MeshId: meshId,
}
err := client.Call("IpcHandler.GetNode", &args, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func main() { func main() {
parser := argparse.NewParser("wg-mesh", parser := argparse.NewParser("wg-mesh",
"wg-mesh Manipulate WireGuard meshes") "wg-mesh Manipulate WireGuard meshes")
@ -184,6 +246,10 @@ func main() {
leaveMeshCmd := parser.NewCommand("leave-mesh", "Leave a mesh network") leaveMeshCmd := parser.NewCommand("leave-mesh", "Leave a mesh network")
queryMeshCmd := parser.NewCommand("query-mesh", "Query a mesh network using JMESPath") queryMeshCmd := parser.NewCommand("query-mesh", "Query a mesh network using JMESPath")
putDescriptionCmd := parser.NewCommand("put-description", "Place a description for the node") putDescriptionCmd := parser.NewCommand("put-description", "Place a description for the node")
putAliasCmd := parser.NewCommand("put-alias", "Place an alias for the node")
setServiceCmd := parser.NewCommand("set-service", "Place a service into your advertisements")
deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements")
getNodeCmd := parser.NewCommand("get-node", "Get a specific node from the mesh")
var newMeshIfName *string = newMeshCmd.String("f", "ifname", &argparse.Options{Required: true}) var newMeshIfName *string = newMeshCmd.String("f", "ifname", &argparse.Options{Required: true})
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{Required: true}) var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
@ -195,8 +261,6 @@ func main() {
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{Required: true}) var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{}) var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{})
// var getMeshId *string = getMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true}) var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true})
var getGraphMeshId *string = getGraphCmd.String("m", "mesh", &argparse.Options{Required: true}) var getGraphMeshId *string = getGraphCmd.String("m", "mesh", &argparse.Options{Required: true})
@ -208,6 +272,16 @@ func main() {
var description *string = putDescriptionCmd.String("d", "description", &argparse.Options{Required: true}) var description *string = putDescriptionCmd.String("d", "description", &argparse.Options{Required: true})
var alias *string = putAliasCmd.String("a", "alias", &argparse.Options{Required: true})
var serviceKey *string = setServiceCmd.String("s", "service", &argparse.Options{Required: true})
var serviceValue *string = setServiceCmd.String("v", "value", &argparse.Options{Required: true})
var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{Required: true})
var getNodeNodeId *string = getNodeCmd.String("n", "nodeid", &argparse.Options{Required: true})
var getNodeMeshId *string = getNodeCmd.String("m", "meshid", &argparse.Options{Required: true})
err := parser.Parse(os.Args) err := parser.Parse(os.Args)
if err != nil { if err != nil {
@ -245,10 +319,6 @@ func main() {
})) }))
} }
// if getMeshCmd.Happened() {
// getMesh(client, *getMeshId)
// }
if getGraphCmd.Happened() { if getGraphCmd.Happened() {
getGraph(client, *getGraphMeshId) getGraph(client, *getGraphMeshId)
} }
@ -268,4 +338,20 @@ func main() {
if putDescriptionCmd.Happened() { if putDescriptionCmd.Happened() {
putDescription(client, *description) putDescription(client, *description)
} }
if putAliasCmd.Happened() {
putAlias(client, *alias)
}
if setServiceCmd.Happened() {
setService(client, *serviceKey, *serviceValue)
}
if deleteServiceCmd.Happened() {
deleteService(client, *deleteServiceKey)
}
if getNodeCmd.Happened() {
getNode(client, *getNodeNodeId, *getNodeMeshId)
}
} }

View File

@ -1,7 +1,8 @@
package main package main
import ( import (
"log" "net/http"
_ "net/http/pprof"
"os" "os"
"os/signal" "os/signal"
@ -35,6 +36,12 @@ func main() {
return return
} }
if conf.Profile {
go func() {
http.ListenAndServe("localhost:6060", nil)
}()
}
var robinRpc robin.WgRpc var robinRpc robin.WgRpc
var robinIpc robin.IpcHandler var robinIpc robin.IpcHandler
var syncProvider sync.SyncServiceImpl var syncProvider sync.SyncServiceImpl
@ -45,8 +52,8 @@ func main() {
SyncProvider: &syncProvider, SyncProvider: &syncProvider,
Client: client, Client: client,
} }
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
syncProvider.Server = ctrlServer syncProvider.Server = ctrlServer
syncRequester := sync.NewSyncRequester(ctrlServer) syncRequester := sync.NewSyncRequester(ctrlServer)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester) syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester)
@ -65,7 +72,7 @@ func main() {
return return
} }
log.Println("Running IPC Handler") logging.Log.WriteInfof("Running IPC Handler")
go ipc.RunIpcHandler(&robinIpc) go ipc.RunIpcHandler(&robinIpc)
go syncScheduler.Run() go syncScheduler.Run()

View File

@ -0,0 +1,95 @@
version: '3'
networks:
net-1:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.0.0/17
net-2:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.155.0/17
services:
wg-1:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-2:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-3:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-4:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
sysctls:
- net.ipv6.conf.all.forwarding=1
networks:
- net-1
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-5:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-6:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-7:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"

View File

@ -0,0 +1,14 @@
certificatePath: "/wgmesh/cert/cert.pem"
privateKeyPath: "/wgmesh/cert/priv.pem"
caCertificatePath: "/wgmesh/cert/cacert.pem"
skipCertVerification: true
timeout: 5
gRPCPort: "21906"
advertiseRoutes: true
clusterSize: 32
syncRate: 1
interClusterChance: 0.15
branchRate: 3
infectionCount: 3
keepAliveTime: 10
pruneTime: 20

View File

@ -0,0 +1,42 @@
version: '3'
networks:
net-1:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.0.0/17
services:
wg-1:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-2:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-3:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"

View File

@ -0,0 +1,14 @@
certificatePath: "/wgmesh/cert/cert.pem"
privateKeyPath: "/wgmesh/cert/priv.pem"
caCertificatePath: "/wgmesh/cert/cacert.pem"
skipCertVerification: true
timeout: 5
gRPCPort: "21906"
advertiseRoutes: true
clusterSize: 32
syncRate: 1
interClusterChance: 0.15
branchRate: 3
infectionCount: 3
keepAliveTime: 10
pruneTime: 20

24
go.mod
View File

@ -5,9 +5,12 @@ go 1.21.3
require ( require (
github.com/akamensky/argparse v1.4.0 github.com/akamensky/argparse v1.4.0
github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9 github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9
github.com/gin-gonic/gin v1.9.1
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/jmespath/go-jmespath v0.4.0 github.com/jmespath/go-jmespath v0.4.0
github.com/jsimonetti/rtnetlink v1.3.5
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
golang.org/x/sys v0.14.0
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
google.golang.org/grpc v1.58.1 google.golang.org/grpc v1.58.1
google.golang.org/protobuf v1.31.0 google.golang.org/protobuf v1.31.0
@ -15,18 +18,33 @@ require (
) )
require ( require (
github.com/bytedance/sonic v1.9.1 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/golang/protobuf v1.5.3 // indirect github.com/golang/protobuf v1.5.3 // indirect
github.com/google/go-cmp v0.5.9 // indirect github.com/google/go-cmp v0.5.9 // indirect
github.com/josharian/native v1.1.0 // indirect github.com/josharian/native v1.1.0 // indirect
github.com/jsimonetti/rtnetlink v1.3.5 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.5.0 // indirect github.com/mdlayher/socket v0.5.0 // indirect
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.13.0 // indirect golang.org/x/crypto v0.13.0 // indirect
golang.org/x/net v0.15.0 // indirect golang.org/x/net v0.15.0 // indirect
golang.org/x/sync v0.3.0 // indirect golang.org/x/sync v0.3.0 // indirect
golang.org/x/sys v0.12.0 // indirect
golang.org/x/text v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 // indirect golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect

205
pkg/api/apiserver.go Normal file
View File

@ -0,0 +1,205 @@
package api
import (
"fmt"
"net/http"
ipcRpc "net/rpc"
"github.com/gin-gonic/gin"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
type ApiServer interface {
GetMeshes(c *gin.Context)
Run(addr string) error
}
type SmegServer struct {
router *gin.Engine
client *ipcRpc.Client
}
func meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode {
if meshNode.Routes == nil {
meshNode.Routes = make([]string, 0)
}
return &SmegNode{
WgHost: meshNode.WgHost,
WgEndpoint: meshNode.WgEndpoint,
Endpoint: meshNode.HostEndpoint,
Timestamp: int(meshNode.Timestamp),
Description: meshNode.Description,
Routes: meshNode.Routes,
PublicKey: meshNode.PublicKey,
Alias: meshNode.Alias,
Services: meshNode.Services,
}
}
func meshToAPIMesh(meshId string, nodes []ctrlserver.MeshNode) SmegMesh {
var smegMesh SmegMesh
smegMesh.MeshId = meshId
smegMesh.Nodes = make(map[string]SmegNode)
for _, node := range nodes {
smegMesh.Nodes[node.WgHost] = *meshNodeToAPIMeshNode(node)
}
return smegMesh
}
// CreateMesh: creates a mesh network
func (s *SmegServer) CreateMesh(c *gin.Context) {
var createMesh CreateMeshRequest
if err := c.ShouldBindJSON(&createMesh); err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
ipcRequest := ipc.NewMeshArgs{
IfName: createMesh.IfName,
WgPort: createMesh.WgPort,
}
var reply string
err := s.client.Call("IpcHandler.CreateMesh", &ipcRequest, &reply)
if err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
c.JSON(http.StatusOK, &gin.H{
"meshid": reply,
})
}
// JoinMesh: joins a mesh network
func (s *SmegServer) JoinMesh(c *gin.Context) {
var joinMesh JoinMeshRequest
if err := c.ShouldBindJSON(&joinMesh); err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
ipcRequest := ipc.JoinMeshArgs{
MeshId: joinMesh.MeshId,
IpAdress: joinMesh.Bootstrap,
IfName: joinMesh.IfName,
Port: joinMesh.WgPort,
}
var reply string
err := s.client.Call("IpcHandler.JoinMesh", &ipcRequest, &reply)
if err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
c.JSON(http.StatusOK, &gin.H{
"status": "success",
})
}
// GetMesh: given a meshId returns the corresponding mesh
// network.
func (s *SmegServer) GetMesh(c *gin.Context) {
meshidParam := c.Param("meshid")
var meshid string = meshidParam
getMeshReply := new(ipc.GetMeshReply)
err := s.client.Call("IpcHandler.GetMesh", &meshid, &getMeshReply)
if err != nil {
c.JSON(http.StatusNotFound,
&gin.H{
"error": fmt.Sprintf("could not find mesh %s", meshidParam),
})
return
}
mesh := meshToAPIMesh(meshidParam, getMeshReply.Nodes)
c.JSON(http.StatusOK, mesh)
}
func (s *SmegServer) GetMeshes(c *gin.Context) {
listMeshesReply := new(ipc.ListMeshReply)
err := s.client.Call("IpcHandler.ListMeshes", "", &listMeshesReply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
c.JSON(http.StatusBadRequest, nil)
return
}
meshes := make([]SmegMesh, 0)
for _, mesh := range listMeshesReply.Meshes {
getMeshReply := new(ipc.GetMeshReply)
err := s.client.Call("IpcHandler.GetMesh", &mesh, &getMeshReply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
c.JSON(http.StatusBadRequest, nil)
return
}
meshes = append(meshes, meshToAPIMesh(mesh, getMeshReply.Nodes))
}
c.JSON(http.StatusOK, meshes)
}
func (s *SmegServer) Run(addr string) error {
logging.Log.WriteInfof("Running API server")
return s.router.Run(addr)
}
func NewSmegServer() (ApiServer, error) {
client, err := ipcRpc.DialHTTP("unix", SockAddr)
if err != nil {
return nil, err
}
router := gin.Default()
router.Use(gin.LoggerWithConfig(gin.LoggerConfig{
Output: logging.Log.Writer(),
}))
smegServer := &SmegServer{
router: router,
client: client,
}
router.GET("/meshes", smegServer.GetMeshes)
router.GET("/mesh/:meshid", smegServer.GetMesh)
router.POST("/mesh/create", smegServer.CreateMesh)
router.POST("/mesh/join", smegServer.JoinMesh)
return smegServer, nil
}

30
pkg/api/types.go Normal file
View File

@ -0,0 +1,30 @@
package api
type SmegNode struct {
Alias string `json:"alias"`
WgHost string `json:"wgHost"`
WgEndpoint string `json:"wgEndpoint"`
Endpoint string `json:"endpoint"`
Timestamp int `json:"timestamp"`
Description string `json:"description"`
PublicKey string `json:"publicKey"`
Routes []string `json:"routes"`
Services map[string]string `json:"services"`
}
type SmegMesh struct {
MeshId string `json:"meshid"`
Nodes map[string]SmegNode `json:"nodes"`
}
type CreateMeshRequest struct {
IfName string `json:"ifName" binding:"required"`
WgPort int `json:"port" binding:"required,gte=1024,lt=65535"`
}
type JoinMeshRequest struct {
IfName string `json:"ifName" binding:"required"`
WgPort int `json:"port" binding:"required,gte=1024,lt=65535"`
Bootstrap string `json:"bootstrap" binding:"required"`
MeshId string `json:"meshid" binding:"required"`
}

View File

@ -20,11 +20,12 @@ import (
type CrdtMeshManager struct { type CrdtMeshManager struct {
MeshId string MeshId string
IfName string IfName string
NodeId string
Client *wgctrl.Client Client *wgctrl.Client
doc *automerge.Doc doc *automerge.Doc
LastHash automerge.ChangeHash LastHash automerge.ChangeHash
conf *conf.WgMeshConfiguration conf *conf.WgMeshConfiguration
cache *MeshCrdt
lastCacheHash automerge.ChangeHash
} }
func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) { func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
@ -35,14 +36,37 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
} }
crdt.Routes = make(map[string]interface{}) crdt.Routes = make(map[string]interface{})
crdt.Services = make(map[string]string)
crdt.Timestamp = time.Now().Unix() crdt.Timestamp = time.Now().Unix()
c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt) c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt)
} }
func (c *CrdtMeshManager) GetNodeIds() []string {
keys, _ := c.doc.Path("nodes").Map().Keys()
return keys
}
// GetMesh(): Converts the document into a struct // GetMesh(): Converts the document into a struct
func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) { func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
return automerge.As[*MeshCrdt](c.doc.Root()) changes, err := c.doc.Changes(c.lastCacheHash)
if err != nil {
return nil, err
}
if c.cache == nil || len(changes) > 3 {
c.lastCacheHash = c.LastHash
cache, err := automerge.As[*MeshCrdt](c.doc.Root())
if err != nil {
return nil, err
}
c.cache = cache
}
return c.cache, nil
} }
// GetMeshId returns the meshid of the mesh // GetMeshId returns the meshid of the mesh
@ -82,13 +106,23 @@ func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, erro
manager.IfName = params.DevName manager.IfName = params.DevName
manager.Client = params.Client manager.Client = params.Client
manager.conf = &params.Conf manager.conf = &params.Conf
manager.cache = nil
return &manager, nil return &manager, nil
} }
// GetNode: returns a mesh node crdt.Close releases resources used by a Client. // NodeExists: returns true if the node exists. Returns false
func (m *CrdtMeshManager) GetNode(endpoint string) (*MeshNodeCrdt, error) { func (m *CrdtMeshManager) NodeExists(key string) bool {
node, err := m.doc.Path("nodes").Map().Get(key)
return node.Kind() == automerge.KindMap && err != nil
}
func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
node, err := m.doc.Path("nodes").Map().Get(endpoint) node, err := m.doc.Path("nodes").Map().Get(endpoint)
if node.Kind() != automerge.KindMap {
return nil, fmt.Errorf("GetNode: something went wrong %s is not a map type")
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -178,6 +212,72 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro
return err return err
} }
func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
return err
}
if node.Kind() != automerge.KindMap {
return fmt.Errorf("%s does not exist", nodeId)
}
err = node.Map().Set("alias", alias)
if err == nil {
logging.Log.WriteInfof("Updated Alias for %s to %s", nodeId, alias)
}
return err
}
func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return fmt.Errorf("AddService: node %s does not exist", nodeId)
}
service, err := node.Map().Get("services")
if err != nil {
return err
}
if service.Kind() != automerge.KindMap {
return fmt.Errorf("AddService: services property does not exist in node")
}
return service.Map().Set(key, value)
}
func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return fmt.Errorf("RemoveService: node %s does not exist", nodeId)
}
service, err := node.Map().Get("services")
if err != nil {
return err
}
if service.Kind() != automerge.KindMap {
return fmt.Errorf("services property does not exist")
}
err = service.Map().Delete(key)
if err != nil {
return fmt.Errorf("service %s does not exist", key)
}
return nil
}
// AddRoutes: adds routes to the specific nodeId // AddRoutes: adds routes to the specific nodeId
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error { func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -336,6 +436,20 @@ func (m *MeshNodeCrdt) GetIdentifier() string {
return strings.Join(constituents, ":") return strings.Join(constituents, ":")
} }
func (m *MeshNodeCrdt) GetAlias() string {
return m.Alias
}
func (m *MeshNodeCrdt) GetServices() map[string]string {
services := make(map[string]string)
for key, service := range m.Services {
services[key] = service
}
return services
}
func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode { func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
nodes := make(map[string]mesh.MeshNode) nodes := make(map[string]mesh.MeshNode)
@ -348,6 +462,8 @@ func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
Timestamp: node.Timestamp, Timestamp: node.Timestamp,
Routes: node.Routes, Routes: node.Routes,
Description: node.Description, Description: node.Description,
Alias: node.Alias,
Services: node.GetServices(),
} }
} }

View File

@ -36,6 +36,8 @@ func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNod
// Always set the routes as empty. // Always set the routes as empty.
// Routes handled by external component // Routes handled by external component
Routes: map[string]interface{}{}, Routes: map[string]interface{}{},
Description: "",
Alias: "",
} }
} }

View File

@ -8,7 +8,9 @@ type MeshNodeCrdt struct {
WgHost string `automerge:"wgHost"` WgHost string `automerge:"wgHost"`
Timestamp int64 `automerge:"timestamp"` Timestamp int64 `automerge:"timestamp"`
Routes map[string]interface{} `automerge:"routes"` Routes map[string]interface{} `automerge:"routes"`
Alias string `automerge:"alias"`
Description string `automerge:"description"` Description string `automerge:"description"`
Services map[string]string `automerge:"services"`
} }
// MeshCrdt: Represents the mesh network as a whole // MeshCrdt: Represents the mesh network as a whole

View File

@ -49,6 +49,10 @@ type WgMeshConfiguration struct {
Timeout int `yaml:"timeout"` Timeout int `yaml:"timeout"`
// PruneTime number of seconds before we consider the 'node' as dead // PruneTime number of seconds before we consider the 'node' as dead
PruneTime int `yaml:"pruneTime"` PruneTime int `yaml:"pruneTime"`
// Profile whether or not to include a http server that profiles the code
Profile bool `yaml:"profile"`
// StubWg whether or not to stub the WireGuard types
StubWg bool `yaml:"stubWg"`
} }
func ValidateConfiguration(c *WgMeshConfiguration) error { func ValidateConfiguration(c *WgMeshConfiguration) error {

View File

@ -17,6 +17,9 @@ type MeshNode struct {
WgHost string WgHost string
Timestamp int64 Timestamp int64
Routes []string Routes []string
Description string
Alias string
Services map[string]string
} }
// Represents a WireGuard Mesh // Represents a WireGuard Mesh

132
pkg/hosts/hosts.go Normal file
View File

@ -0,0 +1,132 @@
// hosts: utility for modifying the /etc/hosts file
package hosts
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"os"
"strings"
)
// HOSTS_FILE is the hosts file location
const HOSTS_FILE = "/etc/hosts"
const DOMAIN_HEADER = "#WG AUTO GENERATED HOSTS"
const DOMAIN_TRAILER = "#WG AUTO GENERATED HOSTS END"
type HostsEntry struct {
Alias string
Ip net.IP
}
// Generic interface to manipulate /etc/hosts file
type HostsManipulator interface {
// AddrAddr associates an aliasd with a given IP address
AddAddr(hosts ...HostsEntry)
// Remove deletes the entry from /etc/hosts
Remove(hosts ...HostsEntry)
// Writes the changes to /etc/hosts file
Write() error
}
type HostsManipulatorImpl struct {
hosts map[string]HostsEntry
}
// AddAddr implements HostsManipulator.
func (m *HostsManipulatorImpl) AddAddr(hosts ...HostsEntry) {
changed := false
for _, host := range hosts {
prev, ok := m.hosts[host.Ip.String()]
if !ok || prev.Alias != host.Alias {
changed = true
}
m.hosts[host.Ip.String()] = host
}
if changed {
m.Write()
}
}
// Remove implements HostsManipulator.
func (m *HostsManipulatorImpl) Remove(hosts ...HostsEntry) {
lenBefore := len(m.hosts)
for _, host := range hosts {
delete(m.hosts, host.Alias)
}
if lenBefore != len(m.hosts) {
m.Write()
}
}
func (m *HostsManipulatorImpl) removeHosts() string {
hostsFile, err := os.ReadFile(HOSTS_FILE)
if err != nil {
return ""
}
var contents strings.Builder
scanner := bufio.NewScanner(bytes.NewReader(hostsFile))
hostsSection := false
for scanner.Scan() {
line := scanner.Text()
if err == io.EOF {
break
} else if err != nil {
return ""
}
if !hostsSection && strings.Contains(line, DOMAIN_HEADER) {
hostsSection = true
}
if !hostsSection {
contents.WriteString(line + "\n")
}
if hostsSection && strings.Contains(line, DOMAIN_TRAILER) {
hostsSection = false
}
}
if scanner.Err() != nil && scanner.Err() != io.EOF {
return ""
}
return contents.String()
}
// Write implements HostsManipulator
func (m *HostsManipulatorImpl) Write() error {
contents := m.removeHosts()
var nextHosts strings.Builder
nextHosts.WriteString(contents)
nextHosts.WriteString(DOMAIN_HEADER + "\n")
for _, host := range m.hosts {
nextHosts.WriteString(fmt.Sprintf("%s\t%s\n", host.Ip.String(), host.Alias))
}
nextHosts.WriteString(DOMAIN_TRAILER + "\n")
return os.WriteFile(HOSTS_FILE, []byte(nextHosts.String()), 0644)
}
func NewHostsManipulator() HostsManipulator {
return &HostsManipulatorImpl{hosts: make(map[string]HostsEntry)}
}

View File

@ -34,6 +34,11 @@ type JoinMeshArgs struct {
Endpoint string Endpoint string
} }
type PutServiceArgs struct {
Service string
Value string
}
type GetMeshReply struct { type GetMeshReply struct {
Nodes []ctrlserver.MeshNode Nodes []ctrlserver.MeshNode
} }
@ -47,6 +52,11 @@ type QueryMesh struct {
Query string Query string
} }
type GetNodeArgs struct {
NodeId string
MeshId string
}
type MeshIpc interface { type MeshIpc interface {
CreateMesh(args *NewMeshArgs, reply *string) error CreateMesh(args *NewMeshArgs, reply *string) error
ListMeshes(name string, reply *ListMeshReply) error ListMeshes(name string, reply *ListMeshReply) error
@ -57,13 +67,17 @@ type MeshIpc interface {
GetDOT(meshId string, reply *string) error GetDOT(meshId string, reply *string) error
Query(query QueryMesh, reply *string) error Query(query QueryMesh, reply *string) error
PutDescription(description string, reply *string) error PutDescription(description string, reply *string) error
PutAlias(alias string, reply *string) error
PutService(args PutServiceArgs, reply *string) error
GetNode(args GetNodeArgs, reply *string) error
DeleteService(service string, reply *string) error
} }
const SockAddr = "/tmp/wgmesh_ipc.sock" const SockAddr = "/tmp/wgmesh_ipc.sock"
func RunIpcHandler(server MeshIpc) error { func RunIpcHandler(server MeshIpc) error {
if err := os.RemoveAll(SockAddr); err != nil { if err := os.RemoveAll(SockAddr); err != nil {
return errors.New("Could not find to address") return errors.New("could not find to address")
} }
rpc.Register(server) rpc.Register(server)

View File

@ -2,6 +2,7 @@
package logging package logging
import ( import (
"io"
"os" "os"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -15,6 +16,7 @@ type Logger interface {
WriteInfof(msg string, args ...interface{}) WriteInfof(msg string, args ...interface{})
WriteErrorf(msg string, args ...interface{}) WriteErrorf(msg string, args ...interface{})
WriteWarnf(msg string, args ...interface{}) WriteWarnf(msg string, args ...interface{})
Writer() io.Writer
} }
type LogrusLogger struct { type LogrusLogger struct {
@ -33,6 +35,10 @@ func (l *LogrusLogger) WriteWarnf(msg string, args ...interface{}) {
l.logger.Warnf(msg, args...) l.logger.Warnf(msg, args...)
} }
func (l *LogrusLogger) Writer() io.Writer {
return l.logger.Writer()
}
func NewLogrusLogger() *LogrusLogger { func NewLogrusLogger() *LogrusLogger {
logger := logrus.New() logger := logrus.New()
logger.SetFormatter(&logrus.TextFormatter{FullTimestamp: true}) logger.SetFormatter(&logrus.TextFormatter{FullTimestamp: true})

46
pkg/mesh/alias.go Normal file
View File

@ -0,0 +1,46 @@
package mesh
import (
"fmt"
"github.com/tim-beatham/wgmesh/pkg/hosts"
)
type MeshAliasManager interface {
AddAliases(nodes []MeshNode)
RemoveAliases(node []MeshNode)
}
type AliasManager struct {
hosts hosts.HostsManipulator
}
// AddAliases: on node update or change add aliases to the hosts file
func (a *AliasManager) AddAliases(nodes []MeshNode) {
for _, node := range nodes {
if node.GetAlias() != "" {
a.hosts.AddAddr(hosts.HostsEntry{
Alias: fmt.Sprintf("%s.smeg", node.GetAlias()),
Ip: node.GetWgHost().IP,
})
}
}
}
// RemoveAliases: on node remove remove aliases from the hosts file
func (a *AliasManager) RemoveAliases(nodes []MeshNode) {
for _, node := range nodes {
if node.GetAlias() != "" {
a.hosts.Remove(hosts.HostsEntry{
Alias: fmt.Sprintf("%s.smeg", node.GetAlias()),
Ip: node.GetWgHost().IP,
})
}
}
}
func NewAliasManager() MeshAliasManager {
return &AliasManager{
hosts: hosts.NewHostsManipulator(),
}
}

View File

@ -25,11 +25,16 @@ type MeshManager interface {
GetSelf(meshId string) (MeshNode, error) GetSelf(meshId string) (MeshNode, error)
ApplyConfig() error ApplyConfig() error
SetDescription(description string) error SetDescription(description string) error
SetAlias(alias string) error
SetService(service string, value string) error
RemoveService(service string) error
UpdateTimeStamp() error UpdateTimeStamp() error
GetClient() *wgctrl.Client GetClient() *wgctrl.Client
GetMeshes() map[string]MeshProvider GetMeshes() map[string]MeshProvider
Prune() error Prune() error
Close() error Close() error
GetMonitor() MeshMonitor
GetNode(string, string) MeshNode
} }
type MeshManagerImpl struct { type MeshManagerImpl struct {
@ -46,6 +51,54 @@ type MeshManagerImpl struct {
idGenerator lib.IdGenerator idGenerator lib.IdGenerator
ipAllocator ip.IPAllocator ipAllocator ip.IPAllocator
interfaceManipulator wg.WgInterfaceManipulator interfaceManipulator wg.WgInterfaceManipulator
Monitor MeshMonitor
}
// RemoveService implements MeshManager.
func (m *MeshManagerImpl) RemoveService(service string) error {
for _, mesh := range m.Meshes {
err := mesh.RemoveService(m.HostParameters.HostEndpoint, service)
if err != nil {
return err
}
}
return nil
}
// SetService implements MeshManager.
func (m *MeshManagerImpl) SetService(service string, value string) error {
for _, mesh := range m.Meshes {
err := mesh.AddService(m.HostParameters.HostEndpoint, service, value)
if err != nil {
return err
}
}
return nil
}
func (m *MeshManagerImpl) GetNode(meshid, nodeId string) MeshNode {
mesh, ok := m.Meshes[meshid]
if !ok {
return nil
}
node, err := mesh.GetNode(nodeId)
if err != nil {
return nil
}
return node
}
// GetMonitor implements MeshManager.
func (m *MeshManagerImpl) GetMonitor() MeshMonitor {
return m.Monitor
} }
// Prune implements MeshManager. // Prune implements MeshManager.
@ -81,6 +134,7 @@ func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) {
return "", fmt.Errorf("error creating mesh: %w", err) return "", fmt.Errorf("error creating mesh: %w", err)
} }
if !m.conf.StubWg {
err = m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{ err = m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: devName, IfName: devName,
Port: port, Port: port,
@ -89,6 +143,7 @@ func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) {
if err != nil { if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err) return "", fmt.Errorf("error creating mesh: %w", err)
} }
}
m.Meshes[meshId] = nodeManager m.Meshes[meshId] = nodeManager
return meshId, nil return meshId, nil
@ -123,12 +178,16 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
m.Meshes[params.MeshId] = meshProvider m.Meshes[params.MeshId] = meshProvider
if !m.conf.StubWg {
return m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{ return m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: params.DevName, IfName: params.DevName,
Port: params.WgPort, Port: params.WgPort,
}) })
} }
return nil
}
// HasChanges returns true if the mesh has changes // HasChanges returns true if the mesh has changes
func (m *MeshManagerImpl) HasChanges(meshId string) bool { func (m *MeshManagerImpl) HasChanges(meshId string) bool {
return m.Meshes[meshId].HasChanges() return m.Meshes[meshId].HasChanges()
@ -159,6 +218,11 @@ func (s *MeshManagerImpl) EnableInterface(meshId string) error {
// GetPublicKey: Gets the public key of the WireGuard mesh // GetPublicKey: Gets the public key of the WireGuard mesh
func (s *MeshManagerImpl) GetPublicKey(meshId string) (*wgtypes.Key, error) { func (s *MeshManagerImpl) GetPublicKey(meshId string) (*wgtypes.Key, error) {
if s.conf.StubWg {
zeroedKey := make([]byte, wgtypes.KeyLen)
return (*wgtypes.Key)(zeroedKey), nil
}
mesh, ok := s.Meshes[meshId] mesh, ok := s.Meshes[meshId]
if !ok { if !ok {
@ -180,7 +244,6 @@ type AddSelfParams struct {
// WgPort is the WireGuard port to advertise // WgPort is the WireGuard port to advertise
WgPort int WgPort int
// Endpoint is the alias of the machine to send routable packets // Endpoint is the alias of the machine to send routable packets
// to
Endpoint string Endpoint string
} }
@ -211,6 +274,7 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
Endpoint: params.Endpoint, Endpoint: params.Endpoint,
}) })
if !s.conf.StubWg {
device, err := mesh.GetDevice() device, err := mesh.GetDevice()
if err != nil { if err != nil {
@ -222,6 +286,7 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
if err != nil { if err != nil {
return fmt.Errorf("addSelf: failed to add address to dev %w", err) return fmt.Errorf("addSelf: failed to add address to dev %w", err)
} }
}
s.Meshes[params.MeshId].AddNode(node) s.Meshes[params.MeshId].AddNode(node)
return s.RouteManager.UpdateRoutes() return s.RouteManager.UpdateRoutes()
@ -241,13 +306,16 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
return err return err
} }
device, err := mesh.GetDevice() if !s.conf.StubWg {
device, e := mesh.GetDevice()
if err != nil { if e != nil {
return err return err
} }
err = s.interfaceManipulator.RemoveInterface(device.Name) err = s.interfaceManipulator.RemoveInterface(device.Name)
}
delete(s.Meshes, meshId) delete(s.Meshes, meshId)
return err return err
} }
@ -259,15 +327,9 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
return nil, fmt.Errorf("mesh %s does not exist", meshId) return nil, fmt.Errorf("mesh %s does not exist", meshId)
} }
snapshot, err := meshInstance.GetMesh() node, err := meshInstance.GetNode(s.HostParameters.HostEndpoint)
if err != nil { if err != nil {
return nil, err
}
node, ok := snapshot.GetNodes()[s.HostParameters.HostEndpoint]
if !ok {
return nil, errors.New("the node doesn't exist in the mesh") return nil, errors.New("the node doesn't exist in the mesh")
} }
@ -281,34 +343,42 @@ func (s *MeshManagerImpl) ApplyConfig() error {
return err return err
} }
return s.RouteManager.InstallRoutes() return nil
} }
func (s *MeshManagerImpl) SetDescription(description string) error { func (s *MeshManagerImpl) SetDescription(description string) error {
for _, mesh := range s.Meshes { for _, mesh := range s.Meshes {
if mesh.NodeExists(s.HostParameters.HostEndpoint) {
err := mesh.SetDescription(s.HostParameters.HostEndpoint, description) err := mesh.SetDescription(s.HostParameters.HostEndpoint, description)
if err != nil { if err != nil {
return err return err
} }
} }
}
return nil return nil
} }
// SetAlias implements MeshManager.
func (s *MeshManagerImpl) SetAlias(alias string) error {
for _, mesh := range s.Meshes {
if mesh.NodeExists(s.HostParameters.HostEndpoint) {
err := mesh.SetAlias(s.HostParameters.HostEndpoint, alias)
if err != nil {
return err
}
}
}
return nil
}
// UpdateTimeStamp updates the timestamp of this node in all meshes // UpdateTimeStamp updates the timestamp of this node in all meshes
func (s *MeshManagerImpl) UpdateTimeStamp() error { func (s *MeshManagerImpl) UpdateTimeStamp() error {
for _, mesh := range s.Meshes { for _, mesh := range s.Meshes {
snapshot, err := mesh.GetMesh() if mesh.NodeExists(s.HostParameters.HostEndpoint) {
err := mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint)
if err != nil {
return err
}
_, exists := snapshot.GetNodes()[s.HostParameters.HostEndpoint]
if exists {
err = mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint)
if err != nil { if err != nil {
return err return err
@ -327,7 +397,12 @@ func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider {
return s.Meshes return s.Meshes
} }
// Close the mesh manager
func (s *MeshManagerImpl) Close() error { func (s *MeshManagerImpl) Close() error {
if s.conf.StubWg {
return nil
}
for _, mesh := range s.Meshes { for _, mesh := range s.Meshes {
dev, err := mesh.GetDevice() dev, err := mesh.GetDevice()
@ -359,7 +434,7 @@ type NewMeshManagerParams struct {
} }
// Creates a new instance of a mesh manager with the given parameters // Creates a new instance of a mesh manager with the given parameters
func NewMeshManager(params *NewMeshManagerParams) *MeshManagerImpl { func NewMeshManager(params *NewMeshManagerParams) MeshManager {
hostParams := HostParameters{} hostParams := HostParameters{}
switch params.Conf.Endpoint { switch params.Conf.Endpoint {
@ -390,5 +465,11 @@ func NewMeshManager(params *NewMeshManagerParams) *MeshManagerImpl {
m.idGenerator = params.IdGenerator m.idGenerator = params.IdGenerator
m.ipAllocator = params.IPAllocator m.ipAllocator = params.IPAllocator
m.interfaceManipulator = params.InterfaceManipulator m.interfaceManipulator = params.InterfaceManipulator
m.Monitor = NewMeshMonitor(m)
aliasManager := NewAliasManager()
m.Monitor.AddUpdateCallback(aliasManager.AddAliases)
m.Monitor.AddRemoveCallback(aliasManager.RemoveAliases)
return m return m
} }

View File

@ -22,7 +22,7 @@ func getMeshConfiguration() *conf.WgMeshConfiguration {
} }
} }
func getMeshManager() *MeshManagerImpl { func getMeshManager() MeshManager {
manager := NewMeshManager(&NewMeshManagerParams{ manager := NewMeshManager(&NewMeshManagerParams{
Conf: *getMeshConfiguration(), Conf: *getMeshConfiguration(),
Client: nil, Client: nil,
@ -51,7 +51,7 @@ func TestCreateMeshCreatesANewMeshProvider(t *testing.T) {
t.Fatal(`meshId should not be empty`) t.Fatal(`meshId should not be empty`)
} }
_, exists := manager.Meshes[meshId] _, exists := manager.GetMeshes()[meshId]
if !exists { if !exists {
t.Fatal(`mesh was not created when it should be`) t.Fatal(`mesh was not created when it should be`)
@ -190,7 +190,7 @@ func TestLeaveMeshDeletesMesh(t *testing.T) {
t.Fatalf("%s", err.Error()) t.Fatalf("%s", err.Error())
} }
_, exists := manager.Meshes[meshId] _, exists := manager.GetMeshes()[meshId]
if exists { if exists {
t.Fatalf(`expected mesh to have been deleted`) t.Fatalf(`expected mesh to have been deleted`)

81
pkg/mesh/monitor.go Normal file
View File

@ -0,0 +1,81 @@
package mesh
type OnChange = func([]MeshNode)
type MeshMonitor interface {
AddUpdateCallback(cb OnChange)
AddRemoveCallback(cb OnChange)
Trigger() error
}
type MeshMonitorImpl struct {
updateCbs []OnChange
removeCbs []OnChange
nodes map[string]MeshNode
manager MeshManager
}
// Trigger causes the mesh monitor to trigger all of
// the callbacks.
func (m *MeshMonitorImpl) Trigger() error {
changedNodes := make([]MeshNode, 0)
removedNodes := make([]MeshNode, 0)
nodes := make(map[string]MeshNode)
for _, mesh := range m.manager.GetMeshes() {
snapshot, err := mesh.GetMesh()
if err != nil {
return err
}
for _, node := range snapshot.GetNodes() {
previous, exists := m.nodes[node.GetWgHost().String()]
if !exists || !NodeEquals(previous, node) {
changedNodes = append(changedNodes, node)
}
nodes[node.GetWgHost().String()] = node
}
}
for _, previous := range m.nodes {
_, ok := nodes[previous.GetWgHost().String()]
if !ok {
removedNodes = append(removedNodes, previous)
}
}
if len(removedNodes) > 0 {
for _, cb := range m.removeCbs {
cb(removedNodes)
}
}
if len(changedNodes) > 0 {
for _, cb := range m.updateCbs {
cb(changedNodes)
}
}
return nil
}
func (m *MeshMonitorImpl) AddUpdateCallback(cb OnChange) {
m.updateCbs = append(m.updateCbs, cb)
}
func (m *MeshMonitorImpl) AddRemoveCallback(cb OnChange) {
m.removeCbs = append(m.removeCbs, cb)
}
func NewMeshMonitor(manager MeshManager) MeshMonitor {
return &MeshMonitorImpl{
updateCbs: make([]OnChange, 0),
nodes: make(map[string]MeshNode),
manager: manager,
}
}

View File

@ -21,6 +21,16 @@ type MeshNodeStub struct {
description string description string
} }
// GetServices implements MeshNode.
func (*MeshNodeStub) GetServices() map[string]string {
return make(map[string]string)
}
// GetAlias implements MeshNode.
func (*MeshNodeStub) GetAlias() string {
return ""
}
func (m *MeshNodeStub) GetHostEndpoint() string { func (m *MeshNodeStub) GetHostEndpoint() string {
return m.hostEndpoint return m.hostEndpoint
} }
@ -66,6 +76,36 @@ type MeshProviderStub struct {
snapshot *MeshSnapshotStub snapshot *MeshSnapshotStub
} }
// GetNodeIds implements MeshProvider.
func (*MeshProviderStub) GetNodeIds() []string {
panic("unimplemented")
}
// GetNode implements MeshProvider.
func (*MeshProviderStub) GetNode(string) (MeshNode, error) {
panic("unimplemented")
}
// NodeExists implements MeshProvider.
func (*MeshProviderStub) NodeExists(string) bool {
panic("unimplemented")
}
// AddService implements MeshProvider.
func (*MeshProviderStub) AddService(nodeId string, key string, value string) error {
panic("unimplemented")
}
// RemoveService implements MeshProvider.
func (*MeshProviderStub) RemoveService(nodeId string, key string) error {
panic("unimplemented")
}
// SetAlias implements MeshProvider.
func (*MeshProviderStub) SetAlias(nodeId string, alias string) error {
panic("unimplemented")
}
// RemoveRoutes implements MeshProvider. // RemoveRoutes implements MeshProvider.
func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error { func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error {
panic("unimplemented") panic("unimplemented")
@ -171,6 +211,31 @@ type MeshManagerStub struct {
meshes map[string]MeshProvider meshes map[string]MeshProvider
} }
// GetNode implements MeshManager.
func (*MeshManagerStub) GetNode(string, string) MeshNode {
panic("unimplemented")
}
// RemoveService implements MeshManager.
func (*MeshManagerStub) RemoveService(service string) error {
panic("unimplemented")
}
// SetService implements MeshManager.
func (*MeshManagerStub) SetService(service string, value string) error {
panic("unimplemented")
}
// GetMonitor implements MeshManager.
func (*MeshManagerStub) GetMonitor() MeshMonitor {
panic("unimplemented")
}
// SetAlias implements MeshManager.
func (*MeshManagerStub) SetAlias(alias string) error {
panic("unimplemented")
}
// Close implements MeshManager. // Close implements MeshManager.
func (*MeshManagerStub) Close() error { func (*MeshManagerStub) Close() error {
panic("unimplemented") panic("unimplemented")

View File

@ -4,6 +4,7 @@ package mesh
import ( import (
"net" "net"
"slices"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/conf"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
@ -28,6 +29,51 @@ type MeshNode interface {
GetIdentifier() string GetIdentifier() string
// GetDescription: returns the description for this node // GetDescription: returns the description for this node
GetDescription() string GetDescription() string
// GetAlias: associates the node with an alias. Potentially used
// for DNS and so forth.
GetAlias() string
// GetServices: returns a list of services offered by the node
GetServices() map[string]string
}
// NodeEquals: determines if two mesh nodes are equivalent to one another
func NodeEquals(node1, node2 MeshNode) bool {
if node1.GetHostEndpoint() != node2.GetHostEndpoint() {
return false
}
node1Pub, _ := node1.GetPublicKey()
node2Pub, _ := node2.GetPublicKey()
if node1Pub != node2Pub {
return false
}
if node1.GetWgEndpoint() != node2.GetWgEndpoint() {
return false
}
if node1.GetWgHost() != node2.GetWgHost() {
return false
}
if !slices.Equal(node1.GetRoutes(), node2.GetRoutes()) {
return false
}
if node1.GetIdentifier() != node2.GetIdentifier() {
return false
}
if node1.GetDescription() != node2.GetDescription() {
return false
}
if node1.GetAlias() != node2.GetAlias() {
return false
}
return true
} }
type MeshSnapshot interface { type MeshSnapshot interface {
@ -46,7 +92,7 @@ type MeshSyncer interface {
type MeshProvider interface { type MeshProvider interface {
// AddNode() adds a node to the mesh // AddNode() adds a node to the mesh
AddNode(node MeshNode) AddNode(node MeshNode)
// GetMesh() returns a snapshot of the mesh provided by the mesh provider // GetMesh() returns a snapshot of the mesh provided by the mesh provider.
GetMesh() (MeshSnapshot, error) GetMesh() (MeshSnapshot, error)
// GetMeshId() returns the ID of the mesh network // GetMeshId() returns the ID of the mesh network
GetMeshId() string GetMeshId() string
@ -68,11 +114,22 @@ type MeshProvider interface {
RemoveRoutes(nodeId string, route ...string) error RemoveRoutes(nodeId string, route ...string) error
// GetSyncer: returns the automerge syncer for sync // GetSyncer: returns the automerge syncer for sync
GetSyncer() MeshSyncer GetSyncer() MeshSyncer
// GetNode get a particular not within the mesh
GetNode(string) (MeshNode, error)
// NodeExists: returns true if a particular node exists false otherwise
NodeExists(string) bool
// SetDescription: sets the description of this automerge data type // SetDescription: sets the description of this automerge data type
SetDescription(nodeId string, description string) error SetDescription(nodeId string, description string) error
// SetAlias: set the alias of the nodeId
SetAlias(nodeId string, alias string) error
// AddService: adds the service to the given node
AddService(nodeId, key, value string) error
// RemoveService: removes the service form the node. throws an error if the service does not exist
RemoveService(nodeId, key string) error
// Prune: prunes all nodes that have not updated their timestamp in // Prune: prunes all nodes that have not updated their timestamp in
// pruneAmount seconds // pruneAmount seconds
Prune(pruneAmount int) error Prune(pruneAmount int) error
GetNodeIds() []string
} }
// HostParameters contains the IDs of a node // HostParameters contains the IDs of a node

View File

@ -31,6 +31,8 @@ type QueryNode struct {
Timestamp int64 `json:"timestmap"` Timestamp int64 `json:"timestmap"`
Description string `json:"description"` Description string `json:"description"`
Routes []string `json:"routes"` Routes []string `json:"routes"`
Alias string `json:"alias"`
Services map[string]string `json:"services"`
} }
func (m *QueryError) Error() string { func (m *QueryError) Error() string {
@ -51,7 +53,7 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
return nil, err return nil, err
} }
nodes := lib.Map(lib.MapValues(snapshot.GetNodes()), meshNodeToQueryNode) nodes := lib.Map(lib.MapValues(snapshot.GetNodes()), MeshNodeToQueryNode)
result, err := jmespath.Search(queryParams, nodes) result, err := jmespath.Search(queryParams, nodes)
@ -63,7 +65,7 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
return bytes, err return bytes, err
} }
func meshNodeToQueryNode(node mesh.MeshNode) *QueryNode { func MeshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode := new(QueryNode) queryNode := new(QueryNode)
queryNode.HostEndpoint = node.GetHostEndpoint() queryNode.HostEndpoint = node.GetHostEndpoint()
pubKey, _ := node.GetPublicKey() pubKey, _ := node.GetPublicKey()
@ -76,6 +78,9 @@ func meshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode.Timestamp = node.GetTimeStamp() queryNode.Timestamp = node.GetTimeStamp()
queryNode.Routes = node.GetRoutes() queryNode.Routes = node.GetRoutes()
queryNode.Description = node.GetDescription() queryNode.Description = node.GetDescription()
queryNode.Alias = node.GetAlias()
queryNode.Services = node.GetServices()
return queryNode return queryNode
} }

View File

@ -2,6 +2,7 @@ package robin
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
@ -10,6 +11,7 @@ import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/wgmesh/pkg/rpc"
) )
@ -117,6 +119,11 @@ func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error {
func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error { func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
mesh := n.Server.GetMeshManager().GetMesh(meshId) mesh := n.Server.GetMeshManager().GetMesh(meshId)
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
meshSnapshot, err := mesh.GetMesh() meshSnapshot, err := mesh.GetMesh()
if err != nil { if err != nil {
@ -144,6 +151,9 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
WgHost: node.GetWgHost().String(), WgHost: node.GetWgHost().String(),
Timestamp: node.GetTimeStamp(), Timestamp: node.GetTimeStamp(),
Routes: node.GetRoutes(), Routes: node.GetRoutes(),
Description: node.GetDescription(),
Alias: node.GetAlias(),
Services: node.GetServices(),
} }
nodes[i] = node nodes[i] = node
@ -201,6 +211,60 @@ func (n *IpcHandler) PutDescription(description string, reply *string) error {
return nil return nil
} }
func (n *IpcHandler) PutAlias(alias string, reply *string) error {
err := n.Server.GetMeshManager().SetAlias(alias)
if err != nil {
return err
}
*reply = fmt.Sprintf("Set alias to %s", alias)
return nil
}
func (n *IpcHandler) PutService(service ipc.PutServiceArgs, reply *string) error {
err := n.Server.GetMeshManager().SetService(service.Service, service.Value)
if err != nil {
return err
}
*reply = "success"
return nil
}
func (n *IpcHandler) DeleteService(service string, reply *string) error {
err := n.Server.GetMeshManager().RemoveService(service)
if err != nil {
return err
}
*reply = "success"
return nil
}
func (n *IpcHandler) GetNode(args ipc.GetNodeArgs, reply *string) error {
node := n.Server.GetMeshManager().GetNode(args.MeshId, args.NodeId)
if node == nil {
*reply = "nil"
return nil
}
queryNode := query.MeshNodeToQueryNode(node)
bytes, err := json.Marshal(queryNode)
if err != nil {
*reply = err.Error()
return nil
}
*reply = string(bytes)
return nil
}
type RobinIpcParams struct { type RobinIpcParams struct {
CtrlServer ctrlserver.CtrlServer CtrlServer ctrlserver.CtrlServer
} }

View File

@ -1,7 +1,6 @@
package sync package sync
import ( import (
"errors"
"math/rand" "math/rand"
"sync" "sync"
"time" "time"
@ -30,35 +29,22 @@ type SyncerImpl struct {
// Sync: Sync random nodes // Sync: Sync random nodes
func (s *SyncerImpl) Sync(meshId string) error { func (s *SyncerImpl) Sync(meshId string) error {
logging.Log.WriteInfof("UPDATING WG CONF")
err := s.manager.ApplyConfig()
if err != nil {
logging.Log.WriteInfof("Failed to update config %w", err)
}
if !s.manager.HasChanges(meshId) && s.infectionCount == 0 { if !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
logging.Log.WriteInfof("No changes for %s", meshId) logging.Log.WriteInfof("No changes for %s", meshId)
return nil return nil
} }
theMesh := s.manager.GetMesh(meshId) logging.Log.WriteInfof("UPDATING WG CONF")
if theMesh == nil { if s.manager.HasChanges(meshId) {
return errors.New("the provided mesh does not exist") err := s.manager.ApplyConfig()
}
snapshot, err := theMesh.GetMesh()
if err != nil { if err != nil {
return err logging.Log.WriteInfof("Failed to update config %w", err)
}
} }
nodes := snapshot.GetNodes() nodeNames := s.manager.GetMesh(meshId).GetNodeIds()
if len(nodes) <= 1 {
return nil
}
self, err := s.manager.GetSelf(meshId) self, err := s.manager.GetSelf(meshId)
@ -66,17 +52,6 @@ func (s *SyncerImpl) Sync(meshId string) error {
return err return err
} }
excludedNodes := map[string]struct{}{
self.GetHostEndpoint(): {},
}
meshNodes := lib.MapValuesWithExclude(nodes, excludedNodes)
getNames := func(node mesh.MeshNode) string {
return node.GetHostEndpoint()
}
nodeNames := lib.Map(meshNodes, getNames)
neighbours := s.cluster.GetNeighbours(nodeNames, self.GetHostEndpoint()) neighbours := s.cluster.GetNeighbours(nodeNames, self.GetHostEndpoint())
randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate) randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
@ -86,7 +61,7 @@ func (s *SyncerImpl) Sync(meshId string) error {
before := time.Now() before := time.Now()
if len(meshNodes) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance { if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
logging.Log.WriteInfof("Sending to random cluster") logging.Log.WriteInfof("Sending to random cluster")
interCluster := s.cluster.GetInterCluster(nodeNames, self.GetHostEndpoint()) interCluster := s.cluster.GetInterCluster(nodeNames, self.GetHostEndpoint())
randomSubset = append(randomSubset, interCluster) randomSubset = append(randomSubset, interCluster)
@ -107,16 +82,20 @@ func (s *SyncerImpl) Sync(meshId string) error {
waitGroup.Wait() waitGroup.Wait()
s.syncCount++ s.syncCount++
logging.Log.WriteInfof("SYNC TIME: %v", time.Now().Sub(before)) logging.Log.WriteInfof("SYNC TIME: %v", time.Since(before))
logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount) logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount)
s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount) s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount)
// Check if any changes have occurred and trigger callbacks
// if changes have occurred.
// return s.manager.GetMonitor().Trigger()
return nil return nil
} }
// SyncMeshes: Sync all meshes // SyncMeshes: Sync all meshes
func (s *SyncerImpl) SyncMeshes() error { func (s *SyncerImpl) SyncMeshes() error {
for meshId, _ := range s.manager.GetMeshes() { for meshId := range s.manager.GetMeshes() {
err := s.Sync(meshId) err := s.Sync(meshId)
if err != nil { if err != nil {

1
smegmesh-web Submodule

Submodule smegmesh-web added at c1128bcd98