1
0
forked from extern/smegmesh

Compare commits

...

28 Commits

Author SHA1 Message Date
388153e706 Stubbing out WireGuard components
Stubbing our WireGuard components so that I can use docker/podman
network_mode=host. This is much more efficient than the docker/podman
userspace network.
2023-11-20 11:28:12 +00:00
023565d985 Merge pull request #17 from tim-beatham/25-ability-to-aliases
25 ability to aliases
2023-11-17 22:20:57 +00:00
36c264b38e 25-ability-aliases
Fixed unit tests failing
2023-11-17 22:18:53 +00:00
68db795f47 Ability to specify aliases
Ability to specify aliases that automatically append to /etc/hosts
2023-11-17 22:13:51 +00:00
f6160fe138 Adding aliases that automatically gets added 2023-11-17 19:13:20 +00:00
2c5289afb0 Merge pull request #16 from tim-beatham/15-add-rest-api
Developed a rest API
2023-11-15 12:57:05 +00:00
7199d07a76 Added smegmesh submodule 2023-11-13 10:46:52 +00:00
5f176e731f Developed a rest API 2023-11-13 10:44:14 +00:00
44f119b45c Updating examples 2023-11-08 09:19:24 +00:00
5215d5d54d Merge pull request #14 from tim-beatham/13-netlink-api
Removed interface manipulation via os.Exec into
2023-11-07 19:53:39 +00:00
1a864b7c80 Removed interface manipulation via os.Exec into
rtnetlink calls
2023-11-07 19:48:53 +00:00
4c19ebd81f Merge pull request #12 from tim-beatham/11-health-system
11 health system
2023-11-06 13:40:04 +00:00
acbeb689b5 Prune nodes if they exceed their timeout time 2023-11-06 13:37:28 +00:00
bc6cd4fdd5 Modified syncer 2023-11-06 10:05:23 +00:00
c88012cf71 Added health system to count how many times a node
fails to conenct.
2023-11-06 09:54:06 +00:00
4dc85f3861 Merge pull request #10 from tim-beatham/9-add-ci-support
9 add ci support
2023-11-05 18:07:52 +00:00
ef614f5961 Add cert dependencies 2023-11-05 18:06:24 +00:00
9454d62417 Adding stubs and writing tests 2023-11-05 18:03:58 +00:00
bb07d35dcb Unit testing the automerge library and lib functions 2023-11-05 12:13:40 +00:00
76dda2cf6f Update go.mod 2023-11-05 10:54:38 +00:00
1b286dd3c1 Update go.yml 2023-11-05 10:53:57 +00:00
2d45c2d298 Run go mod tidy in workflow 2023-11-05 10:51:24 +00:00
900c67a121 Update go.mod 2023-11-05 10:49:18 +00:00
b2fa08a642 Reverted go version 2023-11-05 10:48:35 +00:00
a4e9a5cd0f Updated go version in workflow 2023-11-05 10:47:10 +00:00
275eb423fb Create GitHub hosted test runner go.yml 2023-11-05 10:45:39 +00:00
d17dce3b1e Added clustering and clean up 2023-11-03 15:26:09 +00:00
e2c6db3a4f Merge pull request #8 from tim-beatham/7-create-rotating-window-of-connections
Implemented clustering betweeen nodes
2023-11-03 15:25:30 +00:00
70 changed files with 4273 additions and 415 deletions

31
.github/workflows/go.yml vendored Normal file
View File

@ -0,0 +1,31 @@
# This workflow will build a golang project
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go
name: Go
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: '1.21'
- name: Tidy
run: go mod tidy
- name: Build
run: go build -v ./...
- name: Test
run: go test -v ./...

3
.gitmodules vendored Normal file
View File

@ -0,0 +1,3 @@
[submodule "smegmesh-web"]
path = smegmesh-web
url = git@github.com:tim-beatham/smegmesh-web.git

11
Containerfile Normal file
View File

@ -0,0 +1,11 @@
FROM docker.io/library/golang:bookworm
COPY ./ /wgmesh
RUN apt-get update && apt-get install -y \
wireguard \
wireguard-tools \
iproute2 \
iputils-ping \
tmux \
vim
WORKDIR /wgmesh
RUN go build -o /usr/local/bin ./...

1
Dockerfile Symbolic link
View File

@ -0,0 +1 @@
Containerfile

17
cmd/api/main.go Normal file
View File

@ -0,0 +1,17 @@
package main
import (
"log"
"github.com/tim-beatham/wgmesh/pkg/api"
)
func main() {
apiServer, err := api.NewSmegServer()
if err != nil {
log.Fatal(err.Error())
}
apiServer.Run(":40000")
}

View File

@ -171,6 +171,68 @@ func putDescription(client *ipcRpc.Client, description string) {
fmt.Println(reply)
}
// putAlias: puts an alias for the node
func putAlias(client *ipcRpc.Client, alias string) {
var reply string
err := client.Call("IpcHandler.PutAlias", &alias, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func setService(client *ipcRpc.Client, service, value string) {
var reply string
serviceArgs := &ipc.PutServiceArgs{
Service: service,
Value: value,
}
err := client.Call("IpcHandler.PutService", serviceArgs, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func deleteService(client *ipcRpc.Client, service string) {
var reply string
err := client.Call("IpcHandler.PutService", &service, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func getNode(client *ipcRpc.Client, nodeId, meshId string) {
var reply string
args := &ipc.GetNodeArgs{
NodeId: nodeId,
MeshId: meshId,
}
err := client.Call("IpcHandler.GetNode", &args, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func main() {
parser := argparse.NewParser("wg-mesh",
"wg-mesh Manipulate WireGuard meshes")
@ -184,6 +246,10 @@ func main() {
leaveMeshCmd := parser.NewCommand("leave-mesh", "Leave a mesh network")
queryMeshCmd := parser.NewCommand("query-mesh", "Query a mesh network using JMESPath")
putDescriptionCmd := parser.NewCommand("put-description", "Place a description for the node")
putAliasCmd := parser.NewCommand("put-alias", "Place an alias for the node")
setServiceCmd := parser.NewCommand("set-service", "Place a service into your advertisements")
deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements")
getNodeCmd := parser.NewCommand("get-node", "Get a specific node from the mesh")
var newMeshIfName *string = newMeshCmd.String("f", "ifname", &argparse.Options{Required: true})
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
@ -195,8 +261,6 @@ func main() {
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{})
// var getMeshId *string = getMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true})
var getGraphMeshId *string = getGraphCmd.String("m", "mesh", &argparse.Options{Required: true})
@ -208,6 +272,16 @@ func main() {
var description *string = putDescriptionCmd.String("d", "description", &argparse.Options{Required: true})
var alias *string = putAliasCmd.String("a", "alias", &argparse.Options{Required: true})
var serviceKey *string = setServiceCmd.String("s", "service", &argparse.Options{Required: true})
var serviceValue *string = setServiceCmd.String("v", "value", &argparse.Options{Required: true})
var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{Required: true})
var getNodeNodeId *string = getNodeCmd.String("n", "nodeid", &argparse.Options{Required: true})
var getNodeMeshId *string = getNodeCmd.String("m", "meshid", &argparse.Options{Required: true})
err := parser.Parse(os.Args)
if err != nil {
@ -245,10 +319,6 @@ func main() {
}))
}
// if getMeshCmd.Happened() {
// getMesh(client, *getMeshId)
// }
if getGraphCmd.Happened() {
getGraph(client, *getGraphMeshId)
}
@ -268,4 +338,20 @@ func main() {
if putDescriptionCmd.Happened() {
putDescription(client, *description)
}
if putAliasCmd.Happened() {
putAlias(client, *alias)
}
if setServiceCmd.Happened() {
setService(client, *serviceKey, *serviceValue)
}
if deleteServiceCmd.Happened() {
deleteService(client, *deleteServiceKey)
}
if getNodeCmd.Happened() {
getNode(client, *getNodeNodeId, *getNodeMeshId)
}
}

View File

@ -2,6 +2,7 @@ certificatePath: "/wgmesh/cert/cert.pem"
privateKeyPath: "/wgmesh/cert/priv.pem"
caCertificatePath: "/wgmesh/cert/cacert.pem"
skipCertVerification: true
timeout: 5
gRPCPort: "21906"
advertiseRoutes: true
clusterSize: 32
@ -9,4 +10,5 @@ syncRate: 1
interClusterChance: 0.15
branchRate: 3
infectionCount: 3
keepAliveRate: 60
keepAliveTime: 10
pruneTime: 20

View File

@ -1,7 +1,8 @@
package main
import (
"log"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
@ -9,7 +10,7 @@ import (
ctrlserver "github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/middleware"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/robin"
"github.com/tim-beatham/wgmesh/pkg/sync"
"github.com/tim-beatham/wgmesh/pkg/timestamp"
@ -35,14 +36,18 @@ func main() {
return
}
if conf.Profile {
go func() {
http.ListenAndServe("localhost:6060", nil)
}()
}
var robinRpc robin.WgRpc
var robinIpc robin.IpcHandler
var authProvider middleware.AuthRpcProvider
var syncProvider sync.SyncServiceImpl
ctrlServerParams := ctrlserver.NewCtrlServerParams{
Conf: conf,
AuthProvider: &authProvider,
CtrlProvider: &robinRpc,
SyncProvider: &syncProvider,
Client: client,
@ -53,6 +58,7 @@ func main() {
syncRequester := sync.NewSyncRequester(ctrlServer)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester)
timestampScheduler := timestamp.NewTimestampScheduler(ctrlServer)
pruneScheduler := mesh.NewPruner(ctrlServer.MeshManager, *conf)
robinIpcParams := robin.RobinIpcParams{
CtrlServer: ctrlServer,
@ -66,11 +72,12 @@ func main() {
return
}
log.Println("Running IPC Handler")
logging.Log.WriteInfof("Running IPC Handler")
go ipc.RunIpcHandler(&robinIpc)
go syncScheduler.Run()
go timestampScheduler.Run()
go pruneScheduler.Run()
closeResources := func() {
logging.Log.WriteInfof("Closing resources")

View File

@ -0,0 +1,95 @@
version: '3'
networks:
net-1:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.0.0/17
net-2:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.155.0/17
services:
wg-1:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-2:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-3:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-4:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
sysctls:
- net.ipv6.conf.all.forwarding=1
networks:
- net-1
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-5:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-6:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-7:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"

View File

@ -0,0 +1,14 @@
certificatePath: "/wgmesh/cert/cert.pem"
privateKeyPath: "/wgmesh/cert/priv.pem"
caCertificatePath: "/wgmesh/cert/cacert.pem"
skipCertVerification: true
timeout: 5
gRPCPort: "21906"
advertiseRoutes: true
clusterSize: 32
syncRate: 1
interClusterChance: 0.15
branchRate: 3
infectionCount: 3
keepAliveTime: 10
pruneTime: 20

View File

@ -0,0 +1,42 @@
version: '3'
networks:
net-1:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.0.0/17
services:
wg-1:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-2:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-3:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"

View File

@ -0,0 +1,14 @@
certificatePath: "/wgmesh/cert/cert.pem"
privateKeyPath: "/wgmesh/cert/priv.pem"
caCertificatePath: "/wgmesh/cert/cacert.pem"
skipCertVerification: true
timeout: 5
gRPCPort: "21906"
advertiseRoutes: true
clusterSize: 32
syncRate: 1
interClusterChance: 0.15
branchRate: 3
infectionCount: 3
keepAliveTime: 10
pruneTime: 20

24
go.mod
View File

@ -1,13 +1,16 @@
module github.com/tim-beatham/wgmesh
go 1.21.1
go 1.21.3
require (
github.com/akamensky/argparse v1.4.0
github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9
github.com/gin-gonic/gin v1.9.1
github.com/google/uuid v1.3.0
github.com/jmespath/go-jmespath v0.4.0
github.com/jsimonetti/rtnetlink v1.3.5
github.com/sirupsen/logrus v1.9.3
golang.org/x/sys v0.14.0
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
google.golang.org/grpc v1.58.1
google.golang.org/protobuf v1.31.0
@ -15,16 +18,33 @@ require (
)
require (
github.com/bytedance/sonic v1.9.1 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.13.0 // indirect
golang.org/x/net v0.15.0 // indirect
golang.org/x/sync v0.3.0 // indirect
golang.org/x/sys v0.12.0 // indirect
golang.org/x/text v0.13.0 // indirect
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect

205
pkg/api/apiserver.go Normal file
View File

@ -0,0 +1,205 @@
package api
import (
"fmt"
"net/http"
ipcRpc "net/rpc"
"github.com/gin-gonic/gin"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
type ApiServer interface {
GetMeshes(c *gin.Context)
Run(addr string) error
}
type SmegServer struct {
router *gin.Engine
client *ipcRpc.Client
}
func meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode {
if meshNode.Routes == nil {
meshNode.Routes = make([]string, 0)
}
return &SmegNode{
WgHost: meshNode.WgHost,
WgEndpoint: meshNode.WgEndpoint,
Endpoint: meshNode.HostEndpoint,
Timestamp: int(meshNode.Timestamp),
Description: meshNode.Description,
Routes: meshNode.Routes,
PublicKey: meshNode.PublicKey,
Alias: meshNode.Alias,
Services: meshNode.Services,
}
}
func meshToAPIMesh(meshId string, nodes []ctrlserver.MeshNode) SmegMesh {
var smegMesh SmegMesh
smegMesh.MeshId = meshId
smegMesh.Nodes = make(map[string]SmegNode)
for _, node := range nodes {
smegMesh.Nodes[node.WgHost] = *meshNodeToAPIMeshNode(node)
}
return smegMesh
}
// CreateMesh: creates a mesh network
func (s *SmegServer) CreateMesh(c *gin.Context) {
var createMesh CreateMeshRequest
if err := c.ShouldBindJSON(&createMesh); err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
ipcRequest := ipc.NewMeshArgs{
IfName: createMesh.IfName,
WgPort: createMesh.WgPort,
}
var reply string
err := s.client.Call("IpcHandler.CreateMesh", &ipcRequest, &reply)
if err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
c.JSON(http.StatusOK, &gin.H{
"meshid": reply,
})
}
// JoinMesh: joins a mesh network
func (s *SmegServer) JoinMesh(c *gin.Context) {
var joinMesh JoinMeshRequest
if err := c.ShouldBindJSON(&joinMesh); err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
ipcRequest := ipc.JoinMeshArgs{
MeshId: joinMesh.MeshId,
IpAdress: joinMesh.Bootstrap,
IfName: joinMesh.IfName,
Port: joinMesh.WgPort,
}
var reply string
err := s.client.Call("IpcHandler.JoinMesh", &ipcRequest, &reply)
if err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
c.JSON(http.StatusOK, &gin.H{
"status": "success",
})
}
// GetMesh: given a meshId returns the corresponding mesh
// network.
func (s *SmegServer) GetMesh(c *gin.Context) {
meshidParam := c.Param("meshid")
var meshid string = meshidParam
getMeshReply := new(ipc.GetMeshReply)
err := s.client.Call("IpcHandler.GetMesh", &meshid, &getMeshReply)
if err != nil {
c.JSON(http.StatusNotFound,
&gin.H{
"error": fmt.Sprintf("could not find mesh %s", meshidParam),
})
return
}
mesh := meshToAPIMesh(meshidParam, getMeshReply.Nodes)
c.JSON(http.StatusOK, mesh)
}
func (s *SmegServer) GetMeshes(c *gin.Context) {
listMeshesReply := new(ipc.ListMeshReply)
err := s.client.Call("IpcHandler.ListMeshes", "", &listMeshesReply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
c.JSON(http.StatusBadRequest, nil)
return
}
meshes := make([]SmegMesh, 0)
for _, mesh := range listMeshesReply.Meshes {
getMeshReply := new(ipc.GetMeshReply)
err := s.client.Call("IpcHandler.GetMesh", &mesh, &getMeshReply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
c.JSON(http.StatusBadRequest, nil)
return
}
meshes = append(meshes, meshToAPIMesh(mesh, getMeshReply.Nodes))
}
c.JSON(http.StatusOK, meshes)
}
func (s *SmegServer) Run(addr string) error {
logging.Log.WriteInfof("Running API server")
return s.router.Run(addr)
}
func NewSmegServer() (ApiServer, error) {
client, err := ipcRpc.DialHTTP("unix", SockAddr)
if err != nil {
return nil, err
}
router := gin.Default()
router.Use(gin.LoggerWithConfig(gin.LoggerConfig{
Output: logging.Log.Writer(),
}))
smegServer := &SmegServer{
router: router,
client: client,
}
router.GET("/meshes", smegServer.GetMeshes)
router.GET("/mesh/:meshid", smegServer.GetMesh)
router.POST("/mesh/create", smegServer.CreateMesh)
router.POST("/mesh/join", smegServer.JoinMesh)
return smegServer, nil
}

30
pkg/api/types.go Normal file
View File

@ -0,0 +1,30 @@
package api
type SmegNode struct {
Alias string `json:"alias"`
WgHost string `json:"wgHost"`
WgEndpoint string `json:"wgEndpoint"`
Endpoint string `json:"endpoint"`
Timestamp int `json:"timestamp"`
Description string `json:"description"`
PublicKey string `json:"publicKey"`
Routes []string `json:"routes"`
Services map[string]string `json:"services"`
}
type SmegMesh struct {
MeshId string `json:"meshid"`
Nodes map[string]SmegNode `json:"nodes"`
}
type CreateMeshRequest struct {
IfName string `json:"ifName" binding:"required"`
WgPort int `json:"port" binding:"required,gte=1024,lt=65535"`
}
type JoinMeshRequest struct {
IfName string `json:"ifName" binding:"required"`
WgPort int `json:"port" binding:"required,gte=1024,lt=65535"`
Bootstrap string `json:"bootstrap" binding:"required"`
MeshId string `json:"meshid" binding:"required"`
}

View File

@ -18,13 +18,14 @@ import (
// CrdtMeshManager manages nodes in the crdt mesh
type CrdtMeshManager struct {
MeshId string
IfName string
NodeId string
Client *wgctrl.Client
doc *automerge.Doc
LastHash automerge.ChangeHash
conf *conf.WgMeshConfiguration
MeshId string
IfName string
Client *wgctrl.Client
doc *automerge.Doc
LastHash automerge.ChangeHash
conf *conf.WgMeshConfiguration
cache *MeshCrdt
lastCacheHash automerge.ChangeHash
}
func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
@ -34,15 +35,38 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
panic("node must be of type *MeshNodeCrdt")
}
crdt.Routes = make(map[string]interface{})
crdt.Services = make(map[string]string)
crdt.Timestamp = time.Now().Unix()
c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt)
nodeVal, _ := c.doc.Path("nodes").Map().Get(crdt.HostEndpoint)
nodeVal.Map().Set("routes", automerge.NewMap())
}
func (c *CrdtMeshManager) GetNodeIds() []string {
keys, _ := c.doc.Path("nodes").Map().Keys()
return keys
}
// GetMesh(): Converts the document into a struct
func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
return automerge.As[*MeshCrdt](c.doc.Root())
changes, err := c.doc.Changes(c.lastCacheHash)
if err != nil {
return nil, err
}
if c.cache == nil || len(changes) > 3 {
c.lastCacheHash = c.LastHash
cache, err := automerge.As[*MeshCrdt](c.doc.Root())
if err != nil {
return nil, err
}
c.cache = cache
}
return c.cache, nil
}
// GetMeshId returns the meshid of the mesh
@ -82,23 +106,23 @@ func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, erro
manager.IfName = params.DevName
manager.Client = params.Client
manager.conf = &params.Conf
manager.cache = nil
return &manager, nil
}
func (c *CrdtMeshManager) removeNode(endpoint string) error {
err := c.doc.Path("nodes").Map().Delete(endpoint)
if err != nil {
return err
}
return nil
// NodeExists: returns true if the node exists. Returns false
func (m *CrdtMeshManager) NodeExists(key string) bool {
node, err := m.doc.Path("nodes").Map().Get(key)
return node.Kind() == automerge.KindMap && err != nil
}
// GetNode: returns a mesh node crdt.Close releases resources used by a Client.
func (m *CrdtMeshManager) GetNode(endpoint string) (*MeshNodeCrdt, error) {
func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
node, err := m.doc.Path("nodes").Map().Get(endpoint)
if node.Kind() != automerge.KindMap {
return nil, fmt.Errorf("GetNode: something went wrong %s is not a map type")
}
if err != nil {
return nil, err
}
@ -176,7 +200,7 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro
}
if node.Kind() != automerge.KindMap {
return errors.New(fmt.Sprintf("%s does not exist", nodeId))
return fmt.Errorf("%s does not exist", nodeId)
}
err = node.Map().Set("description", description)
@ -188,6 +212,72 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro
return err
}
func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
return err
}
if node.Kind() != automerge.KindMap {
return fmt.Errorf("%s does not exist", nodeId)
}
err = node.Map().Set("alias", alias)
if err == nil {
logging.Log.WriteInfof("Updated Alias for %s to %s", nodeId, alias)
}
return err
}
func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return fmt.Errorf("AddService: node %s does not exist", nodeId)
}
service, err := node.Map().Get("services")
if err != nil {
return err
}
if service.Kind() != automerge.KindMap {
return fmt.Errorf("AddService: services property does not exist in node")
}
return service.Map().Set(key, value)
}
func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return fmt.Errorf("RemoveService: node %s does not exist", nodeId)
}
service, err := node.Map().Get("services")
if err != nil {
return err
}
if service.Kind() != automerge.KindMap {
return fmt.Errorf("services property does not exist")
}
err = service.Map().Delete(key)
if err != nil {
return fmt.Errorf("service %s does not exist", key)
}
return nil
}
// AddRoutes: adds routes to the specific nodeId
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -197,6 +287,10 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
return err
}
if nodeVal.Kind() != automerge.KindMap {
return fmt.Errorf("node does not exist")
}
routeMap, err := nodeVal.Map().Get("routes")
if err != nil {
@ -210,14 +304,90 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
return err
}
}
return nil
}
// DeleteRoutes deletes the specified routes
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
return err
}
if nodeVal.Kind() != automerge.KindMap {
return fmt.Errorf("node is not a map")
}
routeMap, err := nodeVal.Map().Get("routes")
if err != nil {
return err
}
for _, route := range routes {
err = routeMap.Map().Delete(route)
}
return err
}
func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
return NewAutomergeSync(m)
}
func (m *CrdtMeshManager) Prune(pruneTime int) error {
nodes, err := m.doc.Path("nodes").Get()
if err != nil {
return err
}
if nodes.Kind() != automerge.KindMap {
return errors.New("node must be a map")
}
values, err := nodes.Map().Values()
if err != nil {
return err
}
deletionNodes := make([]string, 0)
for nodeId, node := range values {
if node.Kind() != automerge.KindMap {
return errors.New("node must be a map")
}
nodeMap := node.Map()
timeStamp, err := nodeMap.Get("timestamp")
if err != nil {
return err
}
if timeStamp.Kind() != automerge.KindInt64 {
return errors.New("timestamp is not int64")
}
timeValue := timeStamp.Int64()
nowValue := time.Now().Unix()
if nowValue-timeValue >= int64(pruneTime) {
deletionNodes = append(deletionNodes, nodeId)
}
}
for _, node := range deletionNodes {
logging.Log.WriteInfof("Pruning %s", node)
nodes.Map().Delete(node)
}
return nil
}
func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int {
return strings.Compare(m1.PublicKey, m2.PublicKey)
}
@ -266,6 +436,20 @@ func (m *MeshNodeCrdt) GetIdentifier() string {
return strings.Join(constituents, ":")
}
func (m *MeshNodeCrdt) GetAlias() string {
return m.Alias
}
func (m *MeshNodeCrdt) GetServices() map[string]string {
services := make(map[string]string)
for key, service := range m.Services {
services[key] = service
}
return services
}
func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
nodes := make(map[string]mesh.MeshNode)
@ -278,6 +462,8 @@ func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
Timestamp: node.Timestamp,
Routes: node.Routes,
Description: node.Description,
Alias: node.Alias,
Services: node.GetServices(),
}
}

View File

@ -0,0 +1,366 @@
package crdt
import (
"slices"
"strings"
"testing"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type TestParams struct {
manager *CrdtMeshManager
}
func setUpTests() *TestParams {
manager, _ := NewCrdtNodeManager(&NewCrdtNodeMangerParams{
MeshId: "timsmesh123",
DevName: "wg0",
Port: 5000,
Client: nil,
Conf: conf.WgMeshConfiguration{},
})
return &TestParams{
manager: manager,
}
}
func getTestNode() mesh.MeshNode {
return &MeshNodeCrdt{
HostEndpoint: "public-endpoint:8080",
WgEndpoint: "public-endpoint:21906",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d218/128",
PublicKey: "AAAAAAAAAAAA",
Timestamp: time.Now().Unix(),
Description: "A node that we are adding",
}
}
func getTestNode2() mesh.MeshNode {
return &MeshNodeCrdt{
HostEndpoint: "public-endpoint:8081",
WgEndpoint: "public-endpoint:21907",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d219/128",
PublicKey: "BBBBBBBBB",
Timestamp: time.Now().Unix(),
Description: "A node that we are adding",
}
}
func TestAddNodeNodeExists(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getTestNode())
node, err := testParams.manager.GetNode("public-endpoint:8080")
if err != nil {
t.Error(err)
}
if node == nil {
t.Fatalf(`node not added to the mesh when it should be`)
}
}
func TestAddNodeAddRoute(t *testing.T) {
testParams := setUpTests()
testNode := getTestNode()
testParams.manager.AddNode(testNode)
testParams.manager.AddRoutes(testNode.GetHostEndpoint(), "fd:1c64:1d00::/48")
updatedNode, err := testParams.manager.GetNode(testNode.GetHostEndpoint())
if err != nil {
t.Error(err)
}
if updatedNode == nil {
t.Fatalf(`Node does not exist in the mesh`)
}
routes := updatedNode.GetRoutes()
if !slices.Contains(routes, "fd:1c64:1d00::/48") {
t.Fatal("Route node not added")
}
if len(routes) != 1 {
t.Fatal(`Route length mismatch`)
}
}
func TestGetMeshIdReturnsTheMeshId(t *testing.T) {
testParams := setUpTests()
if len(testParams.manager.GetMeshId()) == 0 {
t.Fatal(`Meshid is less than 0`)
}
}
// Add 2 nodes to the mesh and then get the mesh.s
// It should return the 2 nodes that have been added to the mesh
func TestAdd2NodesGetMesh(t *testing.T) {
testParams := setUpTests()
node1 := getTestNode()
node2 := getTestNode2()
testParams.manager.AddNode(node1)
testParams.manager.AddNode(node2)
mesh, err := testParams.manager.GetMesh()
if err != nil {
t.Error(err)
}
nodes := mesh.GetNodes()
if len(nodes) != 2 {
t.Fatalf(`Mismatch in node slice`)
}
for _, node := range nodes {
if node.GetHostEndpoint() != node1.GetHostEndpoint() && node.GetHostEndpoint() != node2.GetHostEndpoint() {
t.Fatalf(`Node should not exist`)
}
}
}
func TestSaveMeshReturnsMeshBytes(t *testing.T) {
testParams := setUpTests()
node1 := getTestNode()
testParams.manager.AddNode(node1)
bytes := testParams.manager.Save()
if len(bytes) <= 0 {
t.Fatalf(`bytes in the mesh is less than 0`)
}
}
func TestSaveMeshThenLoad(t *testing.T) {
testParams := setUpTests()
testParams2 := setUpTests()
node1 := getTestNode()
testParams.manager.AddNode(node1)
bytes := testParams.manager.Save()
err := testParams2.manager.Load(bytes)
if err != nil {
t.Error(err)
}
if len(bytes) <= 0 {
t.Fatalf(`bytes in the mesh is less than 0`)
}
mesh2, err := testParams2.manager.GetMesh()
if err != nil {
t.Error(err)
}
nodes := mesh2.GetNodes()
if lib.MapValues(nodes)[0].GetHostEndpoint() != node1.GetHostEndpoint() {
t.Fatalf(`Node should be in the list of nodes`)
}
}
func TestLengthNoNodes(t *testing.T) {
testParams := setUpTests()
if testParams.manager.Length() != 0 {
t.Fatalf(`Number of nodes should be 0`)
}
}
func TestLength1Node(t *testing.T) {
testParams := setUpTests()
node := getTestNode()
testParams.manager.AddNode(node)
if testParams.manager.Length() != 1 {
t.Fatalf(`Number of nodes should be 1`)
}
}
func TestLengthMultipleNodes(t *testing.T) {
testParams := setUpTests()
node := getTestNode()
node1 := getTestNode2()
testParams.manager.AddNode(node)
testParams.manager.AddNode(node1)
if testParams.manager.Length() != 2 {
t.Fatalf(`Number of nodes should be 2`)
}
}
func TestHasChangesNoChanges(t *testing.T) {
testParams := setUpTests()
if testParams.manager.HasChanges() {
t.Fatalf(`Should not have changes just created document`)
}
}
func TestHasChangesChanges(t *testing.T) {
testParams := setUpTests()
node := getTestNode()
testParams.manager.AddNode(node)
if !testParams.manager.HasChanges() {
t.Fatalf(`Should have changes just added node`)
}
}
func TestHasChangesSavedChanges(t *testing.T) {
testParams := setUpTests()
node := getTestNode()
testParams.manager.AddNode(node)
testParams.manager.SaveChanges()
if testParams.manager.HasChanges() {
t.Fatalf(`Should not have changes just saved document`)
}
}
func TestUpdateTimeStampNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.UpdateTimeStamp("AAAAAA")
if err == nil {
t.Fatalf(`Error should have returned`)
}
}
func TestUpdateTimeStampNodeExists(t *testing.T) {
testParams := setUpTests()
node := getTestNode()
testParams.manager.AddNode(node)
err := testParams.manager.UpdateTimeStamp(node.GetHostEndpoint())
if err != nil {
t.Error(err)
}
}
func TestSetDescriptionNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.SetDescription("AAAAA", "Bob 123")
if err == nil {
t.Fatalf(`Error should have returned`)
}
}
func TestSetDescriptionNodeExists(t *testing.T) {
testParams := setUpTests()
node := getTestNode()
err := testParams.manager.SetDescription(node.GetHostEndpoint(), "Bob 123")
if err == nil {
t.Fatalf(`Error should have returned`)
}
}
func TestAddRoutesNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.AddRoutes("AAAAA", "fd:1c64:1d00::/48")
if err == nil {
t.Error(err)
}
}
func TestCompareComparesByPublicKey(t *testing.T) {
node := getTestNode().(*MeshNodeCrdt)
node2 := getTestNode2().(*MeshNodeCrdt)
if node.Compare(node2) != -1 {
t.Fatalf(`node is alphabetically before node2`)
}
if node2.Compare(node) != 1 {
t.Fatalf(`node is alphabetical;y before node2`)
}
if node.Compare(node) != 0 {
t.Fatalf(`node is equal to node`)
}
}
func TestGetHostEndpoint(t *testing.T) {
node := getTestNode()
if (node.(*MeshNodeCrdt)).HostEndpoint != node.GetHostEndpoint() {
t.Fatalf(`get hostendpoint should get the host endpoint`)
}
}
func TestGetPublicKey(t *testing.T) {
key1, _ := wgtypes.GenerateKey()
node := getTestNode()
node.(*MeshNodeCrdt).PublicKey = key1.String()
pubKey, err := node.GetPublicKey()
if err != nil {
t.Error(err)
}
if pubKey.String() != key1.String() {
t.Fatalf(`Expected %s got %s`, key1.String(), pubKey.String())
}
}
func TestGetWgEndpoint(t *testing.T) {
node := getTestNode()
if node.(*MeshNodeCrdt).WgEndpoint != node.GetWgEndpoint() {
t.Fatal(`Did not return the correct wgEndpoint`)
}
}
func TestGetWgHost(t *testing.T) {
node := getTestNode()
ip := node.GetWgHost()
if node.(*MeshNodeCrdt).WgHost != ip.String() {
t.Fatal(`Did not parse WgHost correctly`)
}
}
func TestGetTimeStamp(t *testing.T) {
node := getTestNode()
if node.(*MeshNodeCrdt).Timestamp != node.GetTimeStamp() {
t.Fatal(`Did not return return the correct timestamp`)
}
}
func TestGetIdentifierDoesNotContainPrefix(t *testing.T) {
node := getTestNode()
if strings.Contains(node.GetIdentifier(), "/128") {
t.Fatal(`Identifier should not contain prefix`)
}
}

View File

@ -35,7 +35,9 @@ func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNod
WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()),
// Always set the routes as empty.
// Routes handled by external component
Routes: map[string]interface{}{},
Routes: map[string]interface{}{},
Description: "",
Alias: "",
}
}

View File

@ -8,7 +8,9 @@ type MeshNodeCrdt struct {
WgHost string `automerge:"wgHost"`
Timestamp int64 `automerge:"timestamp"`
Routes map[string]interface{} `automerge:"routes"`
Alias string `automerge:"alias"`
Description string `automerge:"description"`
Services map[string]string `automerge:"services"`
}
// MeshCrdt: Represents the mesh network as a whole

View File

@ -8,6 +8,14 @@ import (
"gopkg.in/yaml.v3"
)
type WgMeshConfigurationError struct {
msg string
}
func (m *WgMeshConfigurationError) Error() string {
return m.msg
}
type WgMeshConfiguration struct {
// CertificatePath is the path to the certificate to use in mTLS
CertificatePath string `yaml:"certificatePath"`
@ -24,13 +32,109 @@ type WgMeshConfiguration struct {
AdvertiseRoutes bool `yaml:"advertiseRoutes"`
// Endpoint is the IP in which this computer is publicly reachable.
// usecase is when the node has multiple IP addresses
Endpoint string `yaml:"publicEndpoint"`
ClusterSize int `yaml:"clusterSize"`
SyncRate float64 `yaml:"syncRate"`
Endpoint string `yaml:"publicEndpoint"`
// ClusterSize size of the cluster to split on
ClusterSize int `yaml:"clusterSize"`
// SyncRate number of times per second to perform a sync
SyncRate float64 `yaml:"syncRate"`
// InterClusterChance proability of inter-cluster communication in a sync round
InterClusterChance float64 `yaml:"interClusterChance"`
BranchRate int `yaml:"branchRate"`
InfectionCount int `yaml:"infectionCount"`
KeepAliveRate int `yaml:"keepAliveRate"`
// BranchRate number of nodes to randomly communicate with
BranchRate int `yaml:"branchRate"`
// InfectionCount number of times we sync before we can no longer catch the udpate
InfectionCount int `yaml:"infectionCount"`
// KeepAliveTime number of seconds before we update node indicating that we are still alive
KeepAliveTime int `yaml:"keepAliveTime"`
// Timeout number of seconds before we consider the node as dead
Timeout int `yaml:"timeout"`
// PruneTime number of seconds before we consider the 'node' as dead
PruneTime int `yaml:"pruneTime"`
// Profile whether or not to include a http server that profiles the code
Profile bool `yaml:"profile"`
// StubWg whether or not to stub the WireGuard types
StubWg bool `yaml:"stubWg"`
}
func ValidateConfiguration(c *WgMeshConfiguration) error {
if len(c.CertificatePath) == 0 {
return &WgMeshConfigurationError{
msg: "A public certificate must be specified for mTLS",
}
}
if len(c.PrivateKeyPath) == 0 {
return &WgMeshConfigurationError{
msg: "A private key must be specified for mTLS",
}
}
if len(c.CaCertificatePath) == 0 {
return &WgMeshConfigurationError{
msg: "A ca certificate must be specified for mTLS",
}
}
if len(c.GrpcPort) == 0 {
return &WgMeshConfigurationError{
msg: "A grpc port must be specified",
}
}
if c.ClusterSize <= 0 {
return &WgMeshConfigurationError{
msg: "A cluster size must not be 0",
}
}
if c.SyncRate <= 0 {
return &WgMeshConfigurationError{
msg: "SyncRate cannot be negative",
}
}
if c.BranchRate <= 0 {
return &WgMeshConfigurationError{
msg: "Branch rate cannot be negative",
}
}
if c.InfectionCount <= 0 {
return &WgMeshConfigurationError{
msg: "Infection count cannot be less than 1",
}
}
if c.KeepAliveTime <= 0 {
return &WgMeshConfigurationError{
msg: "KeepAliveRate cannot be less than negative",
}
}
if c.InterClusterChance <= 0 {
return &WgMeshConfigurationError{
msg: "Intercluster chance cannot be less than 0",
}
}
if c.Timeout < 1 {
return &WgMeshConfigurationError{
msg: "Timeout should be greater than or equal to 1",
}
}
if c.PruneTime <= 1 {
return &WgMeshConfigurationError{
msg: "Prune time cannot be <= 1",
}
}
if c.KeepAliveTime <= 1 {
return &WgMeshConfigurationError{
msg: "Prune time cannot be less than keep alive time",
}
}
return nil
}
// ParseConfiguration parses the mesh configuration
@ -51,5 +155,5 @@ func ParseConfiguration(filePath string) (*WgMeshConfiguration, error) {
return nil, err
}
return &conf, nil
return &conf, ValidateConfiguration(&conf)
}

165
pkg/conf/conf_test.go Normal file
View File

@ -0,0 +1,165 @@
package conf
import "testing"
func getExampleConfiguration() *WgMeshConfiguration {
return &WgMeshConfiguration{
CertificatePath: "./cert/cert.pem",
PrivateKeyPath: "./cert/key.pem",
CaCertificatePath: "./cert/ca.pems",
SkipCertVerification: true,
GrpcPort: "8080",
AdvertiseRoutes: true,
Endpoint: "localhost",
ClusterSize: 1,
SyncRate: 1,
InterClusterChance: 0.1,
BranchRate: 2,
KeepAliveTime: 4,
InfectionCount: 1,
Timeout: 2,
PruneTime: 20,
}
}
func TestConfigurationCertificatePathEmpty(t *testing.T) {
conf := getExampleConfiguration()
conf.CertificatePath = ""
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestConfigurationPrivateKeyPathEmpty(t *testing.T) {
conf := getExampleConfiguration()
conf.PrivateKeyPath = ""
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestConfigurationCaCertificatePathEmpty(t *testing.T) {
conf := getExampleConfiguration()
conf.CaCertificatePath = ""
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestConfigurationGrpcPortEmpty(t *testing.T) {
conf := getExampleConfiguration()
conf.GrpcPort = ""
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestClusterSizeZero(t *testing.T) {
conf := getExampleConfiguration()
conf.ClusterSize = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func SyncRateZero(t *testing.T) {
conf := getExampleConfiguration()
conf.SyncRate = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func BranchRateZero(t *testing.T) {
conf := getExampleConfiguration()
conf.BranchRate = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func InfectionCountZero(t *testing.T) {
conf := getExampleConfiguration()
conf.InfectionCount = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func KeepAliveRateZero(t *testing.T) {
conf := getExampleConfiguration()
conf.KeepAliveTime = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestValidCOnfiguration(t *testing.T) {
conf := getExampleConfiguration()
err := ValidateConfiguration(conf)
if err != nil {
t.Error(err)
}
}
func TestTimeout(t *testing.T) {
conf := getExampleConfiguration()
conf.Timeout = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestPruneTimeZero(t *testing.T) {
conf := getExampleConfiguration()
conf.PruneTime = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatalf(`Error should be thrown`)
}
}
func TestPruneTimeLessThanKeepAliveTime(t *testing.T) {
conf := getExampleConfiguration()
conf.PruneTime = 1
err := ValidateConfiguration(conf)
if err == nil {
t.Fatalf(`Error should be thrown`)
}
}

75
pkg/conn/cluster.go Normal file
View File

@ -0,0 +1,75 @@
package conn
import (
"errors"
"math"
"math/rand"
"slices"
)
// ConnCluster splits nodes into clusters where nodes in a cluster communicate
// frequently and nodes outside of a cluster communicate infrequently
type ConnCluster interface {
GetNeighbours(global []string, selfId string) []string
GetInterCluster(global []string, selfId string) string
}
type ConnClusterImpl struct {
clusterSize int
}
func binarySearch(global []string, selfId string, groupSize int) (int, int) {
slices.Sort(global)
lower := 0
higher := len(global) - 1
mid := (lower + higher) / 2
for (higher+1)-lower > groupSize {
if global[mid] < selfId {
lower = mid + 1
} else if global[mid] > selfId {
higher = mid - 1
} else {
break
}
mid = (lower + higher) / 2
}
return lower, int(math.Min(float64(lower+groupSize), float64(len(global))))
}
// GetNeighbours return the neighbours 'nearest' to you. In this implementation the
// neighbours aren't actually the ones nearest to you but just the ones nearest
// to you alphabetically. Perform binary search to get the total group
func (i *ConnClusterImpl) GetNeighbours(global []string, selfId string) []string {
slices.Sort(global)
lower, higher := binarySearch(global, selfId, i.clusterSize)
// slice the list to get the neighbours
return global[lower:higher]
}
// GetInterCluster get nodes not in your cluster. Every round there is a given chance
// you will communicate with a random node that is not in your cluster.
func (i *ConnClusterImpl) GetInterCluster(global []string, selfId string) string {
// Doesn't matter if not in it. Get index of where the node 'should' be
index, _ := binarySearch(global, selfId, 1)
numClusters := math.Ceil(float64(len(global)) / float64(i.clusterSize))
randomCluster := rand.Intn(int(numClusters)-1) + 1
neighbourIndex := (index + randomCluster) % len(global)
return global[neighbourIndex]
}
func NewConnCluster(clusterSize int) (ConnCluster, error) {
log2Cluster := math.Log2(float64(clusterSize))
if float64((log2Cluster))-log2Cluster != 0 {
return nil, errors.New("cluster must be a power of 2")
}
return &ConnClusterImpl{clusterSize: clusterSize}, nil
}

116
pkg/conn/cluster_test.go Normal file
View File

@ -0,0 +1,116 @@
package conn
import (
"math/rand"
"slices"
"testing"
)
func TestGetNeighboursClusterSizeTwo(t *testing.T) {
cluster := &ConnClusterImpl{
clusterSize: 2,
}
neighbours := []string{
"a",
"b",
"c",
"d",
}
result := cluster.GetNeighbours(neighbours, "b")
if len(result) != 2 {
t.Fatalf(`neighbour length should be 2`)
}
if result[0] != "a" && result[1] != "b" {
t.Fatalf(`Expected value b`)
}
}
func TestGetNeighboursGlobalListLessThanClusterSize(t *testing.T) {
cluster := &ConnClusterImpl{
clusterSize: 4,
}
neighbours := []string{
"a",
"b",
"c",
}
result := cluster.GetNeighbours(neighbours, "a")
if len(result) != 3 {
t.Fatalf(`neighbour length should be 3`)
}
slices.Sort(result)
if !slices.Equal(result, neighbours) {
t.Fatalf(`Cluster and neighbours should be equal`)
}
}
func TestGetNeighboursClusterSize4(t *testing.T) {
cluster := &ConnClusterImpl{
clusterSize: 4,
}
neighbours := []string{
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
"k", "l", "m", "n", "o",
}
result := cluster.GetNeighbours(neighbours, "k")
if len(result) != 4 {
t.Fatalf(`cluster size must be 4`)
}
slices.Sort(result)
if !slices.Equal(neighbours[8:12], result) {
t.Fatalf(`Cluster should be i, j, k, l`)
}
}
func TestGetNeighboursClusterSize4OneReturned(t *testing.T) {
cluster := &ConnClusterImpl{
clusterSize: 4,
}
neighbours := []string{
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
"k", "l", "m", "n", "o",
}
result := cluster.GetNeighbours(neighbours, "o")
if len(result) != 3 {
t.Fatalf(`Cluster should be of length 3`)
}
if !slices.Equal(neighbours[12:15], result) {
t.Fatalf(`Cluster should be m, n, o`)
}
}
func TestInterClusterNotInCluster(t *testing.T) {
rand.Seed(1)
cluster := &ConnClusterImpl{
clusterSize: 4,
}
global := []string{
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
"k", "l", "m", "n", "o",
}
neighbours := cluster.GetNeighbours(global, "c")
interCluster := cluster.GetInterCluster(global, "c")
if slices.Contains(neighbours, interCluster) {
t.Fatalf(`intercluster cannot be in your cluster`)
}
}

View File

@ -18,6 +18,8 @@ type PeerConnection interface {
GetClient() (*grpc.ClientConn, error)
}
type PeerConnectionFactory = func(clientConfig *tls.Config, server string) (PeerConnection, error)
// WgCtrlConnection implements PeerConnection.
type WgCtrlConnection struct {
clientConfig *tls.Config
@ -26,12 +28,12 @@ type WgCtrlConnection struct {
}
// NewWgCtrlConnection creates a new instance of a WireGuard control connection
func NewWgCtrlConnection(clientConfig *tls.Config, server string) (*WgCtrlConnection, error) {
func NewWgCtrlConnection(clientConfig *tls.Config, server string) (PeerConnection, error) {
var conn WgCtrlConnection
conn.clientConfig = clientConfig
conn.endpoint = server
if err := conn.createGrpcConn(); err != nil {
if err := conn.CreateGrpcConnection(); err != nil {
return nil, err
}
@ -39,7 +41,7 @@ func NewWgCtrlConnection(clientConfig *tls.Config, server string) (*WgCtrlConnec
}
// ConnectWithToken: Connects to a new gRPC peer given the address of the other server.
func (c *WgCtrlConnection) createGrpcConn() error {
func (c *WgCtrlConnection) CreateGrpcConnection() error {
conn, err := grpc.Dial(c.endpoint,
grpc.WithTransportCredentials(credentials.NewTLS(c.clientConfig)))
@ -62,7 +64,7 @@ func (c *WgCtrlConnection) GetClient() (*grpc.ClientConn, error) {
var err error = nil
if c.conn == nil {
err = errors.New("The client's config does not exist")
err = errors.New("the client's config does not exist")
}
return c.conn, err

View File

@ -34,10 +34,11 @@ type ConnectionManagerImpl struct {
clientConnections map[string]PeerConnection
serverConfig *tls.Config
clientConfig *tls.Config
connFactory PeerConnectionFactory
}
// Create a new instance of a connection manager.
type NewConnectionManageParams struct {
type NewConnectionManagerParams struct {
// The path to the certificate
CertificatePath string
// The private key of the node
@ -45,11 +46,12 @@ type NewConnectionManageParams struct {
// Whether or not to skip certificate verification
SkipCertVerification bool
CaCert string
ConnFactory PeerConnectionFactory
}
// NewConnectionManager: Creates a new instance of a ConnectionManager or an error
// if something went wrong.
func NewConnectionManager(params *NewConnectionManageParams) (ConnectionManager, error) {
func NewConnectionManager(params *NewConnectionManagerParams) (ConnectionManager, error) {
cert, err := tls.LoadX509KeyPair(params.CertificatePath, params.PrivateKey)
if err != nil {
@ -94,10 +96,12 @@ func NewConnectionManager(params *NewConnectionManageParams) (ConnectionManager,
}
connections := make(map[string]PeerConnection)
connMgr := ConnectionManagerImpl{sync.RWMutex{},
connMgr := ConnectionManagerImpl{
sync.RWMutex{},
connections,
serverConfig,
clientConfig,
params.ConnFactory,
}
return &connMgr, nil
@ -127,7 +131,7 @@ func (m *ConnectionManagerImpl) AddConnection(endPoint string) (PeerConnection,
return conn, nil
}
connections, err := NewWgCtrlConnection(m.clientConfig, endPoint)
connections, err := m.connFactory(m.clientConfig, endPoint)
if err != nil {
return nil, err

View File

@ -0,0 +1,145 @@
package conn
import (
"crypto/tls"
"errors"
"log"
"testing"
)
func getConnectionManagerParams() *NewConnectionManagerParams {
return &NewConnectionManagerParams{
CertificatePath: "./test/cert.pem",
PrivateKey: "./test/priv.pem",
CaCert: "./test/cacert.pem",
SkipCertVerification: false,
ConnFactory: MockFactory,
}
}
func TestNewConnectionManagerCertificatePathDoesNotExist(t *testing.T) {
params := getConnectionManagerParams()
params.CertificatePath = "./cert/sdfjdskjdsjkd.pem"
_, err := NewConnectionManager(params)
if err == nil {
t.Fatalf(`Expected error as certificate does not exist`)
}
}
func TestNewConnectionManagerPrivateKeyDoesNotExist(t *testing.T) {
params := getConnectionManagerParams()
params.PrivateKey = "./cert/sdjdjdks.pem"
_, err := NewConnectionManager(params)
if err == nil {
t.Fatalf(`Expected error as private key does not exist`)
}
}
func TestNewConnectionManagerCACertDoesNotExistAndVerify(t *testing.T) {
params := getConnectionManagerParams()
params.CaCert = "./cert/sdjdsjdksjdks.pem"
params.SkipCertVerification = false
_, err := NewConnectionManager(params)
if err == nil {
t.Fatal(`Expected error as ca cert does not exist and skip is false`)
}
}
func TestNewConnectionManagerCACertDoesNotExistAndNotVerify(t *testing.T) {
params := getConnectionManagerParams()
params.CaCert = ""
params.SkipCertVerification = true
_, err := NewConnectionManager(params)
if err != nil {
t.Fatal(`an error should not be thrown`)
}
}
func TestGetConnectionConnectionDoesNotExistAddsConnection(t *testing.T) {
params := getConnectionManagerParams()
m, _ := NewConnectionManager(params)
conn, err := m.GetConnection("abc-123.com")
if err != nil {
t.Error(err)
}
if conn == nil {
t.Fatal(`the connection should not be nil`)
}
conn2, _ := m.GetConnection("abc-123.com")
if conn != conn2 {
log.Fatalf(`should return the same connection instance`)
}
}
func TestAddConnectionThrowsAnErrorIfFactoryThrowsError(t *testing.T) {
params := getConnectionManagerParams()
params.ConnFactory = func(clientConfig *tls.Config, server string) (PeerConnection, error) {
return nil, errors.New("this is an error")
}
m, _ := NewConnectionManager(params)
_, err := m.AddConnection("abc-123.com")
if err == nil || err.Error() != "this is an error" {
t.Error(err)
}
}
func TestAddConnectionConnectionDoesNotExist(t *testing.T) {
params := getConnectionManagerParams()
m, _ := NewConnectionManager(params)
conn, err := m.AddConnection("abc-123.com")
if err != nil {
t.Error(err)
}
if conn == nil {
t.Fatal(`connection should not be nil`)
}
conn1, _ := m.GetConnection("abc-123.com")
if conn != conn1 {
t.Fatal(`underlying connections should be the same`)
}
}
func TestHasConnectionConnectionDoesNotExist(t *testing.T) {
params := getConnectionManagerParams()
m, _ := NewConnectionManager(params)
if m.HasConnection("abc-123.com") {
t.Fatal(`should return that the connection does not exist`)
}
}
func TestHasConnectionConnectionExists(t *testing.T) {
params := getConnectionManagerParams()
m, _ := NewConnectionManager(params)
m.AddConnection("abc-123.com")
if !m.HasConnection("abc-123.com") {
t.Fatal(`should return that the connection exists`)
}
}

View File

@ -16,9 +16,7 @@ type ConnectionServer struct {
// tlsConfiguration of the server
serverConfig *tls.Config
// server an instance of the grpc server
server *grpc.Server
// the authentication service to authenticate nodes
authProvider rpc.AuthenticationServer
server *grpc.Server // the authentication service to authenticate nodes
// the ctrl service to manage node
ctrlProvider rpc.MeshCtrlServerServer
// the sync service to synchronise nodes
@ -30,7 +28,6 @@ type ConnectionServer struct {
// NewConnectionServerParams contains params for creating a new connection server
type NewConnectionServerParams struct {
Conf *conf.WgMeshConfiguration
AuthProvider rpc.AuthenticationServer
CtrlProvider rpc.MeshCtrlServerServer
SyncProvider rpc.SyncServiceServer
}
@ -59,14 +56,12 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
grpc.Creds(credentials.NewTLS(serverConfig)),
)
authProvider := params.AuthProvider
ctrlProvider := params.CtrlProvider
syncProvider := params.SyncProvider
connServer := ConnectionServer{
serverConfig: serverConfig,
server: server,
authProvider: authProvider,
ctrlProvider: ctrlProvider,
syncProvider: syncProvider,
Conf: params.Conf,
@ -78,7 +73,6 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
// Listen for incoming requests. Returns an error if something went wrong.
func (s *ConnectionServer) Listen() error {
rpc.RegisterMeshCtrlServerServer(s.server, s.ctrlProvider)
rpc.RegisterAuthenticationServer(s.server, s.authProvider)
rpc.RegisterSyncServiceServer(s.server, s.syncProvider)

51
pkg/conn/stub.go Normal file
View File

@ -0,0 +1,51 @@
package conn
import (
"crypto/tls"
"google.golang.org/grpc"
)
type ConnectionManagerStub struct {
Endpoints map[string]PeerConnection
}
func (s *ConnectionManagerStub) AddConnection(endPoint string) (PeerConnection, error) {
mock := &PeerConnectionMock{}
s.Endpoints[endPoint] = mock
return mock, nil
}
func (s *ConnectionManagerStub) GetConnection(endPoint string) (PeerConnection, error) {
endpoint, ok := s.Endpoints[endPoint]
if !ok {
return s.AddConnection(endPoint)
}
return endpoint, nil
}
func (s *ConnectionManagerStub) HasConnection(endPoint string) bool {
_, ok := s.Endpoints[endPoint]
return ok
}
func (s *ConnectionManagerStub) Close() error {
return nil
}
type PeerConnectionMock struct {
}
func (c *PeerConnectionMock) Close() error {
return nil
}
func (c *PeerConnectionMock) GetClient() (*grpc.ClientConn, error) {
return &grpc.ClientConn{}, nil
}
var MockFactory PeerConnectionFactory = func(clientConfig *tls.Config, server string) (PeerConnection, error) {
return &PeerConnectionMock{}, nil
}

21
pkg/conn/test/cacert.pem Normal file
View File

@ -0,0 +1,21 @@
-----BEGIN CERTIFICATE-----
MIIDazCCAlOgAwIBAgIUDRIRI8UnHU2a4znsun0gxFwlrFQwDQYJKoZIhvcNAQEL
BQAwRTELMAkGA1UEBhMCR0IxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMzEwMjcxNTIzMDZaFw0yNDEw
MjYxNTIzMDZaMEUxCzAJBgNVBAYTAkdCMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw
HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB
AQUAA4IBDwAwggEKAoIBAQDJ5hOmzilimA/zM5hYP7CQf4iRmICtSbVLgt6/rTDP
p3JsGGQWZ4pZNofzGnGa7aEMoXS2Ztl7GzZbr1p4+rd6MBbVt8XZ/hP+X4zasCXi
/YubG0TYyBuAt+JrcYb0cbsTBkMXXnFcNIXDfeYFsNq+pfyJwq2ElMUUZ6SQmVhH
ovn1Wk9Fv4t2GJMhmUcObrSIoYdgo4Vf9CfQnn0PCaRf+RjspY/Kz33oyqDI6xJx
I0rfJR7f9B6ZKosfAkt4oTTfT9P8w/d1I95oBENhDkalgkdJCuNJ/AwKGxZrYf/P
aefcc91HheauObjBYPFrSn6bUj3LMJEfj4IeBK+fOZCfAgMBAAGjUzBRMB0GA1Ud
DgQWBBSpcF7jtpd9n73VM3xhPmI1GMEkFjAfBgNVHSMEGDAWgBSpcF7jtpd9n73V
M3xhPmI1GMEkFjAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQCK
GplAveP9nVo9zmg+/mkDpyVoo5rp64oJh4DFtm6X+EI31FmH6Cb71Kn2ZzXhQvSq
qrP7+VoGeBDxk4guJtAs/fhnuDupJG2SjsctjiFnDbSrJjWJjGhC0kuL0wcjLU5G
qUpCEJu13GkDlYHKKw0z+oLUOw+OHmvE5/sD23sKl2KxBWKItx0hwSCkGtm0RQld
8mfjOsHqJ2V/FOcHK6X2DSV1728PAhu4l/PRSB0drBA+7kdeCuWIRZw5RA/OyxvU
CuC5dfUh75MrK7KL6sZsXklsoXo8BZp4rRRUt/v1D3r/SMBJPULSGXh6QDjXQX1D
km71c3DEDyKznHTpGxPt
-----END CERTIFICATE-----

View File

@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDJ5hOmzilimA/z
M5hYP7CQf4iRmICtSbVLgt6/rTDPp3JsGGQWZ4pZNofzGnGa7aEMoXS2Ztl7GzZb
r1p4+rd6MBbVt8XZ/hP+X4zasCXi/YubG0TYyBuAt+JrcYb0cbsTBkMXXnFcNIXD
feYFsNq+pfyJwq2ElMUUZ6SQmVhHovn1Wk9Fv4t2GJMhmUcObrSIoYdgo4Vf9CfQ
nn0PCaRf+RjspY/Kz33oyqDI6xJxI0rfJR7f9B6ZKosfAkt4oTTfT9P8w/d1I95o
BENhDkalgkdJCuNJ/AwKGxZrYf/Paefcc91HheauObjBYPFrSn6bUj3LMJEfj4Ie
BK+fOZCfAgMBAAECggEADqAjoUxC9Dj2wtPkf9QRSs5qSr3E6Iiz4OX4k+MMa6aC
I/F6YqMagw7vtz0dqK75ISybA1GdBI16mRaxU5056FiOdunqo7mDokQytG7ZN8HN
OK23hYqtb1wiw0zEjXWlqyGjf5BgXuERJZG7tYLTvcbRbftTzYxnYGyHn8/z9LBp
GsTJ5X8XMLM5+bTvg1Ovv5s0q31FCeqAuw+auHH4pBNP+ylV6dF5XOWq4HO3TJ2b
grHxWB94JZChZnDC/K+HxQ6aHJfbZ5XCoXfIaIVkoXfnyPzgjvgK+/IpHEF8f/3I
uT/NBiArTpRl29pX5flEO4R121VaW93eM1tuzL32VQKBgQD6Trctx9SYuhzgfiO7
kdefvR43Kl9SFyEw3hN3HW1cxSNGCCFotjmdem+QdtMBtUd27UJ9tuiKJC0lcCER
t3WRz4kVd/cb0eC1DPzpGHA81o1rUUR3nMr1o7aBfvQ06VAxFUrFAOPpF8nD7tI4
0CiOh7/sL1ElThA3bOPUpXkYHQKBgQDOfYbP8dppIkC8pRTnHWe0qUY0G4YXxg7r
UtTo4GYOLJeKH/MKoK8MjBDS5VN5n5TAHJ8yUVzhpWXZIPIGzNEhIRDMa56sRPgI
9mLJNs5z/ZIxd/7ZQbDHrD4T3PKeTjzVUtjXrhLowokPlPB/RMQL6ZT+qMao+3bS
fDITSfLG6wKBgBpbcZSDh1JxvpqxDagxqkfqzSS39IObZeZUbC5NzfdH1vgH4SS6
k4SOoPLQYFW8tgLC5w5/1Sq+tnZLwV+xNtMczG2TTVUDm6rU7EjLRv5RBWE4lIIX
45NMIuqt6J8ttkEE4fOurVEdLSTRoBdVa//eMYp4TQ4lkzWS5Ma+ierNAoGAYO3z
1rFFQYzerq8ffM4E3H2JgvRYodhLMJQVdavAvG6aRDBzOk3rXgxx6U3VPYZ3oSbO
ZCRlYVbu1FnuwtpqYQ7Qf+UU+vD1Ld/ax3F+wFwLwET/0KRRg6mLCm/xQ/ad/9WA
DN6d6b1H8ZSMwHFbRexEELbRaomAYZYDO6K+4DkCgYEAv5De85hPnWtAvKhPzwQi
9mtyWo/cfQgtwL8IKNu6hBHl5RXDpPgX/+pNbXLJfBPwVR3H62x1CMYJDkWVuE6/
ZjtF7FSucZMz/mR6r1GhSOXy3YLwQ6JLPjjKzvnEjahGlKwALJNL0O2ZucjsZxHE
PM4rmhRZT9opiapiltEhRm0=
-----END PRIVATE KEY-----

19
pkg/conn/test/cert.pem Normal file
View File

@ -0,0 +1,19 @@
-----BEGIN CERTIFICATE-----
MIIDCjCCAfICFB/Vd2eOXWdNdrakThJhFIRtZmhUMA0GCSqGSIb3DQEBCwUAMEUx
CzAJBgNVBAYTAkdCMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRl
cm5ldCBXaWRnaXRzIFB0eSBMdGQwHhcNMjMxMDI3MTUzNDM1WhcNMjMxMTI2MTUz
NDM1WjA+MQswCQYDVQQGEwJHQjENMAsGA1UECAwERmlmZTENMAsGA1UEBwwEY2l0
eTERMA8GA1UECgwITWVzaCBMdGQwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEK
AoIBAQDVgcLtNU5AYfPML/mE5PyC7YYKvZn2mt6vEiJ7M/6EzYeTXFeYexD5ZqHg
ewGEd1fwiQWQsATsWd+EM4OnCAXAaNiOH6gGY7FR8CThfT+k8yIGPrl1BovzHHYS
Orekna17UFeIyFMHDPIjl4d2WiJPvmNn5PhLEppPHPBWPhl3J3sMrSbqyRuYbtta
oFIzN8mFcikixLg0SnBPtwlLC72ah9G+MF5CwEcU/E0bYbLQZXv+WhG5aw5JEzes
K2GLxVNgM0xXB7hSyLoX1wBc8DdQyLCMkOp55Hl04UKTxtVE82MiuAOVqMUuKFjR
u2a1C+/Gbk/PS5SHgenGjdZ8sZGpAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAHMc
jIFG5Rn9KaVmo7E+/UAq+3ld/3y2yMHg5wq7oG8b7/z0mlSGErHdFMzo75AFLN4r
kOuiF5ItF6dRLNrG8IUFSNMGVH3b3ukw1EI8E89L8ak3CM+wpLT6GVP3BfV8ah+X
4RRix40Tmx4C81l+Lf5W10rHIdlXBCanJy/Fa0ae+S+oXFc9jeXHlK9qlgszrECT
Pa3VCR95LAIc6o9pDL2Z8tpEkSbyzvIWhp53fnC80PyXpSsFMfIw657shagBc/Ov
e7/aPpPf3V3CafJlEIraQp24MDI5ZM59lT5vhRq2AC50gelL6UPV16mVVUlGVhWE
vYyejod5i5ZbuLFOy2g=
-----END CERTIFICATE-----

28
pkg/conn/test/priv.pem Normal file
View File

@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDVgcLtNU5AYfPM
L/mE5PyC7YYKvZn2mt6vEiJ7M/6EzYeTXFeYexD5ZqHgewGEd1fwiQWQsATsWd+E
M4OnCAXAaNiOH6gGY7FR8CThfT+k8yIGPrl1BovzHHYSOrekna17UFeIyFMHDPIj
l4d2WiJPvmNn5PhLEppPHPBWPhl3J3sMrSbqyRuYbttaoFIzN8mFcikixLg0SnBP
twlLC72ah9G+MF5CwEcU/E0bYbLQZXv+WhG5aw5JEzesK2GLxVNgM0xXB7hSyLoX
1wBc8DdQyLCMkOp55Hl04UKTxtVE82MiuAOVqMUuKFjRu2a1C+/Gbk/PS5SHgenG
jdZ8sZGpAgMBAAECggEARJNAggLYhtpPPVp9WJ9ZsU3L+0AppujYND/tXkf1bD89
V+nVYq7IZWp+/MRVWPAiCSphZLb8ZdN59JK9KtVrT4D9aSymwaKcjfZFSj15xyem
Wn4j///hzGxsSe+dE1znnw9PhindbQrN7Pua8TsDATzj3bdPvoETmexwDysz765i
u4zXvxP+xAessz1OYa5IUaDXdlWOf0e1zNXWwanjRggzCeWR3lTofG49GX087oVC
Sb9ASy+AScnOlwpTdQ8sKy1r9gXmE5ey4AULVb0nJ8LDvrCoBBhKBtVE5mJHepE6
bdC9l6poL6roGvHfMAo3SmiUUT5XceqUxBtHcyHX3wKBgQD1uh+Dv0PrH3CTW9cF
bwHL1rmQNJrbDzDAaounGBe9mcot1RrBhyQAoGw1no4c+QWDAwYRuBP2+Rp6JLU/
XnEXSyN85rJN6LajlrLEr+BNmKw6ghNsnAFUZBLaJ7epRi6OjACUwmtvH6hRIef8
aMg4WiOyDT+Z4Xe81pdXb91HXwKBgQDebs3idgVEau3LCKGYnqvmUhzv8iiQiJmD
R29o2G5Xrahf3r1O5gJdGLO1DaCBtdrI7J4xUOlM935KaEYFe5B7RVGXg23tNWgb
2M+YQqu5qz61bDxhg7dGkegHrdvKNcSkV6GUSm5w9rdxJlY8+l45p/7QpSkatcbd
IRiVzMNr9wKBgQC/+Z5fbpFgYxqvdaPicdxkZShqOj71f8OlwFfEvrTlgv4KmqAh
rDP7bVm89leu2PpuZXFbbIXkgK8n1//mNyGBgkmCbjXFWlc+LSETOxixZuK/fxov
0x3S0bBM0ZTSYatD4KsfjVkj4wa8BBJbB33NUNbsZx9WWGkUlk58mD+3XwKBgQDV
mgR+n6WJQUIfwqckH+Ol517AkYSg33zEE9qKDaVQ74QMpKKY3MqSSkFw8agcR93V
K1zysOeJsPYHUEFFzJY/up6S6HSs4aebbkZUylmMkEVFBa6qWkmrLDxs+2lgsuem
hjy1YhDSzCn3L8CLCEdqCMjr5l8ltkBFZB3u5NcZmwKBgHE9ODedQm783JfvDNBb
lB/IoUjMhMR0J2vHC3zxgTU4nIK+MR0vXvA7fmZebpaQNwYrHY9gvrL0/QevOrmG
PtXlkQ9GITMxTlqfHWV5jXZuRBIGTqh1QW3tKbVAhUhNlM0XDNBmBvjKIFjxUIo3
zMRw/o4R4cIaazyVxguZbsa2
-----END PRIVATE KEY-----

84
pkg/conn/window.go Normal file
View File

@ -0,0 +1,84 @@
package conn
import (
"errors"
"slices"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
// ConnectionWindow maintains a sliding window of connections between users
type ConnectionWindow interface {
// GetWindow is a list of connections to choose from
GetWindow() []string
// SlideConnection removes a node from the window and adds a random node
// not already in the window. connList represents the list of possible
// connections to choose from
SlideConnection(connList []string) error
// PushConneciton is used when connection list less than window size.
PutConnection(conn []string) error
// IsFull returns true if the window is full. In which case we must slide the window
IsFull() bool
}
type ConnectionWindowImpl struct {
window []string
windowSize int
}
// GetWindow gets the current list of active connections in
// the window
func (c *ConnectionWindowImpl) GetWindow() []string {
return c.window
}
// SlideConnection slides the connection window by one shuffling items
// in the windows
func (c *ConnectionWindowImpl) SlideConnection(connList []string) error {
// If the number of peer connections is less than the length of the window
// then exit early. Can't slide the window it should contain all nodes!
if len(c.window) < c.windowSize {
return nil
}
filter := func(node string) bool {
return !slices.Contains(c.window, node)
}
pool := lib.Filter(connList, filter)
newNode := lib.RandomSubsetOfLength(pool, 1)
if len(newNode) == 0 {
return errors.New("could not slide window")
}
for i := len(c.window) - 1; i >= 1; i-- {
c.window[i] = c.window[i-1]
}
c.window[0] = newNode[0]
return nil
}
// PutConnection put random connections in the connection
func (c *ConnectionWindowImpl) PutConnection(connList []string) error {
if len(c.window) >= c.windowSize {
return errors.New("cannot place connection. Window full need to slide")
}
c.window = lib.RandomSubsetOfLength(connList, c.windowSize)
return nil
}
func (c *ConnectionWindowImpl) IsFull() bool {
return len(c.window) >= c.windowSize
}
func NewConnectionWindow(windowLength int) ConnectionWindow {
window := &ConnectionWindowImpl{
window: make([]string, 0),
windowSize: windowLength,
}
return window
}

View File

@ -18,7 +18,6 @@ import (
type NewCtrlServerParams struct {
Conf *conf.WgMeshConfiguration
Client *wgctrl.Client
AuthProvider rpc.AuthenticationServer
CtrlProvider rpc.MeshCtrlServerServer
SyncProvider rpc.SyncServiceServer
Querier query.Querier
@ -36,6 +35,8 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
ipAllocator := &ip.ULABuilder{}
interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client)
configApplyer := mesh.NewWgMeshConfigApplyer()
meshManagerParams := &mesh.NewMeshManagerParams{
Conf: *params.Conf,
Client: params.Client,
@ -44,16 +45,19 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
IdGenerator: idGenerator,
IPAllocator: ipAllocator,
InterfaceManipulator: interfaceManipulator,
ConfigApplyer: configApplyer,
}
ctrlServer.MeshManager = mesh.NewMeshManager(meshManagerParams)
configApplyer.SetMeshManager(ctrlServer.MeshManager)
ctrlServer.Conf = params.Conf
connManagerParams := conn.NewConnectionManageParams{
connManagerParams := conn.NewConnectionManagerParams{
CertificatePath: params.Conf.CertificatePath,
PrivateKey: params.Conf.PrivateKeyPath,
SkipCertVerification: params.Conf.SkipCertVerification,
CaCert: params.Conf.CaCertificatePath,
ConnFactory: conn.NewWgCtrlConnection,
}
connMgr, err := conn.NewConnectionManager(&connManagerParams)
@ -65,7 +69,6 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
ctrlServer.ConnectionManager = connMgr
connServerParams := conn.NewConnectionServerParams{
Conf: params.Conf,
AuthProvider: params.AuthProvider,
CtrlProvider: params.CtrlProvider,
SyncProvider: params.SyncProvider,
}
@ -82,12 +85,36 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
return ctrlServer, nil
}
func (s *MeshCtrlServer) GetConfiguration() *conf.WgMeshConfiguration {
return s.Conf
}
func (s *MeshCtrlServer) GetClient() *wgctrl.Client {
return s.Client
}
func (s *MeshCtrlServer) GetQuerier() query.Querier {
return s.Querier
}
func (s *MeshCtrlServer) GetMeshManager() mesh.MeshManager {
return s.MeshManager
}
func (s *MeshCtrlServer) GetConnectionManager() conn.ConnectionManager {
return s.ConnectionManager
}
// Close closes the ctrl server tearing down any connections that exist
func (s *MeshCtrlServer) Close() error {
if err := s.ConnectionManager.Close(); err != nil {
logging.Log.WriteErrorf(err.Error())
}
if err := s.MeshManager.Close(); err != nil {
logging.Log.WriteErrorf(err.Error())
}
if err := s.ConnectionServer.Close(); err != nil {
logging.Log.WriteErrorf(err.Error())
}

View File

@ -17,6 +17,9 @@ type MeshNode struct {
WgHost string
Timestamp int64
Routes []string
Description string
Alias string
Services map[string]string
}
// Represents a WireGuard Mesh
@ -25,10 +28,19 @@ type Mesh struct {
Nodes map[string]MeshNode
}
type CtrlServer interface {
GetConfiguration() *conf.WgMeshConfiguration
GetClient() *wgctrl.Client
GetQuerier() query.Querier
GetMeshManager() mesh.MeshManager
Close() error
GetConnectionManager() conn.ConnectionManager
}
// Represents a ctrlserver to be used in WireGuard
type MeshCtrlServer struct {
Client *wgctrl.Client
MeshManager *mesh.MeshManager
MeshManager mesh.MeshManager
ConnectionManager conn.ConnectionManager
ConnectionServer *conn.ConnectionServer
Conf *conf.WgMeshConfiguration

51
pkg/ctrlserver/stub.go Normal file
View File

@ -0,0 +1,51 @@
package ctrlserver
import (
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query"
"golang.zx2c4.com/wireguard/wgctrl"
)
type CtrlServerStub struct {
manager mesh.MeshManager
querier query.Querier
connectionManager conn.ConnectionManager
}
func NewCtrlServerStub() *CtrlServerStub {
var manager mesh.MeshManager = mesh.NewMeshManagerStub()
return &CtrlServerStub{
manager: manager,
querier: query.NewJmesQuerier(manager),
connectionManager: &conn.ConnectionManagerStub{},
}
}
func (c *CtrlServerStub) GetConfiguration() *conf.WgMeshConfiguration {
return &conf.WgMeshConfiguration{
GrpcPort: "8080",
Endpoint: "abc.com",
}
}
func (c *CtrlServerStub) GetClient() *wgctrl.Client {
return &wgctrl.Client{}
}
func (c *CtrlServerStub) GetQuerier() query.Querier {
return c.querier
}
func (c *CtrlServerStub) GetMeshManager() mesh.MeshManager {
return c.manager
}
func (c *CtrlServerStub) Close() error {
return nil
}
func (c *CtrlServerStub) GetConnectionManager() conn.ConnectionManager {
return c.connectionManager
}

132
pkg/hosts/hosts.go Normal file
View File

@ -0,0 +1,132 @@
// hosts: utility for modifying the /etc/hosts file
package hosts
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"os"
"strings"
)
// HOSTS_FILE is the hosts file location
const HOSTS_FILE = "/etc/hosts"
const DOMAIN_HEADER = "#WG AUTO GENERATED HOSTS"
const DOMAIN_TRAILER = "#WG AUTO GENERATED HOSTS END"
type HostsEntry struct {
Alias string
Ip net.IP
}
// Generic interface to manipulate /etc/hosts file
type HostsManipulator interface {
// AddrAddr associates an aliasd with a given IP address
AddAddr(hosts ...HostsEntry)
// Remove deletes the entry from /etc/hosts
Remove(hosts ...HostsEntry)
// Writes the changes to /etc/hosts file
Write() error
}
type HostsManipulatorImpl struct {
hosts map[string]HostsEntry
}
// AddAddr implements HostsManipulator.
func (m *HostsManipulatorImpl) AddAddr(hosts ...HostsEntry) {
changed := false
for _, host := range hosts {
prev, ok := m.hosts[host.Ip.String()]
if !ok || prev.Alias != host.Alias {
changed = true
}
m.hosts[host.Ip.String()] = host
}
if changed {
m.Write()
}
}
// Remove implements HostsManipulator.
func (m *HostsManipulatorImpl) Remove(hosts ...HostsEntry) {
lenBefore := len(m.hosts)
for _, host := range hosts {
delete(m.hosts, host.Alias)
}
if lenBefore != len(m.hosts) {
m.Write()
}
}
func (m *HostsManipulatorImpl) removeHosts() string {
hostsFile, err := os.ReadFile(HOSTS_FILE)
if err != nil {
return ""
}
var contents strings.Builder
scanner := bufio.NewScanner(bytes.NewReader(hostsFile))
hostsSection := false
for scanner.Scan() {
line := scanner.Text()
if err == io.EOF {
break
} else if err != nil {
return ""
}
if !hostsSection && strings.Contains(line, DOMAIN_HEADER) {
hostsSection = true
}
if !hostsSection {
contents.WriteString(line + "\n")
}
if hostsSection && strings.Contains(line, DOMAIN_TRAILER) {
hostsSection = false
}
}
if scanner.Err() != nil && scanner.Err() != io.EOF {
return ""
}
return contents.String()
}
// Write implements HostsManipulator
func (m *HostsManipulatorImpl) Write() error {
contents := m.removeHosts()
var nextHosts strings.Builder
nextHosts.WriteString(contents)
nextHosts.WriteString(DOMAIN_HEADER + "\n")
for _, host := range m.hosts {
nextHosts.WriteString(fmt.Sprintf("%s\t%s\n", host.Ip.String(), host.Alias))
}
nextHosts.WriteString(DOMAIN_TRAILER + "\n")
return os.WriteFile(HOSTS_FILE, []byte(nextHosts.String()), 0644)
}
func NewHostsManipulator() HostsManipulator {
return &HostsManipulatorImpl{hosts: make(map[string]HostsEntry)}
}

View File

@ -34,6 +34,11 @@ type JoinMeshArgs struct {
Endpoint string
}
type PutServiceArgs struct {
Service string
Value string
}
type GetMeshReply struct {
Nodes []ctrlserver.MeshNode
}
@ -47,6 +52,11 @@ type QueryMesh struct {
Query string
}
type GetNodeArgs struct {
NodeId string
MeshId string
}
type MeshIpc interface {
CreateMesh(args *NewMeshArgs, reply *string) error
ListMeshes(name string, reply *ListMeshReply) error
@ -57,13 +67,17 @@ type MeshIpc interface {
GetDOT(meshId string, reply *string) error
Query(query QueryMesh, reply *string) error
PutDescription(description string, reply *string) error
PutAlias(alias string, reply *string) error
PutService(args PutServiceArgs, reply *string) error
GetNode(args GetNodeArgs, reply *string) error
DeleteService(service string, reply *string) error
}
const SockAddr = "/tmp/wgmesh_ipc.sock"
func RunIpcHandler(server MeshIpc) error {
if err := os.RemoveAll(SockAddr); err != nil {
return errors.New("Could not find to address")
return errors.New("could not find to address")
}
rpc.Register(server)

View File

@ -30,7 +30,7 @@ func MapKeys[K comparable, V any](m map[K]V) []K {
values := make([]K, len(m))
i := 0
for k, _ := range m {
for k := range m {
values[i] = k
i++
}
@ -58,7 +58,7 @@ type filterFunc[V any] func(V) bool
func Filter[V any](list []V, f filterFunc[V]) []V {
newList := make([]V, 0)
for _, elem := range newList {
for _, elem := range list {
if f(elem) {
newList = append(newList, elem)
}

144
pkg/lib/conv_test.go Normal file
View File

@ -0,0 +1,144 @@
package lib
import (
"slices"
"testing"
)
func stringToInt(input string) int {
return len(input)
}
func intDiv(input int) int {
return input / 2
}
func TestMapValuesMapsValues(t *testing.T) {
values := []int{1, 4, 11, 92}
var theMap map[string]int = map[string]int{
"mynameisjeff": values[0],
"tim": values[1],
"bob": values[2],
"derek": values[3],
}
mapValues := MapValues(theMap)
for _, elem := range mapValues {
if !slices.Contains(values, elem) {
t.Fatalf(`%d is not an expected value`, elem)
}
}
if len(mapValues) != len(theMap) {
t.Fatalf(`Expected length %d got %d`, len(theMap), len(mapValues))
}
}
func TestMapValuesWithExcludeExcludesValues(t *testing.T) {
values := []int{1, 9, 22}
var theMap map[string]int = map[string]int{
"mynameisbob": values[0],
"tim": values[1],
"bob": values[2],
}
exclude := map[string]struct{}{
"tim": {},
}
mapValues := MapValuesWithExclude(theMap, exclude)
if slices.Contains(mapValues, values[1]) {
t.Fatalf(`Failed to exclude expected value`)
}
if len(mapValues) != 2 {
t.Fatalf(`Incorrect expected length`)
}
for _, value := range theMap {
if !slices.Contains(values, value) {
t.Fatalf(`Element does not exist in the list of
expected values`)
}
}
}
func TestMapKeys(t *testing.T) {
keys := []string{"1", "2", "3"}
theMap := map[string]int{
keys[0]: 1,
keys[1]: 2,
keys[2]: 3,
}
mapKeys := MapKeys(theMap)
for _, elem := range mapKeys {
if !slices.Contains(keys, elem) {
t.Fatalf(`%s elem is not an expected key`, elem)
}
}
if len(mapKeys) != len(theMap) {
t.Fatalf(`Missing expected values`)
}
}
func TestMapValues(t *testing.T) {
array := []string{"mynameisjeff", "tim", "bob", "derek"}
intArray := Map(array, stringToInt)
for index, elem := range intArray {
if len(array[index]) != elem {
t.Fatalf(`Have %d want %d`, elem, len(array[index]))
}
}
}
func TestFilterFilterAll(t *testing.T) {
values := []int{1, 2, 3, 4, 5}
filterFunc := func(n int) bool {
return false
}
newValues := Filter(values, filterFunc)
if len(newValues) != 0 {
t.Fatalf(`Expected value was 0`)
}
}
func TestFilterFilterNone(t *testing.T) {
values := []int{1, 2, 3, 4, 5}
filterFunc := func(n int) bool {
return true
}
newValues := Filter(values, filterFunc)
if !slices.Equal(values, newValues) {
t.Fatalf(`Expected lists to be the same`)
}
}
func TestFilterFilterSome(t *testing.T) {
values := []int{1, 2, 3, 4, 5}
filterFunc := func(n int) bool {
return n < 3
}
expected := []int{1, 2}
actual := Filter(values, filterFunc)
if !slices.Equal(expected, actual) {
t.Fatalf(`Expected expected and actual to be the same`)
}
}

46
pkg/lib/random_test.go Normal file
View File

@ -0,0 +1,46 @@
package lib
import (
"slices"
"testing"
)
// Test that a random subset of length 0 produces a zero length
// list
func TestRandomSubsetOfLength0(t *testing.T) {
values := []int{1, 2, 3, 4, 5, 6, 7, 8}
randomValues := RandomSubsetOfLength(values, 0)
if len(randomValues) != 0 {
t.Fatalf(`Expected length to be 0`)
}
}
func TestRandomSubsetOfLength1(t *testing.T) {
values := []int{1, 2, 3, 4, 5}
randomValues := RandomSubsetOfLength(values, 1)
if len(randomValues) != 1 {
t.Fatalf(`Expected length to be 1`)
}
if !slices.Contains(values, randomValues[0]) {
t.Fatalf(`Expected length to be 1`)
}
}
func TestRandomSubsetEntireList(t *testing.T) {
values := []int{1, 2, 3, 4, 5}
randomValues := RandomSubsetOfLength(values, len(values))
if len(randomValues) != len(values) {
t.Fatalf(`Expected length to be %d was %d`, len(values), len(randomValues))
}
slices.Sort(randomValues)
if !slices.Equal(values, randomValues) {
t.Fatalf(`Expected slices to be equal`)
}
}

300
pkg/lib/rtnetlink.go Normal file
View File

@ -0,0 +1,300 @@
package lib
import (
"encoding/binary"
"fmt"
"net"
"github.com/jsimonetti/rtnetlink"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"golang.org/x/sys/unix"
)
type RtNetlinkConfig struct {
conn *rtnetlink.Conn
}
func NewRtNetlinkConfig() (*RtNetlinkConfig, error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return nil, err
}
return &RtNetlinkConfig{conn: conn}, nil
}
const WIREGUARD_MTU = 1420
// Create a netlink interface if it does not exist. ifName is the name of the netlink interface
func (c *RtNetlinkConfig) CreateLink(ifName string) error {
_, err := net.InterfaceByName(ifName)
if err == nil {
return fmt.Errorf("interface %s already exists", ifName)
}
err = c.conn.Link.New(&rtnetlink.LinkMessage{
Family: unix.AF_UNSPEC,
Flags: unix.IFF_UP,
Attributes: &rtnetlink.LinkAttributes{
Name: ifName,
Info: &rtnetlink.LinkInfo{Kind: "wireguard"},
MTU: uint32(WIREGUARD_MTU),
},
})
if err != nil {
return fmt.Errorf("failed to create wireguard interface: %w", err)
}
return nil
}
// Delete link delete the specified interface
func (c *RtNetlinkConfig) DeleteLink(ifName string) error {
iface, err := net.InterfaceByName(ifName)
if err != nil {
return fmt.Errorf("failed to get interface %s %w", ifName, err)
}
err = c.conn.Link.Delete(uint32(iface.Index))
if err != nil {
return fmt.Errorf("failed to delete wg interface %w", err)
}
return nil
}
// AddAddress adds an address to the given interface.
func (c *RtNetlinkConfig) AddAddress(ifName string, address string) error {
iface, err := net.InterfaceByName(ifName)
if err != nil {
return fmt.Errorf("failed to get interface %s error: %w", ifName, err)
}
addr, cidr, err := net.ParseCIDR(address)
if err != nil {
return fmt.Errorf("failed to parse CIDR %s error: %w", addr, err)
}
family := unix.AF_INET6
ipv4 := cidr.IP.To4()
if ipv4 != nil {
family = unix.AF_INET
}
// Calculate the prefix length
ones, _ := cidr.Mask.Size()
// Calculate the broadcast IP
// Only used when family is AF_INET
var brd net.IP
if ipv4 != nil {
brd = make(net.IP, len(ipv4))
binary.BigEndian.PutUint32(brd, binary.BigEndian.Uint32(ipv4)|^binary.BigEndian.Uint32(net.IP(cidr.Mask).To4()))
}
err = c.conn.Address.New(&rtnetlink.AddressMessage{
Family: uint8(family),
PrefixLength: uint8(ones),
Scope: unix.RT_SCOPE_UNIVERSE,
Index: uint32(iface.Index),
Attributes: &rtnetlink.AddressAttributes{
Address: addr,
Local: addr,
Broadcast: brd,
},
})
if err != nil {
err = fmt.Errorf("failed to add address to link %w", err)
}
return err
}
// AddRoute: adds a route to the routing table.
// ifName is the intrface to add the route to
// gateway is the IP of the gateway device to hop to
// dst is the network prefix of the advertised destination
func (c *RtNetlinkConfig) AddRoute(ifName string, route Route) error {
iface, err := net.InterfaceByName(ifName)
if err != nil {
return fmt.Errorf("failed accessing interface %s error %w", ifName, err)
}
gw := route.Gateway
dst := route.Destination
var family uint8 = unix.AF_INET6
if dst.IP.To4() != nil {
family = unix.AF_INET
}
attr := rtnetlink.RouteAttributes{
Dst: dst.IP,
OutIface: uint32(iface.Index),
Gateway: gw,
}
ones, _ := dst.Mask.Size()
err = c.conn.Route.Replace(&rtnetlink.RouteMessage{
Family: family,
Table: unix.RT_TABLE_MAIN,
Protocol: unix.RTPROT_BOOT,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
DstLength: uint8(ones),
Attributes: attr,
})
if err != nil {
return fmt.Errorf("failed to add route %w", err)
}
return nil
}
// DeleteRoute deletes routes with the gateway and destination
func (c *RtNetlinkConfig) DeleteRoute(ifName string, route Route) error {
iface, err := net.InterfaceByName(ifName)
if err != nil {
return fmt.Errorf("failed accessing interface %s error %w", ifName, err)
}
gw := route.Gateway
dst := route.Destination
var family uint8 = unix.AF_INET6
if dst.IP.To4() != nil {
family = unix.AF_INET
}
attr := rtnetlink.RouteAttributes{
Dst: dst.IP,
OutIface: uint32(iface.Index),
Gateway: gw,
}
ones, _ := dst.Mask.Size()
err = c.conn.Route.Delete(&rtnetlink.RouteMessage{
Family: family,
Table: unix.RT_TABLE_MAIN,
Protocol: unix.RTPROT_BOOT,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
DstLength: uint8(ones),
Attributes: attr,
})
if err != nil {
return fmt.Errorf("failed to delete route %w", err)
}
return nil
}
type Route struct {
Gateway net.IP
Destination net.IPNet
}
func (r1 Route) equal(r2 Route) bool {
return r1.Gateway.String() == r2.Gateway.String() &&
r1.Destination.String() == r2.Destination.String()
}
// DeleteRoutes deletes all routes not in exclude
func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...Route) error {
routes := make([]rtnetlink.RouteMessage, 0)
if len(exclude) != 0 {
lRoutes, err := c.listRoutes(ifName, family, exclude[0].Gateway)
if err != nil {
return err
}
routes = lRoutes
}
ifRoutes := make([]Route, 0)
for _, rtRoute := range routes {
logging.Log.WriteInfof("Routes: %s", rtRoute.Attributes.Dst.String())
maskSize := 128
if family == unix.AF_INET {
maskSize = 32
}
cidr := net.CIDRMask(int(rtRoute.DstLength), maskSize)
route := Route{
Gateway: rtRoute.Attributes.Gateway,
Destination: net.IPNet{IP: rtRoute.Attributes.Dst, Mask: cidr},
}
ifRoutes = append(ifRoutes, route)
}
shouldExclude := func(r Route) bool {
for _, route := range exclude {
if route.equal(r) {
return false
}
}
return true
}
toDelete := Filter(ifRoutes, shouldExclude)
for _, route := range toDelete {
logging.Log.WriteInfof("Deleting route %s", route.Destination.String())
err := c.DeleteRoute(ifName, route)
if err != nil {
return err
}
}
return nil
}
// listRoutes lists all routes on the interface
func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8, gateway net.IP) ([]rtnetlink.RouteMessage, error) {
iface, err := net.InterfaceByName(ifName)
if err != nil {
return nil, fmt.Errorf("failed accessing interface %s error %w", ifName, err)
}
routes, err := c.conn.Route.List()
if err != nil {
return nil, fmt.Errorf("failed to get route %w", err)
}
filterFunc := func(r rtnetlink.RouteMessage) bool {
return r.Attributes.Gateway.Equal(gateway) && r.Attributes.OutIface == uint32(iface.Index)
}
routes = Filter(routes, filterFunc)
return routes, nil
}
func (c *RtNetlinkConfig) Close() error {
return c.conn.Close()
}

42
pkg/lib/timer.go Normal file
View File

@ -0,0 +1,42 @@
package lib
import "time"
type TimerFunc = func() error
type Timer struct {
f TimerFunc
quit chan struct{}
updateRate int
}
func (t *Timer) Run() error {
ticker := time.NewTicker(time.Duration(t.updateRate) * time.Second)
t.quit = make(chan struct{})
for {
select {
case <-ticker.C:
err := t.f()
if err != nil {
return err
}
case <-t.quit:
break
}
}
}
func (t *Timer) Stop() error {
close(t.quit)
return nil
}
func NewTimer(f TimerFunc, updateRate int) *Timer {
return &Timer{
f: f,
updateRate: updateRate,
}
}

View File

@ -2,6 +2,7 @@
package logging
import (
"io"
"os"
"github.com/sirupsen/logrus"
@ -15,6 +16,7 @@ type Logger interface {
WriteInfof(msg string, args ...interface{})
WriteErrorf(msg string, args ...interface{})
WriteWarnf(msg string, args ...interface{})
Writer() io.Writer
}
type LogrusLogger struct {
@ -33,6 +35,10 @@ func (l *LogrusLogger) WriteWarnf(msg string, args ...interface{}) {
l.logger.Warnf(msg, args...)
}
func (l *LogrusLogger) Writer() io.Writer {
return l.logger.Writer()
}
func NewLogrusLogger() *LogrusLogger {
logger := logrus.New()
logger.SetFormatter(&logrus.TextFormatter{FullTimestamp: true})

46
pkg/mesh/alias.go Normal file
View File

@ -0,0 +1,46 @@
package mesh
import (
"fmt"
"github.com/tim-beatham/wgmesh/pkg/hosts"
)
type MeshAliasManager interface {
AddAliases(nodes []MeshNode)
RemoveAliases(node []MeshNode)
}
type AliasManager struct {
hosts hosts.HostsManipulator
}
// AddAliases: on node update or change add aliases to the hosts file
func (a *AliasManager) AddAliases(nodes []MeshNode) {
for _, node := range nodes {
if node.GetAlias() != "" {
a.hosts.AddAddr(hosts.HostsEntry{
Alias: fmt.Sprintf("%s.smeg", node.GetAlias()),
Ip: node.GetWgHost().IP,
})
}
}
}
// RemoveAliases: on node remove remove aliases from the hosts file
func (a *AliasManager) RemoveAliases(nodes []MeshNode) {
for _, node := range nodes {
if node.GetAlias() != "" {
a.hosts.Remove(hosts.HostsEntry{
Alias: fmt.Sprintf("%s.smeg", node.GetAlias()),
Ip: node.GetWgHost().IP,
})
}
}
}
func NewAliasManager() MeshAliasManager {
return &AliasManager{
hosts: hosts.NewHostsManipulator(),
}
}

View File

@ -1,7 +1,6 @@
package mesh
import (
"errors"
"fmt"
"net"
@ -12,11 +11,12 @@ import (
type MeshConfigApplyer interface {
ApplyConfig() error
RemovePeers(meshId string) error
SetMeshManager(manager MeshManager)
}
// WgMeshConfigApplyer applies WireGuard configuration
type WgMeshConfigApplyer struct {
meshManager *MeshManager
meshManager MeshManager
}
func convertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) {
@ -82,11 +82,11 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
return err
}
return m.meshManager.Client.ConfigureDevice(dev.Name, cfg)
return m.meshManager.GetClient().ConfigureDevice(dev.Name, cfg)
}
func (m *WgMeshConfigApplyer) ApplyConfig() error {
for _, mesh := range m.meshManager.Meshes {
for _, mesh := range m.meshManager.GetMeshes() {
err := m.updateWgConf(mesh)
if err != nil {
@ -101,7 +101,7 @@ func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error {
mesh := m.meshManager.GetMesh(meshId)
if mesh == nil {
return errors.New(fmt.Sprintf("mesh %s does not exist", meshId))
return fmt.Errorf("mesh %s does not exist", meshId)
}
dev, err := mesh.GetDevice()
@ -110,7 +110,7 @@ func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error {
return err
}
m.meshManager.Client.ConfigureDevice(dev.Name, wgtypes.Config{
m.meshManager.GetClient().ConfigureDevice(dev.Name, wgtypes.Config{
ReplacePeers: true,
Peers: make([]wgtypes.PeerConfig, 1),
})
@ -118,6 +118,10 @@ func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error {
return nil
}
func NewWgMeshConfigApplyer(manager *MeshManager) MeshConfigApplyer {
return &WgMeshConfigApplyer{meshManager: manager}
func (m *WgMeshConfigApplyer) SetMeshManager(manager MeshManager) {
m.meshManager = manager
}
func NewWgMeshConfigApplyer() MeshConfigApplyer {
return &WgMeshConfigApplyer{}
}

View File

@ -15,7 +15,7 @@ type MeshGraphConverter interface {
}
type MeshDOTConverter struct {
manager *MeshManager
manager MeshManager
}
func (c *MeshDOTConverter) Generate(meshId string) (string, error) {
@ -34,7 +34,7 @@ func (c *MeshDOTConverter) Generate(meshId string) (string, error) {
}
for _, node := range snapshot.GetNodes() {
c.graphNode(g, node)
c.graphNode(g, node, meshId)
}
nodes := lib.MapValues(snapshot.GetNodes())
@ -55,11 +55,13 @@ func (c *MeshDOTConverter) Generate(meshId string) (string, error) {
}
// graphNode: graphs a node within the mesh
func (c *MeshDOTConverter) graphNode(g *graph.Graph, node MeshNode) {
func (c *MeshDOTConverter) graphNode(g *graph.Graph, node MeshNode, meshId string) {
nodeId := fmt.Sprintf("\"%s\"", node.GetIdentifier())
g.PutNode(nodeId, graph.CIRCLE)
if node.GetHostEndpoint() == c.manager.HostParameters.HostEndpoint {
self, _ := c.manager.GetSelf(meshId)
if node.GetHostEndpoint() == self.GetHostEndpoint() {
return
}
@ -70,6 +72,6 @@ func (c *MeshDOTConverter) graphNode(g *graph.Graph, node MeshNode) {
}
}
func NewMeshDotConverter(m *MeshManager) MeshGraphConverter {
func NewMeshDotConverter(m MeshManager) MeshGraphConverter {
return &MeshDOTConverter{manager: m}
}

View File

@ -13,7 +13,31 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type MeshManager struct {
type MeshManager interface {
CreateMesh(devName string, port int) (string, error)
AddMesh(params *AddMeshParams) error
HasChanges(meshid string) bool
GetMesh(meshId string) MeshProvider
EnableInterface(meshId string) error
GetPublicKey(meshId string) (*wgtypes.Key, error)
AddSelf(params *AddSelfParams) error
LeaveMesh(meshId string) error
GetSelf(meshId string) (MeshNode, error)
ApplyConfig() error
SetDescription(description string) error
SetAlias(alias string) error
SetService(service string, value string) error
RemoveService(service string) error
UpdateTimeStamp() error
GetClient() *wgctrl.Client
GetMeshes() map[string]MeshProvider
Prune() error
Close() error
GetMonitor() MeshMonitor
GetNode(string, string) MeshNode
}
type MeshManagerImpl struct {
Meshes map[string]MeshProvider
RouteManager RouteManager
Client *wgctrl.Client
@ -27,10 +51,71 @@ type MeshManager struct {
idGenerator lib.IdGenerator
ipAllocator ip.IPAllocator
interfaceManipulator wg.WgInterfaceManipulator
Monitor MeshMonitor
}
// RemoveService implements MeshManager.
func (m *MeshManagerImpl) RemoveService(service string) error {
for _, mesh := range m.Meshes {
err := mesh.RemoveService(m.HostParameters.HostEndpoint, service)
if err != nil {
return err
}
}
return nil
}
// SetService implements MeshManager.
func (m *MeshManagerImpl) SetService(service string, value string) error {
for _, mesh := range m.Meshes {
err := mesh.AddService(m.HostParameters.HostEndpoint, service, value)
if err != nil {
return err
}
}
return nil
}
func (m *MeshManagerImpl) GetNode(meshid, nodeId string) MeshNode {
mesh, ok := m.Meshes[meshid]
if !ok {
return nil
}
node, err := mesh.GetNode(nodeId)
if err != nil {
return nil
}
return node
}
// GetMonitor implements MeshManager.
func (m *MeshManagerImpl) GetMonitor() MeshMonitor {
return m.Monitor
}
// Prune implements MeshManager.
func (m *MeshManagerImpl) Prune() error {
for _, mesh := range m.Meshes {
err := mesh.Prune(m.conf.PruneTime)
if err != nil {
return err
}
}
return nil
}
// CreateMesh: Creates a new mesh, stores it and returns the mesh id
func (m *MeshManager) CreateMesh(devName string, port int) (string, error) {
func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) {
meshId, err := m.idGenerator.GetId()
if err != nil {
@ -46,25 +131,21 @@ func (m *MeshManager) CreateMesh(devName string, port int) (string, error) {
})
if err != nil {
return "", err
return "", fmt.Errorf("error creating mesh: %w", err)
}
err = m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: devName,
Port: port,
})
if !m.conf.StubWg {
err = m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: devName,
Port: port,
})
if err != nil {
return "", nil
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
}
}
m.Meshes[meshId] = nodeManager
err = m.configApplyer.RemovePeers(meshId)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
return meshId, nil
}
@ -76,7 +157,7 @@ type AddMeshParams struct {
}
// AddMesh: Add the mesh to the list of meshes
func (m *MeshManager) AddMesh(params *AddMeshParams) error {
func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: params.DevName,
Port: params.WgPort,
@ -97,54 +178,51 @@ func (m *MeshManager) AddMesh(params *AddMeshParams) error {
m.Meshes[params.MeshId] = meshProvider
return m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: params.DevName,
Port: params.WgPort,
})
if !m.conf.StubWg {
return m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: params.DevName,
Port: params.WgPort,
})
}
return nil
}
// HasChanges returns true if the mesh has changes
func (m *MeshManager) HasChanges(meshId string) bool {
func (m *MeshManagerImpl) HasChanges(meshId string) bool {
return m.Meshes[meshId].HasChanges()
}
// GetMesh returns the mesh with the given meshid
func (m *MeshManager) GetMesh(meshId string) MeshProvider {
theMesh, _ := m.Meshes[meshId]
func (m *MeshManagerImpl) GetMesh(meshId string) MeshProvider {
theMesh := m.Meshes[meshId]
return theMesh
}
// EnableInterface: Enables the given WireGuard interface.
func (s *MeshManager) EnableInterface(meshId string) error {
func (s *MeshManagerImpl) EnableInterface(meshId string) error {
err := s.configApplyer.ApplyConfig()
if err != nil {
return err
}
meshNode, err := s.GetSelf(meshId)
err = s.RouteManager.InstallRoutes()
if err != nil {
return err
}
mesh := s.GetMesh(meshId)
if err != nil {
return err
}
dev, err := mesh.GetDevice()
if err != nil {
return err
}
return s.interfaceManipulator.EnableInterface(dev.Name, meshNode.GetWgHost().String())
return nil
}
// GetPublicKey: Gets the public key of the WireGuard mesh
func (s *MeshManager) GetPublicKey(meshId string) (*wgtypes.Key, error) {
func (s *MeshManagerImpl) GetPublicKey(meshId string) (*wgtypes.Key, error) {
if s.conf.StubWg {
zeroedKey := make([]byte, wgtypes.KeyLen)
return (*wgtypes.Key)(zeroedKey), nil
}
mesh, ok := s.Meshes[meshId]
if !ok {
@ -166,12 +244,17 @@ type AddSelfParams struct {
// WgPort is the WireGuard port to advertise
WgPort int
// Endpoint is the alias of the machine to send routable packets
// to
Endpoint string
}
// AddSelf adds this host to the mesh
func (s *MeshManager) AddSelf(params *AddSelfParams) error {
func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
mesh := s.GetMesh(params.MeshId)
if mesh == nil {
return fmt.Errorf("addself: mesh %s does not exist", params.MeshId)
}
pubKey, err := s.GetPublicKey(params.MeshId)
if err != nil {
@ -191,78 +274,146 @@ func (s *MeshManager) AddSelf(params *AddSelfParams) error {
Endpoint: params.Endpoint,
})
if !s.conf.StubWg {
device, err := mesh.GetDevice()
if err != nil {
return fmt.Errorf("failed to get device %w", err)
}
err = s.interfaceManipulator.AddAddress(device.Name, fmt.Sprintf("%s/64", nodeIP))
if err != nil {
return fmt.Errorf("addSelf: failed to add address to dev %w", err)
}
}
s.Meshes[params.MeshId].AddNode(node)
return s.RouteManager.UpdateRoutes()
}
// LeaveMesh leaves the mesh network
func (s *MeshManager) LeaveMesh(meshId string) error {
_, exists := s.Meshes[meshId]
func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
mesh, exists := s.Meshes[meshId]
if !exists {
return errors.New(fmt.Sprintf("mesh %s does not exist", meshId))
return fmt.Errorf("mesh %s does not exist", meshId)
}
err := s.RouteManager.RemoveRoutes(meshId)
if err != nil {
return err
}
if !s.conf.StubWg {
device, e := mesh.GetDevice()
if e != nil {
return err
}
err = s.interfaceManipulator.RemoveInterface(device.Name)
}
// For now just delete the mesh with the ID.
delete(s.Meshes, meshId)
return nil
return err
}
func (s *MeshManager) GetSelf(meshId string) (MeshNode, error) {
func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
meshInstance, ok := s.Meshes[meshId]
if !ok {
return nil, errors.New(fmt.Sprintf("mesh %s does not exist", meshId))
return nil, fmt.Errorf("mesh %s does not exist", meshId)
}
snapshot, err := meshInstance.GetMesh()
node, err := meshInstance.GetNode(s.HostParameters.HostEndpoint)
if err != nil {
return nil, err
}
node, ok := snapshot.GetNodes()[s.HostParameters.HostEndpoint]
if !ok {
return nil, errors.New("the node doesn't exist in the mesh")
}
return node, nil
}
func (s *MeshManager) ApplyConfig() error {
return s.configApplyer.ApplyConfig()
func (s *MeshManagerImpl) ApplyConfig() error {
err := s.configApplyer.ApplyConfig()
if err != nil {
return err
}
return nil
}
func (s *MeshManager) SetDescription(description string) error {
func (s *MeshManagerImpl) SetDescription(description string) error {
for _, mesh := range s.Meshes {
err := mesh.SetDescription(s.HostParameters.HostEndpoint, description)
if mesh.NodeExists(s.HostParameters.HostEndpoint) {
err := mesh.SetDescription(s.HostParameters.HostEndpoint, description)
if err != nil {
return err
if err != nil {
return err
}
}
}
return nil
}
// UpdateTimeStamp updates the timestamp of this node in all meshes
func (s *MeshManager) UpdateTimeStamp() error {
// SetAlias implements MeshManager.
func (s *MeshManagerImpl) SetAlias(alias string) error {
for _, mesh := range s.Meshes {
snapshot, err := mesh.GetMesh()
if mesh.NodeExists(s.HostParameters.HostEndpoint) {
err := mesh.SetAlias(s.HostParameters.HostEndpoint, alias)
if err != nil {
return err
}
}
}
return nil
}
// UpdateTimeStamp updates the timestamp of this node in all meshes
func (s *MeshManagerImpl) UpdateTimeStamp() error {
for _, mesh := range s.Meshes {
if mesh.NodeExists(s.HostParameters.HostEndpoint) {
err := mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint)
if err != nil {
return err
}
}
}
return nil
}
func (s *MeshManagerImpl) GetClient() *wgctrl.Client {
return s.Client
}
func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider {
return s.Meshes
}
// Close the mesh manager
func (s *MeshManagerImpl) Close() error {
if s.conf.StubWg {
return nil
}
for _, mesh := range s.Meshes {
dev, err := mesh.GetDevice()
if err != nil {
return err
}
_, exists := snapshot.GetNodes()[s.HostParameters.HostEndpoint]
err = s.interfaceManipulator.RemoveInterface(dev.Name)
if exists {
err = mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint)
if err != nil {
return err
}
if err != nil {
return err
}
}
@ -278,10 +429,12 @@ type NewMeshManagerParams struct {
IdGenerator lib.IdGenerator
IPAllocator ip.IPAllocator
InterfaceManipulator wg.WgInterfaceManipulator
ConfigApplyer MeshConfigApplyer
RouteManager RouteManager
}
// Creates a new instance of a mesh manager with the given parameters
func NewMeshManager(params *NewMeshManagerParams) *MeshManager {
func NewMeshManager(params *NewMeshManagerParams) MeshManager {
hostParams := HostParameters{}
switch params.Conf.Endpoint {
@ -293,7 +446,7 @@ func NewMeshManager(params *NewMeshManagerParams) *MeshManager {
logging.Log.WriteInfof("Endpoint %s", hostParams.HostEndpoint)
m := &MeshManager{
m := &MeshManagerImpl{
Meshes: make(map[string]MeshProvider),
HostParameters: &hostParams,
meshProviderFactory: params.MeshProvider,
@ -301,10 +454,22 @@ func NewMeshManager(params *NewMeshManagerParams) *MeshManager {
Client: params.Client,
conf: &params.Conf,
}
m.configApplyer = NewWgMeshConfigApplyer(m)
m.RouteManager = NewRouteManager(m)
m.configApplyer = params.ConfigApplyer
m.RouteManager = params.RouteManager
if m.RouteManager == nil {
m.RouteManager = NewRouteManager(m)
}
m.idGenerator = params.IdGenerator
m.ipAllocator = params.IPAllocator
m.interfaceManipulator = params.InterfaceManipulator
m.Monitor = NewMeshMonitor(m)
aliasManager := NewAliasManager()
m.Monitor.AddUpdateCallback(aliasManager.AddAliases)
m.Monitor.AddRemoveCallback(aliasManager.RemoveAliases)
return m
}

247
pkg/mesh/manager_test.go Normal file
View File

@ -0,0 +1,247 @@
package mesh
import (
"testing"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/wg"
)
func getMeshConfiguration() *conf.WgMeshConfiguration {
return &conf.WgMeshConfiguration{
GrpcPort: "8080",
Endpoint: "abc.com",
ClusterSize: 64,
SyncRate: 4,
BranchRate: 3,
InterClusterChance: 0.15,
InfectionCount: 2,
KeepAliveTime: 60,
}
}
func getMeshManager() MeshManager {
manager := NewMeshManager(&NewMeshManagerParams{
Conf: *getMeshConfiguration(),
Client: nil,
MeshProvider: &StubMeshProviderFactory{},
NodeFactory: &StubNodeFactory{Config: getMeshConfiguration()},
IdGenerator: &lib.UUIDGenerator{},
IPAllocator: &ip.ULABuilder{},
InterfaceManipulator: &wg.WgInterfaceManipulatorStub{},
ConfigApplyer: &MeshConfigApplyerStub{},
RouteManager: &RouteManagerStub{},
})
return manager
}
func TestCreateMeshCreatesANewMeshProvider(t *testing.T) {
manager := getMeshManager()
meshId, err := manager.CreateMesh("wg0", 5000)
if err != nil {
t.Error(err)
}
if len(meshId) == 0 {
t.Fatal(`meshId should not be empty`)
}
_, exists := manager.GetMeshes()[meshId]
if !exists {
t.Fatal(`mesh was not created when it should be`)
}
}
func TestAddMeshAddsAMesh(t *testing.T) {
manager := getMeshManager()
meshId := "meshid123"
manager.AddMesh(&AddMeshParams{
MeshId: meshId,
DevName: "wg0",
WgPort: 6000,
MeshBytes: make([]byte, 0),
})
mesh := manager.GetMesh(meshId)
if mesh == nil || mesh.GetMeshId() != meshId {
t.Fatalf(`mesh has not been added to the list of meshes`)
}
}
func TestAddMeshMeshAlreadyExistsReplacesIt(t *testing.T) {
manager := getMeshManager()
meshId := "meshid123"
for i := 0; i < 2; i++ {
err := manager.AddMesh(&AddMeshParams{
MeshId: meshId,
DevName: "wg0",
WgPort: 6000,
MeshBytes: make([]byte, 0),
})
if err != nil {
t.Error(err)
}
}
mesh := manager.GetMesh(meshId)
if mesh == nil || mesh.GetMeshId() != meshId {
t.Fatalf(`mesh has not been added to the list of meshes`)
}
}
func TestAddSelfAddsSelfToTheMesh(t *testing.T) {
manager := getMeshManager()
meshId := "meshid123"
err := manager.AddMesh(&AddMeshParams{
MeshId: meshId,
DevName: "wg0",
WgPort: 6000,
MeshBytes: make([]byte, 0),
})
if err != nil {
t.Error(err)
}
err = manager.AddSelf(&AddSelfParams{
MeshId: meshId,
WgPort: 5000,
Endpoint: "abc.com",
})
if err != nil {
t.Error(err)
}
mesh, err := manager.GetMesh(meshId).GetMesh()
if err != nil {
t.Error(err)
}
_, ok := mesh.GetNodes()["abc.com"]
if !ok {
t.Fatalf(`node has not been added`)
}
}
func TestAddSelfToMeshAlreadyInMesh(t *testing.T) {
TestAddSelfAddsSelfToTheMesh(t)
TestAddSelfAddsSelfToTheMesh(t)
}
func TestAddSelfToMeshMeshDoesNotExist(t *testing.T) {
manager := getMeshManager()
meshId := "meshid123"
err := manager.AddSelf(&AddSelfParams{
MeshId: meshId,
WgPort: 5000,
Endpoint: "abc.com",
})
if err == nil {
t.Fatalf(`Expected error to be thrown`)
}
}
func TestLeaveMeshMeshDoesNotExist(t *testing.T) {
manager := getMeshManager()
meshId := "meshid123"
err := manager.LeaveMesh(meshId)
if err == nil {
t.Fatalf(`Expected error to be thrown`)
}
}
func TestLeaveMeshDeletesMesh(t *testing.T) {
manager := getMeshManager()
meshId := "meshid123"
err := manager.AddMesh(&AddMeshParams{
MeshId: meshId,
DevName: "wg0",
WgPort: 6000,
MeshBytes: make([]byte, 0),
})
if err != nil {
t.Error(err)
}
err = manager.LeaveMesh(meshId)
if err != nil {
t.Fatalf("%s", err.Error())
}
_, exists := manager.GetMeshes()[meshId]
if exists {
t.Fatalf(`expected mesh to have been deleted`)
}
}
func TestSetDescription(t *testing.T) {
manager := getMeshManager()
description := "wooooo"
meshId1, _ := manager.CreateMesh("wg0", 5000)
meshId2, _ := manager.CreateMesh("wg0", 5001)
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId2,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
err := manager.SetDescription(description)
if err != nil {
t.Fatalf(`failed to set the descriptions`)
}
}
func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) {
manager := getMeshManager()
meshId1, _ := manager.CreateMesh("wg0", 5000)
meshId2, _ := manager.CreateMesh("wg0", 5001)
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId2,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
err := manager.UpdateTimeStamp()
if err != nil {
t.Fatalf(`failed to update the timestamp`)
}
}

81
pkg/mesh/monitor.go Normal file
View File

@ -0,0 +1,81 @@
package mesh
type OnChange = func([]MeshNode)
type MeshMonitor interface {
AddUpdateCallback(cb OnChange)
AddRemoveCallback(cb OnChange)
Trigger() error
}
type MeshMonitorImpl struct {
updateCbs []OnChange
removeCbs []OnChange
nodes map[string]MeshNode
manager MeshManager
}
// Trigger causes the mesh monitor to trigger all of
// the callbacks.
func (m *MeshMonitorImpl) Trigger() error {
changedNodes := make([]MeshNode, 0)
removedNodes := make([]MeshNode, 0)
nodes := make(map[string]MeshNode)
for _, mesh := range m.manager.GetMeshes() {
snapshot, err := mesh.GetMesh()
if err != nil {
return err
}
for _, node := range snapshot.GetNodes() {
previous, exists := m.nodes[node.GetWgHost().String()]
if !exists || !NodeEquals(previous, node) {
changedNodes = append(changedNodes, node)
}
nodes[node.GetWgHost().String()] = node
}
}
for _, previous := range m.nodes {
_, ok := nodes[previous.GetWgHost().String()]
if !ok {
removedNodes = append(removedNodes, previous)
}
}
if len(removedNodes) > 0 {
for _, cb := range m.removeCbs {
cb(removedNodes)
}
}
if len(changedNodes) > 0 {
for _, cb := range m.updateCbs {
cb(changedNodes)
}
}
return nil
}
func (m *MeshMonitorImpl) AddUpdateCallback(cb OnChange) {
m.updateCbs = append(m.updateCbs, cb)
}
func (m *MeshMonitorImpl) AddRemoveCallback(cb OnChange) {
m.removeCbs = append(m.removeCbs, cb)
}
func NewMeshMonitor(manager MeshManager) MeshMonitor {
return &MeshMonitorImpl{
updateCbs: make([]OnChange, 0),
nodes: make(map[string]MeshNode),
manager: manager,
}
}

16
pkg/mesh/pruner.go Normal file
View File

@ -0,0 +1,16 @@
package mesh
import (
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
func pruneFunction(m MeshManager) lib.TimerFunc {
return func() error {
return m.Prune()
}
}
func NewPruner(m MeshManager, conf conf.WgMeshConfiguration) *lib.Timer {
return lib.NewTimer(pruneFunction(m), conf.PruneTime/2)
}

View File

@ -1,22 +1,29 @@
package mesh
import (
"fmt"
"net"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/route"
"golang.org/x/sys/unix"
)
type RouteManager interface {
UpdateRoutes() error
InstallRoutes() error
RemoveRoutes(meshId string) error
}
type RouteManagerImpl struct {
meshManager *MeshManager
meshManager MeshManager
routeInstaller route.RouteInstaller
}
func (r *RouteManagerImpl) UpdateRoutes() error {
meshes := r.meshManager.Meshes
meshes := r.meshManager.GetMeshes()
ulaBuilder := new(ip.ULABuilder)
for _, mesh1 := range meshes {
@ -32,7 +39,13 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
return err
}
err = mesh1.AddRoutes(r.meshManager.HostParameters.HostEndpoint, ipNet.String())
self, err := r.meshManager.GetSelf(mesh1.GetMeshId())
if err != nil {
return err
}
err = mesh1.AddRoutes(self.GetHostEndpoint(), ipNet.String())
if err != nil {
return err
@ -43,6 +56,129 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
return nil
}
func NewRouteManager(m *MeshManager) RouteManager {
// removeRoutes: removes all meshes we are no longer a part of
func (r *RouteManagerImpl) RemoveRoutes(meshId string) error {
ulaBuilder := new(ip.ULABuilder)
meshes := r.meshManager.GetMeshes()
ipNet, err := ulaBuilder.GetIPNet(meshId)
if err != nil {
return err
}
for _, mesh1 := range meshes {
self, err := r.meshManager.GetSelf(meshId)
if err != nil {
return err
}
mesh1.RemoveRoutes(self.GetHostEndpoint(), ipNet.String())
}
return nil
}
// AddRoute adds a route to the given interface
func (m *RouteManagerImpl) addRoute(ifName string, meshPrefix string, routes ...lib.Route) error {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return fmt.Errorf("failed to create config: %w", err)
}
defer rtnl.Close()
// Delete any routes that may be vacant
err = rtnl.DeleteRoutes(ifName, unix.AF_INET6, routes...)
if err != nil {
return err
}
for _, route := range routes {
if route.Destination.String() == meshPrefix {
continue
}
err = rtnl.AddRoute(ifName, route)
if err != nil {
return err
}
}
return nil
}
func (m *RouteManagerImpl) installRoute(ifName string, meshid string, node MeshNode) error {
routeMapFunc := func(route string) lib.Route {
_, cidr, _ := net.ParseCIDR(route)
r := lib.Route{
Destination: *cidr,
Gateway: node.GetWgHost().IP,
}
return r
}
ipBuilder := &ip.ULABuilder{}
ipNet, err := ipBuilder.GetIPNet(meshid)
if err != nil {
return err
}
routes := lib.Map(append(node.GetRoutes(), ipNet.String()), routeMapFunc)
return m.addRoute(ifName, ipNet.String(), routes...)
}
func (m *RouteManagerImpl) installRoutes(meshProvider MeshProvider) error {
mesh, err := meshProvider.GetMesh()
if err != nil {
return err
}
dev, err := meshProvider.GetDevice()
if err != nil {
return err
}
self, err := m.meshManager.GetSelf(meshProvider.GetMeshId())
if err != nil {
return err
}
for _, node := range mesh.GetNodes() {
if self.GetHostEndpoint() == node.GetHostEndpoint() {
continue
}
err = m.installRoute(dev.Name, meshProvider.GetMeshId(), node)
if err != nil {
return err
}
}
return nil
}
// InstallRoutes installs all routes to the RIB
func (r *RouteManagerImpl) InstallRoutes() error {
for _, mesh := range r.meshManager.GetMeshes() {
err := r.installRoutes(mesh)
if err != nil {
return err
}
}
return nil
}
func NewRouteManager(m MeshManager) RouteManager {
return &RouteManagerImpl{meshManager: m, routeInstaller: route.NewRouteInstaller()}
}

16
pkg/mesh/route_stub.go Normal file
View File

@ -0,0 +1,16 @@
package mesh
type RouteManagerStub struct {
}
func (r *RouteManagerStub) UpdateRoutes() error {
return nil
}
func (r *RouteManagerStub) InstallRoutes() error {
return nil
}
func (r *RouteManagerStub) RemoveRoutes(meshId string) error {
return nil
}

315
pkg/mesh/stub_types.go Normal file
View File

@ -0,0 +1,315 @@
package mesh
import (
"fmt"
"net"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type MeshNodeStub struct {
hostEndpoint string
publicKey wgtypes.Key
wgEndpoint string
wgHost *net.IPNet
timeStamp int64
routes []string
identifier string
description string
}
// GetServices implements MeshNode.
func (*MeshNodeStub) GetServices() map[string]string {
return make(map[string]string)
}
// GetAlias implements MeshNode.
func (*MeshNodeStub) GetAlias() string {
return ""
}
func (m *MeshNodeStub) GetHostEndpoint() string {
return m.hostEndpoint
}
func (m *MeshNodeStub) GetPublicKey() (wgtypes.Key, error) {
return m.publicKey, nil
}
func (m *MeshNodeStub) GetWgEndpoint() string {
return m.wgEndpoint
}
func (m *MeshNodeStub) GetWgHost() *net.IPNet {
return m.wgHost
}
func (m *MeshNodeStub) GetTimeStamp() int64 {
return m.timeStamp
}
func (m *MeshNodeStub) GetRoutes() []string {
return m.routes
}
func (m *MeshNodeStub) GetIdentifier() string {
return m.identifier
}
func (m *MeshNodeStub) GetDescription() string {
return m.description
}
type MeshSnapshotStub struct {
nodes map[string]MeshNode
}
func (s *MeshSnapshotStub) GetNodes() map[string]MeshNode {
return s.nodes
}
type MeshProviderStub struct {
meshId string
snapshot *MeshSnapshotStub
}
// GetNodeIds implements MeshProvider.
func (*MeshProviderStub) GetNodeIds() []string {
panic("unimplemented")
}
// GetNode implements MeshProvider.
func (*MeshProviderStub) GetNode(string) (MeshNode, error) {
panic("unimplemented")
}
// NodeExists implements MeshProvider.
func (*MeshProviderStub) NodeExists(string) bool {
panic("unimplemented")
}
// AddService implements MeshProvider.
func (*MeshProviderStub) AddService(nodeId string, key string, value string) error {
panic("unimplemented")
}
// RemoveService implements MeshProvider.
func (*MeshProviderStub) RemoveService(nodeId string, key string) error {
panic("unimplemented")
}
// SetAlias implements MeshProvider.
func (*MeshProviderStub) SetAlias(nodeId string, alias string) error {
panic("unimplemented")
}
// RemoveRoutes implements MeshProvider.
func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error {
panic("unimplemented")
}
// Prune implements MeshProvider.
func (*MeshProviderStub) Prune(pruneAmount int) error {
return nil
}
// UpdateTimeStamp implements MeshProvider.
func (*MeshProviderStub) UpdateTimeStamp(nodeId string) error {
return nil
}
func (s *MeshProviderStub) AddNode(node MeshNode) {
s.snapshot.nodes[node.GetHostEndpoint()] = node
}
func (s *MeshProviderStub) GetMesh() (MeshSnapshot, error) {
return s.snapshot, nil
}
func (s *MeshProviderStub) GetMeshId() string {
return s.meshId
}
func (s *MeshProviderStub) Save() []byte {
return make([]byte, 0)
}
func (s *MeshProviderStub) Load(bytes []byte) error {
return nil
}
func (s *MeshProviderStub) GetDevice() (*wgtypes.Device, error) {
pubKey, _ := wgtypes.GenerateKey()
return &wgtypes.Device{
PublicKey: pubKey,
}, nil
}
func (s *MeshProviderStub) SaveChanges() {}
func (s *MeshProviderStub) HasChanges() bool {
return false
}
func (s *MeshProviderStub) AddRoutes(nodeId string, route ...string) error {
return nil
}
func (s *MeshProviderStub) GetSyncer() MeshSyncer {
return nil
}
func (s *MeshProviderStub) SetDescription(nodeId string, description string) error {
return nil
}
type StubMeshProviderFactory struct{}
func (s *StubMeshProviderFactory) CreateMesh(params *MeshProviderFactoryParams) (MeshProvider, error) {
return &MeshProviderStub{
meshId: params.MeshId,
snapshot: &MeshSnapshotStub{nodes: make(map[string]MeshNode)},
}, nil
}
type StubNodeFactory struct {
Config *conf.WgMeshConfiguration
}
func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode {
_, wgHost, _ := net.ParseCIDR(fmt.Sprintf("%s/128", params.NodeIP.String()))
return &MeshNodeStub{
hostEndpoint: params.Endpoint,
publicKey: *params.PublicKey,
wgEndpoint: fmt.Sprintf("%s:%s", params.Endpoint, s.Config.GrpcPort),
wgHost: wgHost,
timeStamp: time.Now().Unix(),
routes: make([]string, 0),
identifier: "abc",
description: "A Mesh Node Stub",
}
}
type MeshConfigApplyerStub struct{}
func (a *MeshConfigApplyerStub) ApplyConfig() error {
return nil
}
func (a *MeshConfigApplyerStub) RemovePeers(meshId string) error {
return nil
}
func (a *MeshConfigApplyerStub) SetMeshManager(manager MeshManager) {
}
type MeshManagerStub struct {
meshes map[string]MeshProvider
}
// GetNode implements MeshManager.
func (*MeshManagerStub) GetNode(string, string) MeshNode {
panic("unimplemented")
}
// RemoveService implements MeshManager.
func (*MeshManagerStub) RemoveService(service string) error {
panic("unimplemented")
}
// SetService implements MeshManager.
func (*MeshManagerStub) SetService(service string, value string) error {
panic("unimplemented")
}
// GetMonitor implements MeshManager.
func (*MeshManagerStub) GetMonitor() MeshMonitor {
panic("unimplemented")
}
// SetAlias implements MeshManager.
func (*MeshManagerStub) SetAlias(alias string) error {
panic("unimplemented")
}
// Close implements MeshManager.
func (*MeshManagerStub) Close() error {
panic("unimplemented")
}
// Prune implements MeshManager.
func (*MeshManagerStub) Prune() error {
return nil
}
func NewMeshManagerStub() MeshManager {
return &MeshManagerStub{meshes: make(map[string]MeshProvider)}
}
func (m *MeshManagerStub) CreateMesh(devName string, port int) (string, error) {
return "tim123", nil
}
func (m *MeshManagerStub) AddMesh(params *AddMeshParams) error {
m.meshes[params.MeshId] = &MeshProviderStub{
params.MeshId,
&MeshSnapshotStub{nodes: make(map[string]MeshNode)},
}
return nil
}
func (m *MeshManagerStub) HasChanges(meshId string) bool {
return false
}
func (m *MeshManagerStub) GetMesh(meshId string) MeshProvider {
return &MeshProviderStub{
meshId: meshId,
snapshot: &MeshSnapshotStub{nodes: make(map[string]MeshNode)}}
}
func (m *MeshManagerStub) EnableInterface(meshId string) error {
return nil
}
func (m *MeshManagerStub) GetPublicKey(meshId string) (*wgtypes.Key, error) {
key, _ := wgtypes.GenerateKey()
return &key, nil
}
func (m *MeshManagerStub) AddSelf(params *AddSelfParams) error {
return nil
}
func (m *MeshManagerStub) GetSelf(meshId string) (MeshNode, error) {
return nil, nil
}
func (m *MeshManagerStub) ApplyConfig() error {
return nil
}
func (m *MeshManagerStub) SetDescription(description string) error {
return nil
}
func (m *MeshManagerStub) UpdateTimeStamp() error {
return nil
}
func (m *MeshManagerStub) GetClient() *wgctrl.Client {
return nil
}
func (m *MeshManagerStub) GetMeshes() map[string]MeshProvider {
return m.meshes
}
func (m *MeshManagerStub) LeaveMesh(meshId string) error {
return nil
}

View File

@ -4,6 +4,7 @@ package mesh
import (
"net"
"slices"
"github.com/tim-beatham/wgmesh/pkg/conf"
"golang.zx2c4.com/wireguard/wgctrl"
@ -28,6 +29,51 @@ type MeshNode interface {
GetIdentifier() string
// GetDescription: returns the description for this node
GetDescription() string
// GetAlias: associates the node with an alias. Potentially used
// for DNS and so forth.
GetAlias() string
// GetServices: returns a list of services offered by the node
GetServices() map[string]string
}
// NodeEquals: determines if two mesh nodes are equivalent to one another
func NodeEquals(node1, node2 MeshNode) bool {
if node1.GetHostEndpoint() != node2.GetHostEndpoint() {
return false
}
node1Pub, _ := node1.GetPublicKey()
node2Pub, _ := node2.GetPublicKey()
if node1Pub != node2Pub {
return false
}
if node1.GetWgEndpoint() != node2.GetWgEndpoint() {
return false
}
if node1.GetWgHost() != node2.GetWgHost() {
return false
}
if !slices.Equal(node1.GetRoutes(), node2.GetRoutes()) {
return false
}
if node1.GetIdentifier() != node2.GetIdentifier() {
return false
}
if node1.GetDescription() != node2.GetDescription() {
return false
}
if node1.GetAlias() != node2.GetAlias() {
return false
}
return true
}
type MeshSnapshot interface {
@ -46,7 +92,7 @@ type MeshSyncer interface {
type MeshProvider interface {
// AddNode() adds a node to the mesh
AddNode(node MeshNode)
// GetMesh() returns a snapshot of the mesh provided by the mesh provider
// GetMesh() returns a snapshot of the mesh provided by the mesh provider.
GetMesh() (MeshSnapshot, error)
// GetMeshId() returns the ID of the mesh network
GetMeshId() string
@ -58,20 +104,37 @@ type MeshProvider interface {
GetDevice() (*wgtypes.Device, error)
// HasChanges returns true if we have changes since last time we synced
HasChanges() bool
// Record that we have changges and save the corresponding changes
// Record that we have changes and save the corresponding changes
SaveChanges()
// UpdateTimeStamp: update the timestamp of the given node
UpdateTimeStamp(nodeId string) error
// AddRoutes: adds routes to the given node
AddRoutes(nodeId string, route ...string) error
// DeleteRoutes: deletes the routes from the node
RemoveRoutes(nodeId string, route ...string) error
// GetSyncer: returns the automerge syncer for sync
GetSyncer() MeshSyncer
// GetNode get a particular not within the mesh
GetNode(string) (MeshNode, error)
// NodeExists: returns true if a particular node exists false otherwise
NodeExists(string) bool
// SetDescription: sets the description of this automerge data type
SetDescription(nodeId string, description string) error
// SetAlias: set the alias of the nodeId
SetAlias(nodeId string, alias string) error
// AddService: adds the service to the given node
AddService(nodeId, key, value string) error
// RemoveService: removes the service form the node. throws an error if the service does not exist
RemoveService(nodeId, key string) error
// Prune: prunes all nodes that have not updated their timestamp in
// pruneAmount seconds
Prune(pruneAmount int) error
GetNodeIds() []string
}
// HostParameters contains the IDs of a node
type HostParameters struct {
HostEndpoint string
// TODO: Contain the WireGungracefullyuard identifier in this
}
// MeshProviderFactoryParams parameters required to build a mesh provider

View File

@ -1,29 +0,0 @@
package middleware
import (
"context"
"errors"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/rpc"
)
// AuthRpcProvider implements the AuthRpcProvider service
type AuthRpcProvider struct {
rpc.UnimplementedAuthenticationServer
}
// JoinMesh handles a JoinMeshRequest. Succeeds by stating the node managed to join the mesh
// or returns an error if it failed
func (a *AuthRpcProvider) JoinMesh(ctx context.Context, in *rpc.JoinAuthMeshRequest) (*rpc.JoinAuthMeshReply, error) {
meshId := in.MeshId
if meshId == "" {
return nil, errors.New("Must specify the meshId")
}
logging.Log.WriteInfof("MeshID: " + in.MeshId)
var token string = ""
return &rpc.JoinAuthMeshReply{Success: true, Token: &token}, nil
}

View File

@ -16,7 +16,7 @@ type Querier interface {
}
type JmesQuerier struct {
manager *mesh.MeshManager
manager mesh.MeshManager
}
type QueryError struct {
@ -24,13 +24,15 @@ type QueryError struct {
}
type QueryNode struct {
HostEndpoint string `json:"hostEndpoint"`
PublicKey string `json:"publicKey"`
WgEndpoint string `json:"wgEndpoint"`
WgHost string `json:"wgHost"`
Timestamp int64 `json:"timestmap"`
Description string `json:"description"`
Routes []string `json:"routes"`
HostEndpoint string `json:"hostEndpoint"`
PublicKey string `json:"publicKey"`
WgEndpoint string `json:"wgEndpoint"`
WgHost string `json:"wgHost"`
Timestamp int64 `json:"timestmap"`
Description string `json:"description"`
Routes []string `json:"routes"`
Alias string `json:"alias"`
Services map[string]string `json:"services"`
}
func (m *QueryError) Error() string {
@ -39,7 +41,7 @@ func (m *QueryError) Error() string {
// Query: queries the data
func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
mesh, ok := j.manager.Meshes[meshId]
mesh, ok := j.manager.GetMeshes()[meshId]
if !ok {
return nil, &QueryError{msg: fmt.Sprintf("%s does not exist", meshId)}
@ -51,7 +53,7 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
return nil, err
}
nodes := lib.Map(lib.MapValues(snapshot.GetNodes()), meshNodeToQueryNode)
nodes := lib.Map(lib.MapValues(snapshot.GetNodes()), MeshNodeToQueryNode)
result, err := jmespath.Search(queryParams, nodes)
@ -63,7 +65,7 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
return bytes, err
}
func meshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
func MeshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode := new(QueryNode)
queryNode.HostEndpoint = node.GetHostEndpoint()
pubKey, _ := node.GetPublicKey()
@ -76,9 +78,12 @@ func meshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode.Timestamp = node.GetTimeStamp()
queryNode.Routes = node.GetRoutes()
queryNode.Description = node.GetDescription()
queryNode.Alias = node.GetAlias()
queryNode.Services = node.GetServices()
return queryNode
}
func NewJmesQuerier(manager *mesh.MeshManager) Querier {
func NewJmesQuerier(manager mesh.MeshManager) Querier {
return &JmesQuerier{manager: manager}
}

View File

@ -2,45 +2,49 @@ package robin
import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
"time"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query"
"github.com/tim-beatham/wgmesh/pkg/rpc"
)
type IpcHandler struct {
Server *ctrlserver.MeshCtrlServer
ipAllocator ip.IPAllocator
Server ctrlserver.CtrlServer
}
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
meshId, err := n.Server.MeshManager.CreateMesh(args.IfName, args.WgPort)
meshId, err := n.Server.GetMeshManager().CreateMesh(args.IfName, args.WgPort)
if err != nil {
return err
}
err = n.Server.MeshManager.AddSelf(&mesh.AddSelfParams{
err = n.Server.GetMeshManager().AddSelf(&mesh.AddSelfParams{
MeshId: meshId,
WgPort: args.WgPort,
Endpoint: args.Endpoint,
})
if err != nil {
return err
}
*reply = meshId
return err
}
func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
meshNames := make([]string, len(n.Server.MeshManager.Meshes))
meshNames := make([]string, len(n.Server.GetMeshManager().GetMeshes()))
i := 0
for meshId, _ := range n.Server.MeshManager.Meshes {
for meshId, _ := range n.Server.GetMeshManager().GetMeshes() {
meshNames[i] = meshId
i++
}
@ -50,7 +54,7 @@ func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
}
func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
peerConnection, err := n.Server.ConnectionManager.GetConnection(args.IpAdress)
peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAdress)
if err != nil {
return err
@ -77,7 +81,7 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
return err
}
err = n.Server.MeshManager.AddMesh(&mesh.AddMeshParams{
err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
MeshId: args.MeshId,
DevName: args.IfName,
WgPort: args.Port,
@ -88,7 +92,7 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
return err
}
err = n.Server.MeshManager.AddSelf(&mesh.AddSelfParams{
err = n.Server.GetMeshManager().AddSelf(&mesh.AddSelfParams{
MeshId: args.MeshId,
WgPort: args.Port,
Endpoint: args.Endpoint,
@ -104,7 +108,7 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
// LeaveMesh leaves a mesh network
func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error {
err := n.Server.MeshManager.LeaveMesh(meshId)
err := n.Server.GetMeshManager().LeaveMesh(meshId)
if err == nil {
*reply = fmt.Sprintf("Left Mesh %s", meshId)
@ -114,7 +118,12 @@ func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error {
}
func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
mesh := n.Server.MeshManager.GetMesh(meshId)
mesh := n.Server.GetMeshManager().GetMesh(meshId)
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
meshSnapshot, err := mesh.GetMesh()
if err != nil {
@ -124,6 +133,7 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
if mesh == nil {
return errors.New("mesh does not exist")
}
nodes := make([]ctrlserver.MeshNode, len(meshSnapshot.GetNodes()))
i := 0
@ -141,6 +151,9 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
WgHost: node.GetWgHost().String(),
Timestamp: node.GetTimeStamp(),
Routes: node.GetRoutes(),
Description: node.GetDescription(),
Alias: node.GetAlias(),
Services: node.GetServices(),
}
nodes[i] = node
@ -152,7 +165,7 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
}
func (n *IpcHandler) EnableInterface(meshId string, reply *string) error {
err := n.Server.MeshManager.EnableInterface(meshId)
err := n.Server.GetMeshManager().EnableInterface(meshId)
if err != nil {
*reply = err.Error()
@ -164,7 +177,7 @@ func (n *IpcHandler) EnableInterface(meshId string, reply *string) error {
}
func (n *IpcHandler) GetDOT(meshId string, reply *string) error {
g := mesh.NewMeshDotConverter(n.Server.MeshManager)
g := mesh.NewMeshDotConverter(n.Server.GetMeshManager())
result, err := g.Generate(meshId)
@ -177,7 +190,7 @@ func (n *IpcHandler) GetDOT(meshId string, reply *string) error {
}
func (n *IpcHandler) Query(params ipc.QueryMesh, reply *string) error {
queryResponse, err := n.Server.Querier.Query(params.MeshId, params.Query)
queryResponse, err := n.Server.GetQuerier().Query(params.MeshId, params.Query)
if err != nil {
return err
@ -188,7 +201,7 @@ func (n *IpcHandler) Query(params ipc.QueryMesh, reply *string) error {
}
func (n *IpcHandler) PutDescription(description string, reply *string) error {
err := n.Server.MeshManager.SetDescription(description)
err := n.Server.GetMeshManager().SetDescription(description)
if err != nil {
return err
@ -198,8 +211,62 @@ func (n *IpcHandler) PutDescription(description string, reply *string) error {
return nil
}
func (n *IpcHandler) PutAlias(alias string, reply *string) error {
err := n.Server.GetMeshManager().SetAlias(alias)
if err != nil {
return err
}
*reply = fmt.Sprintf("Set alias to %s", alias)
return nil
}
func (n *IpcHandler) PutService(service ipc.PutServiceArgs, reply *string) error {
err := n.Server.GetMeshManager().SetService(service.Service, service.Value)
if err != nil {
return err
}
*reply = "success"
return nil
}
func (n *IpcHandler) DeleteService(service string, reply *string) error {
err := n.Server.GetMeshManager().RemoveService(service)
if err != nil {
return err
}
*reply = "success"
return nil
}
func (n *IpcHandler) GetNode(args ipc.GetNodeArgs, reply *string) error {
node := n.Server.GetMeshManager().GetNode(args.MeshId, args.NodeId)
if node == nil {
*reply = "nil"
return nil
}
queryNode := query.MeshNodeToQueryNode(node)
bytes, err := json.Marshal(queryNode)
if err != nil {
*reply = err.Error()
return nil
}
*reply = string(bytes)
return nil
}
type RobinIpcParams struct {
CtrlServer *ctrlserver.MeshCtrlServer
CtrlServer ctrlserver.CtrlServer
}
func NewRobinIpc(ipcParams RobinIpcParams) IpcHandler {

View File

@ -0,0 +1,73 @@
package robin
import (
"testing"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/mesh"
)
func getRequester() *IpcHandler {
return &IpcHandler{Server: ctrlserver.NewCtrlServerStub()}
}
func TestCreateMeshRepliesMeshId(t *testing.T) {
var reply string
requester := getRequester()
err := requester.CreateMesh(&ipc.NewMeshArgs{
IfName: "wg0",
WgPort: 5000,
Endpoint: "abc.com",
}, &reply)
if err != nil {
t.Error(err)
}
if len(reply) == 0 {
t.Fatalf(`reply should have been returned`)
}
}
func TestListMeshesNoMeshesListsEmpty(t *testing.T) {
var reply ipc.ListMeshReply
requester := getRequester()
err := requester.ListMeshes("", &reply)
if err != nil {
t.Error(err)
}
if len(reply.Meshes) != 0 {
t.Fatalf(`meshes should be empty`)
}
}
func TestListMeshesMeshesNotEmpty(t *testing.T) {
var reply ipc.ListMeshReply
requester := getRequester()
requester.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
MeshId: "tim123",
DevName: "wg0",
WgPort: 5000,
MeshBytes: make([]byte, 0),
})
err := requester.ListMeshes("", &reply)
if err != nil {
t.Error(err)
}
if len(reply.Meshes) != 1 {
t.Fatalf(`only only mesh exists`)
}
if reply.Meshes[0] != "tim123" {
t.Fatalf(`meshId was %s expected %s`, reply.Meshes[0], "tim123")
}
}

View File

@ -13,29 +13,6 @@ type WgRpc struct {
Server *ctrlserver.MeshCtrlServer
}
func nodeToRpcNode(node ctrlserver.MeshNode) *rpc.MeshNode {
return &rpc.MeshNode{
PublicKey: node.PublicKey,
WgEndpoint: node.WgEndpoint,
WgHost: node.WgHost,
Endpoint: node.HostEndpoint,
}
}
func nodesToRpcNodes(nodes map[string]ctrlserver.MeshNode) []*rpc.MeshNode {
n := len(nodes)
meshNodes := make([]*rpc.MeshNode, n)
var i int = 0
for _, v := range nodes {
meshNodes[i] = nodeToRpcNode(v)
i++
}
return meshNodes
}
func (m *WgRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.GetMeshReply, error) {
mesh := m.Server.MeshManager.GetMesh(request.MeshId)

View File

@ -0,0 +1 @@
package robin

View File

@ -1,7 +1,6 @@
package sync
import (
"errors"
"math/rand"
"sync"
"time"
@ -20,7 +19,7 @@ type Syncer interface {
}
type SyncerImpl struct {
manager *mesh.MeshManager
manager mesh.MeshManager
requester SyncRequester
infectionCount int
syncCount int
@ -30,44 +29,30 @@ type SyncerImpl struct {
// Sync: Sync random nodes
func (s *SyncerImpl) Sync(meshId string) error {
logging.Log.WriteInfof("UPDATING WG CONF")
s.manager.ApplyConfig()
if !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
logging.Log.WriteInfof("No changes for %s", meshId)
return nil
}
theMesh := s.manager.GetMesh(meshId)
logging.Log.WriteInfof("UPDATING WG CONF")
if theMesh == nil {
return errors.New("the provided mesh does not exist")
if s.manager.HasChanges(meshId) {
err := s.manager.ApplyConfig()
if err != nil {
logging.Log.WriteInfof("Failed to update config %w", err)
}
}
snapshot, err := theMesh.GetMesh()
nodeNames := s.manager.GetMesh(meshId).GetNodeIds()
self, err := s.manager.GetSelf(meshId)
if err != nil {
return err
}
nodes := snapshot.GetNodes()
if len(nodes) <= 1 {
return nil
}
excludedNodes := map[string]struct{}{
s.manager.HostParameters.HostEndpoint: {},
}
meshNodes := lib.MapValuesWithExclude(nodes, excludedNodes)
getNames := func(node mesh.MeshNode) string {
return node.GetHostEndpoint()
}
nodeNames := lib.Map(meshNodes, getNames)
neighbours := s.cluster.GetNeighbours(nodeNames, s.manager.HostParameters.HostEndpoint)
neighbours := s.cluster.GetNeighbours(nodeNames, self.GetHostEndpoint())
randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
for _, node := range randomSubset {
@ -76,9 +61,9 @@ func (s *SyncerImpl) Sync(meshId string) error {
before := time.Now()
if len(meshNodes) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
logging.Log.WriteInfof("Sending to random cluster")
interCluster := s.cluster.GetInterCluster(nodeNames, s.manager.HostParameters.HostEndpoint)
interCluster := s.cluster.GetInterCluster(nodeNames, self.GetHostEndpoint())
randomSubset = append(randomSubset, interCluster)
}
@ -97,17 +82,20 @@ func (s *SyncerImpl) Sync(meshId string) error {
waitGroup.Wait()
s.syncCount++
logging.Log.WriteInfof("SYNC TIME: %v", time.Now().Sub(before))
logging.Log.WriteInfof("SYNC TIME: %v", time.Since(before))
logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount)
s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount)
// Check if any changes have occurred and trigger callbacks
// if changes have occurred.
// return s.manager.GetMonitor().Trigger()
return nil
}
// SyncMeshes: Sync all meshes
func (s *SyncerImpl) SyncMeshes() error {
for meshId, _ := range s.manager.Meshes {
for meshId := range s.manager.GetMeshes() {
err := s.Sync(meshId)
if err != nil {
@ -118,7 +106,7 @@ func (s *SyncerImpl) SyncMeshes() error {
return nil
}
func NewSyncer(m *mesh.MeshManager, conf *conf.WgMeshConfiguration, r SyncRequester) Syncer {
func NewSyncer(m mesh.MeshManager, conf *conf.WgMeshConfiguration, r SyncRequester) Syncer {
cluster, _ := conn.NewConnCluster(conf.ClusterSize)
return &SyncerImpl{
manager: m,

View File

@ -14,7 +14,7 @@ type SyncErrorHandler interface {
// SyncErrorHandlerImpl Is an implementation of the SyncErrorHandler
type SyncErrorHandlerImpl struct {
meshManager *mesh.MeshManager
meshManager mesh.MeshManager
}
func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint string) bool {
@ -24,6 +24,13 @@ func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint stri
return false
}
// self, err := s.meshManager.GetSelf(meshId)
// if err != nil {
// return false
// }
// mesh.DecrementHealth(endpoint, self.GetHostEndpoint())
return true
}
@ -40,6 +47,6 @@ func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error)
return false
}
func NewSyncErrorHandler(m *mesh.MeshManager) SyncErrorHandler {
func NewSyncErrorHandler(m mesh.MeshManager) SyncErrorHandler {
return &SyncErrorHandlerImpl{meshManager: m}
}

View File

@ -89,10 +89,12 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
c := rpc.NewSyncServiceClient(client)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
syncTimeOut := s.server.Conf.SyncRate * float64(time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(syncTimeOut))
defer cancel()
err = syncMesh(mesh, ctx, c)
err = s.syncMesh(mesh, ctx, c)
if err != nil {
return s.handleErr(meshId, endpoint, err)
@ -102,7 +104,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
return nil
}
func syncMesh(mesh mesh.MeshProvider, ctx context.Context, client rpc.SyncServiceClient) error {
func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context, client rpc.SyncServiceClient) error {
stream, err := client.SyncMesh(ctx)
syncer := mesh.GetSyncer()

View File

@ -1,10 +1,8 @@
package sync
import (
"time"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
// SyncScheduler: Loops through all nodes in the mesh and runs a schedule to
@ -22,34 +20,13 @@ type SyncSchedulerImpl struct {
}
// Run implements SyncScheduler.
func (s *SyncSchedulerImpl) Run() error {
ticker := time.NewTicker(time.Duration(s.server.Conf.SyncRate) * time.Second)
quit := make(chan struct{})
s.quit = quit
for {
select {
case <-ticker.C:
err := s.syncer.SyncMeshes()
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
break
case <-quit:
break
}
func syncFunction(syncer Syncer) lib.TimerFunc {
return func() error {
return syncer.SyncMeshes()
}
}
// Stop implements SyncScheduler.
func (s *SyncSchedulerImpl) Stop() error {
close(s.quit)
return nil
}
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester) SyncScheduler {
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester) *lib.Timer {
syncer := NewSyncer(s.MeshManager, s.Conf, syncRequester)
return &SyncSchedulerImpl{server: s, syncer: syncer}
return lib.NewTimer(syncFunction(syncer), int(s.Conf.SyncRate))
}

View File

@ -1,48 +1,14 @@
package timestamp
import (
"time"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
type TimestampScheduler interface {
Run() error
Stop() error
}
type TimeStampSchedulerImpl struct {
meshManager *mesh.MeshManager
updateRate int
quit chan struct{}
}
func (s *TimeStampSchedulerImpl) Run() error {
ticker := time.NewTicker(time.Duration(s.updateRate) * time.Second)
s.quit = make(chan struct{})
for {
select {
case <-ticker.C:
err := s.meshManager.UpdateTimeStamp()
if err != nil {
logging.Log.WriteErrorf("Update Timestamp Error: %s", err.Error())
}
case <-s.quit:
break
}
func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
timerFunc := func() error {
return ctrlServer.MeshManager.UpdateTimeStamp()
}
}
func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) TimestampScheduler {
return &TimeStampSchedulerImpl{meshManager: ctrlServer.MeshManager, updateRate: ctrlServer.Conf.KeepAliveRate}
}
func (s *TimeStampSchedulerImpl) Stop() error {
close(s.quit)
return nil
return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime)
}

15
pkg/wg/stubs.go Normal file
View File

@ -0,0 +1,15 @@
package wg
type WgInterfaceManipulatorStub struct{}
func (i *WgInterfaceManipulatorStub) CreateInterface(params *CreateInterfaceParams) error {
return nil
}
func (i *WgInterfaceManipulatorStub) AddAddress(ifName string, addr string) error {
return nil
}
func (i *WgInterfaceManipulatorStub) RemoveInterface(ifName string) error {
return nil
}

View File

@ -16,7 +16,8 @@ type CreateInterfaceParams struct {
type WgInterfaceManipulator interface {
// CreateInterface creates a WireGuard interface
CreateInterface(params *CreateInterfaceParams) error
// Enable interface enables the given interface with
// the IP. It overrides the IP at the interface
EnableInterface(ifName string, ip string) error
// AddAddress adds an address to the given interface name
AddAddress(ifName string, addr string) error
// RemoveInterface removes the specified interface
RemoveInterface(ifName string) error
}

View File

@ -1,50 +1,37 @@
package wg
import (
"errors"
"fmt"
"net"
"os/exec"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// createInterface uses ip link to create an interface. If the interface exists
// it returns an error
func createInterface(ifName string) error {
_, err := net.InterfaceByName(ifName)
if err == nil {
err = flushInterface(ifName)
return err
}
// Check if the interface exists
cmd := exec.Command("/usr/bin/ip", "link", "add", "dev", ifName, "type", "wireguard")
if err := cmd.Run(); err != nil {
return err
}
return nil
}
type WgInterfaceManipulatorImpl struct {
client *wgctrl.Client
}
// CreateInterface creates a WireGuard interface
func (m *WgInterfaceManipulatorImpl) CreateInterface(params *CreateInterfaceParams) error {
err := createInterface(params.IfName)
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return err
return fmt.Errorf("failed to access link: %w", err)
}
defer rtnl.Close()
err = rtnl.CreateLink(params.IfName)
if err != nil {
return fmt.Errorf("failed to create link: %w", err)
}
privateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return err
return fmt.Errorf("failed to create private key: %w", err)
}
var cfg wgtypes.Config = wgtypes.Config{
@ -52,59 +39,44 @@ func (m *WgInterfaceManipulatorImpl) CreateInterface(params *CreateInterfacePara
ListenPort: &params.Port,
}
m.client.ConfigureDevice(params.IfName, cfg)
err = m.client.ConfigureDevice(params.IfName, cfg)
if err != nil {
return fmt.Errorf("failed to configure dev: %w", err)
}
logging.Log.WriteInfof("ip link set up dev %s type wireguard", params.IfName)
return nil
}
// flushInterface flushes the specified interface
func flushInterface(ifName string) error {
_, err := net.InterfaceByName(ifName)
// Add an address to the given interface
func (m *WgInterfaceManipulatorImpl) AddAddress(ifName string, addr string) error {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return &WgError{msg: fmt.Sprintf("Interface %s does not exist cannot flush", ifName)}
return fmt.Errorf("failed to create config: %w", err)
}
defer rtnl.Close()
err = rtnl.AddAddress(ifName, addr)
if err != nil {
err = fmt.Errorf("failed to add address: %w", err)
}
cmd := exec.Command("/usr/bin/ip", "addr", "flush", "dev", ifName)
if err := cmd.Run(); err != nil {
logging.Log.WriteErrorf(fmt.Sprintf("%s error flushing interface %s", err.Error(), ifName))
return &WgError{msg: fmt.Sprintf("Failed to flush interface %s", ifName)}
}
return nil
return err
}
// EnableInterface flushes the interface and sets the ip address of the
// interface
func (m *WgInterfaceManipulatorImpl) EnableInterface(ifName string, ip string) error {
if len(ifName) == 0 {
return errors.New("ifName not provided")
}
err := flushInterface(ifName)
// RemoveInterface implements WgInterfaceManipulator.
func (*WgInterfaceManipulatorImpl) RemoveInterface(ifName string) error {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return err
return fmt.Errorf("failed to create config: %w", err)
}
defer rtnl.Close()
cmd := exec.Command("/usr/bin/ip", "link", "set", "up", "dev", ifName)
if err := cmd.Run(); err != nil {
return err
}
hostIp, _, err := net.ParseCIDR(ip)
if err != nil {
return err
}
cmd = exec.Command("/usr/bin/ip", "addr", "add", hostIp.String()+"/64", "dev", ifName)
if err := cmd.Run(); err != nil {
return err
}
return nil
return rtnl.DeleteLink(ifName)
}
func NewWgInterfaceManipulator(client *wgctrl.Client) WgInterfaceManipulator {

1
smegmesh-web Submodule

Submodule smegmesh-web added at c1128bcd98