1
0
forked from extern/smegmesh

Compare commits

...

41 Commits

Author SHA1 Message Date
4c54022f63 25-modify-code-to-use-public-api
Modify the code to use a public IP address by default if none is
specified
2023-11-22 10:41:54 +00:00
bf0724f6e5 Merge pull request #24 from tim-beatham/24-keepalive-holepunch
24 keepalive holepunch
2023-11-21 21:28:16 +00:00
624bd6e921 24-keepalive
Persistent keep alive working
2023-11-21 21:26:31 +00:00
7b939e0468 24-keepalive-holepunch
Added the ability to hole punch NAT
2023-11-21 20:42:43 +00:00
6e201ebaf5 24-keepalive-holepunch
Nodes acting as peers and nodes acting as clients
2023-11-21 16:42:49 +00:00
06542da03c main
Fixed problems with timestamp not updating
2023-11-21 13:31:34 +00:00
0d63cd6624 main
Adding words.txt for what words
2023-11-20 18:12:58 +00:00
f13319cfc1 Merge pull request #22 from tim-beatham/21-phonetic-words-ipv6
21 phonetic words ipv6
2023-11-20 18:08:49 +00:00
95f4495b0b 21-phonetic-words-ipv6
Simple what 8 words implementation
2023-11-20 18:07:52 +00:00
330fa74ef4 IPv6 What 8 Words
what 8 words for ipv6 started
2023-11-20 15:22:32 +00:00
3e5b57e41f Merge pull request #20 from tim-beatham/19-hash-wg-interface
Hashing the WireGuard interface
2023-11-20 13:04:19 +00:00
b179cd3cf4 Hashing the WireGuard interface
Hashing the interface and using ephmeral ports so that the admin doesn't
choose an interface and port combination. An administrator can alteranatively
decide to provide port but this isn't critical.
2023-11-20 13:03:42 +00:00
8f211aa116 Merge pull request #18 from tim-beatham/26-performance-testing
Stubbing out WireGuard components
2023-11-20 11:29:37 +00:00
388153e706 Stubbing out WireGuard components
Stubbing our WireGuard components so that I can use docker/podman
network_mode=host. This is much more efficient than the docker/podman
userspace network.
2023-11-20 11:28:12 +00:00
023565d985 Merge pull request #17 from tim-beatham/25-ability-to-aliases
25 ability to aliases
2023-11-17 22:20:57 +00:00
36c264b38e 25-ability-aliases
Fixed unit tests failing
2023-11-17 22:18:53 +00:00
68db795f47 Ability to specify aliases
Ability to specify aliases that automatically append to /etc/hosts
2023-11-17 22:13:51 +00:00
f6160fe138 Adding aliases that automatically gets added 2023-11-17 19:13:20 +00:00
2c5289afb0 Merge pull request #16 from tim-beatham/15-add-rest-api
Developed a rest API
2023-11-15 12:57:05 +00:00
7199d07a76 Added smegmesh submodule 2023-11-13 10:46:52 +00:00
5f176e731f Developed a rest API 2023-11-13 10:44:14 +00:00
44f119b45c Updating examples 2023-11-08 09:19:24 +00:00
5215d5d54d Merge pull request #14 from tim-beatham/13-netlink-api
Removed interface manipulation via os.Exec into
2023-11-07 19:53:39 +00:00
1a864b7c80 Removed interface manipulation via os.Exec into
rtnetlink calls
2023-11-07 19:48:53 +00:00
4c19ebd81f Merge pull request #12 from tim-beatham/11-health-system
11 health system
2023-11-06 13:40:04 +00:00
acbeb689b5 Prune nodes if they exceed their timeout time 2023-11-06 13:37:28 +00:00
bc6cd4fdd5 Modified syncer 2023-11-06 10:05:23 +00:00
c88012cf71 Added health system to count how many times a node
fails to conenct.
2023-11-06 09:54:06 +00:00
4dc85f3861 Merge pull request #10 from tim-beatham/9-add-ci-support
9 add ci support
2023-11-05 18:07:52 +00:00
ef614f5961 Add cert dependencies 2023-11-05 18:06:24 +00:00
9454d62417 Adding stubs and writing tests 2023-11-05 18:03:58 +00:00
bb07d35dcb Unit testing the automerge library and lib functions 2023-11-05 12:13:40 +00:00
76dda2cf6f Update go.mod 2023-11-05 10:54:38 +00:00
1b286dd3c1 Update go.yml 2023-11-05 10:53:57 +00:00
2d45c2d298 Run go mod tidy in workflow 2023-11-05 10:51:24 +00:00
900c67a121 Update go.mod 2023-11-05 10:49:18 +00:00
b2fa08a642 Reverted go version 2023-11-05 10:48:35 +00:00
a4e9a5cd0f Updated go version in workflow 2023-11-05 10:47:10 +00:00
275eb423fb Create GitHub hosted test runner go.yml 2023-11-05 10:45:39 +00:00
d17dce3b1e Added clustering and clean up 2023-11-03 15:26:09 +00:00
e2c6db3a4f Merge pull request #8 from tim-beatham/7-create-rotating-window-of-connections
Implemented clustering betweeen nodes
2023-11-03 15:25:30 +00:00
75 changed files with 4920 additions and 471 deletions

31
.github/workflows/go.yml vendored Normal file
View File

@ -0,0 +1,31 @@
# This workflow will build a golang project
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go
name: Go
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: '1.21'
- name: Tidy
run: go mod tidy
- name: Build
run: go build -v ./...
- name: Test
run: go test -v ./...

3
.gitmodules vendored Normal file
View File

@ -0,0 +1,3 @@
[submodule "smegmesh-web"]
path = smegmesh-web
url = git@github.com:tim-beatham/smegmesh-web.git

11
Containerfile Normal file
View File

@ -0,0 +1,11 @@
FROM docker.io/library/golang:bookworm
COPY ./ /wgmesh
RUN apt-get update && apt-get install -y \
wireguard \
wireguard-tools \
iproute2 \
iputils-ping \
tmux \
vim
WORKDIR /wgmesh
RUN go build -o /usr/local/bin ./...

1
Dockerfile Symbolic link
View File

@ -0,0 +1 @@
Containerfile

19
cmd/api/main.go Normal file
View File

@ -0,0 +1,19 @@
package main
import (
"log"
"github.com/tim-beatham/wgmesh/pkg/api"
)
func main() {
apiServer, err := api.NewSmegServer(api.ApiServerConf{
WordsFile: "./cmd/api/words.txt",
})
if err != nil {
log.Fatal(err.Error())
}
apiServer.Run(":8080")
}

257
cmd/api/words.txt Normal file
View File

@ -0,0 +1,257 @@
be
to
of
it
in
we
do
he
on
go
at
if
or
up
by
hi
the
and
you
not
for
but
say
get
she
one
all
can
out
who
now
see
way
how
lot
yes
use
any
day
try
put
let
why
new
off
big
too
ask
man
bit
end
may
own
run
pay
job
old
kid
bad
few
ago
far
buy
set
guy
car
sit
war
win
yet
top
law
cut
low
die
eat
age
hit
air
add
boy
act
tax
oil
eye
son
key
fun
dad
dog
arm
fly
box
gas
lie
hot
gun
per
art
red
fit
bed
fan
mix
mom
sex
bus
fix
bar
lay
ice
bet
bag
due
aid
tie
leg
ban
odd
cup
dry
cry
rid
pop
sir
cat
map
sad
sea
aim
sun
fat
row
egg
tea
god
wed
tip
ear
hat
net
ill
dig
fee
mad
gap
nor
bid
era
toy
sky
bin
owe
wet
tap
pro
ski
cow
pen
van
web
pot
sum
cap
log
pub
pig
joy
raw
rat
via
lip
two
six
ten
lab
ton
mid
bat
hip
gut
sin
non
rub
sub
par
pre
ray
cue
dye
fin
ion
neo
hey
wow
mum
bye
aye
jet
sue
pet
flu
cop
ooh
rip
spy
pie
bug
gum
wan
rap
nut
beg
pin
pit
jam
tag
fax
vet
fry
pad
lad
mud
bay
con
pan
gee
toe
dip
shy
gym
zoo
fox
bow
tin
hop
wee
kit
opt
vow
sew
cab
bee
rob
rig
yep
ego
rib
nod
hug
lap
ash
hum
dam
bum
yen
jar

View File

@ -16,7 +16,6 @@ const SockAddr = "/tmp/wgmesh_ipc.sock"
type CreateMeshParams struct {
Client *ipcRpc.Client
IfName string
WgPort int
Endpoint string
}
@ -24,7 +23,6 @@ type CreateMeshParams struct {
func createMesh(args *CreateMeshParams) string {
var reply string
newMeshParams := ipc.NewMeshArgs{
IfName: args.IfName,
WgPort: args.WgPort,
Endpoint: args.Endpoint,
}
@ -68,7 +66,6 @@ func joinMesh(params *JoinMeshParams) string {
args := ipc.JoinMeshArgs{
MeshId: params.MeshId,
IpAdress: params.IpAddress,
IfName: params.IfName,
Port: params.WgPort,
}
@ -171,6 +168,68 @@ func putDescription(client *ipcRpc.Client, description string) {
fmt.Println(reply)
}
// putAlias: puts an alias for the node
func putAlias(client *ipcRpc.Client, alias string) {
var reply string
err := client.Call("IpcHandler.PutAlias", &alias, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func setService(client *ipcRpc.Client, service, value string) {
var reply string
serviceArgs := &ipc.PutServiceArgs{
Service: service,
Value: value,
}
err := client.Call("IpcHandler.PutService", serviceArgs, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func deleteService(client *ipcRpc.Client, service string) {
var reply string
err := client.Call("IpcHandler.PutService", &service, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func getNode(client *ipcRpc.Client, nodeId, meshId string) {
var reply string
args := &ipc.GetNodeArgs{
NodeId: nodeId,
MeshId: meshId,
}
err := client.Call("IpcHandler.GetNode", &args, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func main() {
parser := argparse.NewParser("wg-mesh",
"wg-mesh Manipulate WireGuard meshes")
@ -178,25 +237,24 @@ func main() {
newMeshCmd := parser.NewCommand("new-mesh", "Create a new mesh")
listMeshCmd := parser.NewCommand("list-meshes", "List meshes the node is connected to")
joinMeshCmd := parser.NewCommand("join-mesh", "Join a mesh network")
// getMeshCmd := parser.NewCommand("get-mesh", "Get a mesh network")
enableInterfaceCmd := parser.NewCommand("enable-interface", "Enable A Specific Mesh Interface")
getGraphCmd := parser.NewCommand("get-graph", "Convert a mesh into DOT format")
leaveMeshCmd := parser.NewCommand("leave-mesh", "Leave a mesh network")
queryMeshCmd := parser.NewCommand("query-mesh", "Query a mesh network using JMESPath")
putDescriptionCmd := parser.NewCommand("put-description", "Place a description for the node")
putAliasCmd := parser.NewCommand("put-alias", "Place an alias for the node")
setServiceCmd := parser.NewCommand("set-service", "Place a service into your advertisements")
deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements")
getNodeCmd := parser.NewCommand("get-node", "Get a specific node from the mesh")
var newMeshIfName *string = newMeshCmd.String("f", "ifname", &argparse.Options{Required: true})
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{})
var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{})
var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true})
var joinMeshIfName *string = joinMeshCmd.String("f", "ifname", &argparse.Options{Required: true})
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{})
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{})
// var getMeshId *string = getMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true})
var getGraphMeshId *string = getGraphCmd.String("m", "mesh", &argparse.Options{Required: true})
@ -208,6 +266,16 @@ func main() {
var description *string = putDescriptionCmd.String("d", "description", &argparse.Options{Required: true})
var alias *string = putAliasCmd.String("a", "alias", &argparse.Options{Required: true})
var serviceKey *string = setServiceCmd.String("s", "service", &argparse.Options{Required: true})
var serviceValue *string = setServiceCmd.String("v", "value", &argparse.Options{Required: true})
var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{Required: true})
var getNodeNodeId *string = getNodeCmd.String("n", "nodeid", &argparse.Options{Required: true})
var getNodeMeshId *string = getNodeCmd.String("m", "meshid", &argparse.Options{Required: true})
err := parser.Parse(os.Args)
if err != nil {
@ -224,7 +292,6 @@ func main() {
if newMeshCmd.Happened() {
fmt.Println(createMesh(&CreateMeshParams{
Client: client,
IfName: *newMeshIfName,
WgPort: *newMeshPort,
Endpoint: *newMeshEndpoint,
}))
@ -237,7 +304,6 @@ func main() {
if joinMeshCmd.Happened() {
fmt.Println(joinMesh(&JoinMeshParams{
Client: client,
IfName: *joinMeshIfName,
WgPort: *joinMeshPort,
IpAddress: *joinMeshIpAddress,
MeshId: *joinMeshId,
@ -245,10 +311,6 @@ func main() {
}))
}
// if getMeshCmd.Happened() {
// getMesh(client, *getMeshId)
// }
if getGraphCmd.Happened() {
getGraph(client, *getGraphMeshId)
}
@ -268,4 +330,20 @@ func main() {
if putDescriptionCmd.Happened() {
putDescription(client, *description)
}
if putAliasCmd.Happened() {
putAlias(client, *alias)
}
if setServiceCmd.Happened() {
setService(client, *serviceKey, *serviceValue)
}
if deleteServiceCmd.Happened() {
deleteService(client, *deleteServiceKey)
}
if getNodeCmd.Happened() {
getNode(client, *getNodeNodeId, *getNodeMeshId)
}
}

View File

@ -2,6 +2,7 @@ certificatePath: "/wgmesh/cert/cert.pem"
privateKeyPath: "/wgmesh/cert/priv.pem"
caCertificatePath: "/wgmesh/cert/cacert.pem"
skipCertVerification: true
timeout: 5
gRPCPort: "21906"
advertiseRoutes: true
clusterSize: 32
@ -9,4 +10,5 @@ syncRate: 1
interClusterChance: 0.15
branchRate: 3
infectionCount: 3
keepAliveRate: 60
keepAliveTime: 10
pruneTime: 20

View File

@ -1,7 +1,8 @@
package main
import (
"log"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
@ -9,7 +10,7 @@ import (
ctrlserver "github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/middleware"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/robin"
"github.com/tim-beatham/wgmesh/pkg/sync"
"github.com/tim-beatham/wgmesh/pkg/timestamp"
@ -35,14 +36,18 @@ func main() {
return
}
if conf.Profile {
go func() {
http.ListenAndServe("localhost:6060", nil)
}()
}
var robinRpc robin.WgRpc
var robinIpc robin.IpcHandler
var authProvider middleware.AuthRpcProvider
var syncProvider sync.SyncServiceImpl
ctrlServerParams := ctrlserver.NewCtrlServerParams{
Conf: conf,
AuthProvider: &authProvider,
CtrlProvider: &robinRpc,
SyncProvider: &syncProvider,
Client: client,
@ -53,6 +58,7 @@ func main() {
syncRequester := sync.NewSyncRequester(ctrlServer)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester)
timestampScheduler := timestamp.NewTimestampScheduler(ctrlServer)
pruneScheduler := mesh.NewPruner(ctrlServer.MeshManager, *conf)
robinIpcParams := robin.RobinIpcParams{
CtrlServer: ctrlServer,
@ -66,11 +72,12 @@ func main() {
return
}
log.Println("Running IPC Handler")
logging.Log.WriteInfof("Running IPC Handler")
go ipc.RunIpcHandler(&robinIpc)
go syncScheduler.Run()
go timestampScheduler.Run()
go pruneScheduler.Run()
closeResources := func() {
logging.Log.WriteInfof("Closing resources")

View File

@ -0,0 +1,95 @@
version: '3'
networks:
net-1:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.0.0/17
net-2:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.155.0/17
services:
wg-1:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-2:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-3:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-4:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
sysctls:
- net.ipv6.conf.all.forwarding=1
networks:
- net-1
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-5:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-6:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-7:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"

View File

@ -0,0 +1,14 @@
certificatePath: "/wgmesh/cert/cert.pem"
privateKeyPath: "/wgmesh/cert/priv.pem"
caCertificatePath: "/wgmesh/cert/cacert.pem"
skipCertVerification: true
timeout: 5
gRPCPort: "21906"
advertiseRoutes: true
clusterSize: 32
syncRate: 1
interClusterChance: 0.15
branchRate: 3
infectionCount: 3
keepAliveTime: 10
pruneTime: 20

View File

@ -0,0 +1,42 @@
version: '3'
networks:
net-1:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.0.0/17
services:
wg-1:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-2:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-3:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"

View File

@ -0,0 +1,14 @@
certificatePath: "/wgmesh/cert/cert.pem"
privateKeyPath: "/wgmesh/cert/priv.pem"
caCertificatePath: "/wgmesh/cert/cacert.pem"
skipCertVerification: true
timeout: 5
gRPCPort: "21906"
advertiseRoutes: true
clusterSize: 32
syncRate: 1
interClusterChance: 0.15
branchRate: 3
infectionCount: 3
keepAliveTime: 10
pruneTime: 20

24
go.mod
View File

@ -1,13 +1,16 @@
module github.com/tim-beatham/wgmesh
go 1.21.1
go 1.21.3
require (
github.com/akamensky/argparse v1.4.0
github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9
github.com/gin-gonic/gin v1.9.1
github.com/google/uuid v1.3.0
github.com/jmespath/go-jmespath v0.4.0
github.com/jsimonetti/rtnetlink v1.3.5
github.com/sirupsen/logrus v1.9.3
golang.org/x/sys v0.14.0
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
google.golang.org/grpc v1.58.1
google.golang.org/protobuf v1.31.0
@ -15,16 +18,33 @@ require (
)
require (
github.com/bytedance/sonic v1.9.1 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.13.0 // indirect
golang.org/x/net v0.15.0 // indirect
golang.org/x/sync v0.3.0 // indirect
golang.org/x/sys v0.12.0 // indirect
golang.org/x/text v0.13.0 // indirect
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect

238
pkg/api/apiserver.go Normal file
View File

@ -0,0 +1,238 @@
package api
import (
"fmt"
"net/http"
ipcRpc "net/rpc"
"github.com/gin-gonic/gin"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/what8words"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
type ApiServer interface {
GetMeshes(c *gin.Context)
Run(addr string) error
}
type SmegServer struct {
router *gin.Engine
client *ipcRpc.Client
words *what8words.What8Words
}
func (s *SmegServer) routeToApiRoute(meshNode ctrlserver.MeshNode) []Route {
routes := make([]Route, len(meshNode.Routes))
for index, route := range meshNode.Routes {
word, err := s.words.Convert(route)
if err != nil {
fmt.Println(err.Error())
}
routes[index] = Route{
Prefix: route,
RouteId: word,
}
}
return routes
}
func (s *SmegServer) meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode {
if meshNode.Routes == nil {
meshNode.Routes = make([]string, 0)
}
alias := meshNode.Alias
if alias == "" {
alias, _ = s.words.ConvertIdentifier(meshNode.WgHost)
}
return &SmegNode{
WgHost: meshNode.WgHost,
WgEndpoint: meshNode.WgEndpoint,
Endpoint: meshNode.HostEndpoint,
Timestamp: int(meshNode.Timestamp),
Description: meshNode.Description,
Routes: s.routeToApiRoute(meshNode),
PublicKey: meshNode.PublicKey,
Alias: alias,
Services: meshNode.Services,
}
}
func (s *SmegServer) meshToAPIMesh(meshId string, nodes []ctrlserver.MeshNode) SmegMesh {
var smegMesh SmegMesh
smegMesh.MeshId = meshId
smegMesh.Nodes = make(map[string]SmegNode)
for _, node := range nodes {
smegMesh.Nodes[node.WgHost] = *s.meshNodeToAPIMeshNode(node)
}
return smegMesh
}
// CreateMesh: creates a mesh network
func (s *SmegServer) CreateMesh(c *gin.Context) {
var createMesh CreateMeshRequest
if err := c.ShouldBindJSON(&createMesh); err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
ipcRequest := ipc.NewMeshArgs{
WgPort: createMesh.WgPort,
}
var reply string
err := s.client.Call("IpcHandler.CreateMesh", &ipcRequest, &reply)
if err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
c.JSON(http.StatusOK, &gin.H{
"meshid": reply,
})
}
// JoinMesh: joins a mesh network
func (s *SmegServer) JoinMesh(c *gin.Context) {
var joinMesh JoinMeshRequest
if err := c.ShouldBindJSON(&joinMesh); err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
ipcRequest := ipc.JoinMeshArgs{
MeshId: joinMesh.MeshId,
IpAdress: joinMesh.Bootstrap,
Port: joinMesh.WgPort,
}
var reply string
err := s.client.Call("IpcHandler.JoinMesh", &ipcRequest, &reply)
if err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
c.JSON(http.StatusOK, &gin.H{
"status": "success",
})
}
// GetMesh: given a meshId returns the corresponding mesh
// network.
func (s *SmegServer) GetMesh(c *gin.Context) {
meshidParam := c.Param("meshid")
var meshid string = meshidParam
getMeshReply := new(ipc.GetMeshReply)
err := s.client.Call("IpcHandler.GetMesh", &meshid, &getMeshReply)
if err != nil {
c.JSON(http.StatusNotFound,
&gin.H{
"error": fmt.Sprintf("could not find mesh %s", meshidParam),
})
return
}
mesh := s.meshToAPIMesh(meshidParam, getMeshReply.Nodes)
c.JSON(http.StatusOK, mesh)
}
func (s *SmegServer) GetMeshes(c *gin.Context) {
listMeshesReply := new(ipc.ListMeshReply)
err := s.client.Call("IpcHandler.ListMeshes", "", &listMeshesReply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
c.JSON(http.StatusBadRequest, nil)
return
}
meshes := make([]SmegMesh, 0)
for _, mesh := range listMeshesReply.Meshes {
getMeshReply := new(ipc.GetMeshReply)
err := s.client.Call("IpcHandler.GetMesh", &mesh, &getMeshReply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
c.JSON(http.StatusBadRequest, nil)
return
}
meshes = append(meshes, s.meshToAPIMesh(mesh, getMeshReply.Nodes))
}
c.JSON(http.StatusOK, meshes)
}
func (s *SmegServer) Run(addr string) error {
logging.Log.WriteInfof("Running API server")
return s.router.Run(addr)
}
func NewSmegServer(conf ApiServerConf) (ApiServer, error) {
client, err := ipcRpc.DialHTTP("unix", SockAddr)
if err != nil {
return nil, err
}
words, err := what8words.NewWhat8Words(conf.WordsFile)
if err != nil {
return nil, err
}
router := gin.Default()
router.Use(gin.LoggerWithConfig(gin.LoggerConfig{
Output: logging.Log.Writer(),
}))
smegServer := &SmegServer{
router: router,
client: client,
words: words,
}
router.GET("/meshes", smegServer.GetMeshes)
router.GET("/mesh/:meshid", smegServer.GetMesh)
router.POST("/mesh/create", smegServer.CreateMesh)
router.POST("/mesh/join", smegServer.JoinMesh)
return smegServer, nil
}

37
pkg/api/types.go Normal file
View File

@ -0,0 +1,37 @@
package api
type Route struct {
RouteId string `json:"routeId"`
Prefix string `json:"prefix"`
}
type SmegNode struct {
Alias string `json:"alias"`
WgHost string `json:"wgHost"`
WgEndpoint string `json:"wgEndpoint"`
Endpoint string `json:"endpoint"`
Timestamp int `json:"timestamp"`
Description string `json:"description"`
PublicKey string `json:"publicKey"`
Routes []Route `json:"routes"`
Services map[string]string `json:"services"`
}
type SmegMesh struct {
MeshId string `json:"meshid"`
Nodes map[string]SmegNode `json:"nodes"`
}
type CreateMeshRequest struct {
WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"`
}
type JoinMeshRequest struct {
WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"`
Bootstrap string `json:"bootstrap" binding:"required"`
MeshId string `json:"meshid" binding:"required"`
}
type ApiServerConf struct {
WordsFile string
}

View File

@ -18,13 +18,14 @@ import (
// CrdtMeshManager manages nodes in the crdt mesh
type CrdtMeshManager struct {
MeshId string
IfName string
NodeId string
Client *wgctrl.Client
doc *automerge.Doc
LastHash automerge.ChangeHash
conf *conf.WgMeshConfiguration
MeshId string
IfName string
Client *wgctrl.Client
doc *automerge.Doc
LastHash automerge.ChangeHash
conf *conf.WgMeshConfiguration
cache *MeshCrdt
lastCacheHash automerge.ChangeHash
}
func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
@ -34,15 +35,59 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
panic("node must be of type *MeshNodeCrdt")
}
crdt.Routes = make(map[string]interface{})
crdt.Services = make(map[string]string)
crdt.Timestamp = time.Now().Unix()
c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt)
nodeVal, _ := c.doc.Path("nodes").Map().Get(crdt.HostEndpoint)
nodeVal.Map().Set("routes", automerge.NewMap())
}
func (c *CrdtMeshManager) isPeer(nodeId string) bool {
node, err := c.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return false
}
nodeType, err := node.Map().Get("type")
if err != nil || nodeType.Kind() != automerge.KindStr {
return false
}
return nodeType.Str() == string(conf.PEER_ROLE)
}
func (c *CrdtMeshManager) GetPeers() []string {
keys, _ := c.doc.Path("nodes").Map().Keys()
keys = lib.Filter(keys, func(s string) bool {
return c.isPeer(s)
})
return keys
}
// GetMesh(): Converts the document into a struct
func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
return automerge.As[*MeshCrdt](c.doc.Root())
changes, err := c.doc.Changes(c.lastCacheHash)
if err != nil {
return nil, err
}
if c.cache == nil || len(changes) > 0 {
c.lastCacheHash = c.LastHash
cache, err := automerge.As[*MeshCrdt](c.doc.Root())
if err != nil {
return nil, err
}
c.cache = cache
}
return c.cache, nil
}
// GetMeshId returns the meshid of the mesh
@ -82,23 +127,23 @@ func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, erro
manager.IfName = params.DevName
manager.Client = params.Client
manager.conf = &params.Conf
manager.cache = nil
return &manager, nil
}
func (c *CrdtMeshManager) removeNode(endpoint string) error {
err := c.doc.Path("nodes").Map().Delete(endpoint)
if err != nil {
return err
}
return nil
// NodeExists: returns true if the node exists. Returns false
func (m *CrdtMeshManager) NodeExists(key string) bool {
node, err := m.doc.Path("nodes").Map().Get(key)
return node.Kind() == automerge.KindMap && err == nil
}
// GetNode: returns a mesh node crdt.Close releases resources used by a Client.
func (m *CrdtMeshManager) GetNode(endpoint string) (*MeshNodeCrdt, error) {
func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
node, err := m.doc.Path("nodes").Map().Get(endpoint)
if node.Kind() != automerge.KindMap {
return nil, fmt.Errorf("GetNode: something went wrong %s is not a map type")
}
if err != nil {
return nil, err
}
@ -176,7 +221,7 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro
}
if node.Kind() != automerge.KindMap {
return errors.New(fmt.Sprintf("%s does not exist", nodeId))
return fmt.Errorf("%s does not exist", nodeId)
}
err = node.Map().Set("description", description)
@ -188,6 +233,72 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro
return err
}
func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
return err
}
if node.Kind() != automerge.KindMap {
return fmt.Errorf("%s does not exist", nodeId)
}
err = node.Map().Set("alias", alias)
if err == nil {
logging.Log.WriteInfof("Updated Alias for %s to %s", nodeId, alias)
}
return err
}
func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return fmt.Errorf("AddService: node %s does not exist", nodeId)
}
service, err := node.Map().Get("services")
if err != nil {
return err
}
if service.Kind() != automerge.KindMap {
return fmt.Errorf("AddService: services property does not exist in node")
}
return service.Map().Set(key, value)
}
func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return fmt.Errorf("RemoveService: node %s does not exist", nodeId)
}
service, err := node.Map().Get("services")
if err != nil {
return err
}
if service.Kind() != automerge.KindMap {
return fmt.Errorf("services property does not exist")
}
err = service.Map().Delete(key)
if err != nil {
return fmt.Errorf("service %s does not exist", key)
}
return nil
}
// AddRoutes: adds routes to the specific nodeId
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -197,6 +308,10 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
return err
}
if nodeVal.Kind() != automerge.KindMap {
return fmt.Errorf("node does not exist")
}
routeMap, err := nodeVal.Map().Get("routes")
if err != nil {
@ -210,14 +325,90 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
return err
}
}
return nil
}
// DeleteRoutes deletes the specified routes
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
return err
}
if nodeVal.Kind() != automerge.KindMap {
return fmt.Errorf("node is not a map")
}
routeMap, err := nodeVal.Map().Get("routes")
if err != nil {
return err
}
for _, route := range routes {
err = routeMap.Map().Delete(route)
}
return err
}
func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
return NewAutomergeSync(m)
}
func (m *CrdtMeshManager) Prune(pruneTime int) error {
nodes, err := m.doc.Path("nodes").Get()
if err != nil {
return err
}
if nodes.Kind() != automerge.KindMap {
return errors.New("node must be a map")
}
values, err := nodes.Map().Values()
if err != nil {
return err
}
deletionNodes := make([]string, 0)
for nodeId, node := range values {
if node.Kind() != automerge.KindMap {
return errors.New("node must be a map")
}
nodeMap := node.Map()
timeStamp, err := nodeMap.Get("timestamp")
if err != nil {
return err
}
if timeStamp.Kind() != automerge.KindInt64 {
return errors.New("timestamp is not int64")
}
timeValue := timeStamp.Int64()
nowValue := time.Now().Unix()
if nowValue-timeValue >= int64(pruneTime) {
deletionNodes = append(deletionNodes, nodeId)
}
}
for _, node := range deletionNodes {
logging.Log.WriteInfof("Pruning %s", node)
nodes.Map().Delete(node)
}
return nil
}
func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int {
return strings.Compare(m1.PublicKey, m2.PublicKey)
}
@ -266,6 +457,26 @@ func (m *MeshNodeCrdt) GetIdentifier() string {
return strings.Join(constituents, ":")
}
func (m *MeshNodeCrdt) GetAlias() string {
return m.Alias
}
func (m *MeshNodeCrdt) GetServices() map[string]string {
services := make(map[string]string)
for key, service := range m.Services {
services[key] = service
}
return services
}
// GetType refers to the type of the node. Peer means that the node is globally accessible
// Client means the node is only accessible through another peer
func (n *MeshNodeCrdt) GetType() conf.NodeType {
return conf.NodeType(n.Type)
}
func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
nodes := make(map[string]mesh.MeshNode)
@ -278,6 +489,9 @@ func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
Timestamp: node.Timestamp,
Routes: node.Routes,
Description: node.Description,
Alias: node.Alias,
Services: node.GetServices(),
Type: node.Type,
}
}

View File

@ -0,0 +1,366 @@
package crdt
import (
"slices"
"strings"
"testing"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type TestParams struct {
manager *CrdtMeshManager
}
func setUpTests() *TestParams {
manager, _ := NewCrdtNodeManager(&NewCrdtNodeMangerParams{
MeshId: "timsmesh123",
DevName: "wg0",
Port: 5000,
Client: nil,
Conf: conf.WgMeshConfiguration{},
})
return &TestParams{
manager: manager,
}
}
func getTestNode() mesh.MeshNode {
return &MeshNodeCrdt{
HostEndpoint: "public-endpoint:8080",
WgEndpoint: "public-endpoint:21906",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d218/128",
PublicKey: "AAAAAAAAAAAA",
Timestamp: time.Now().Unix(),
Description: "A node that we are adding",
}
}
func getTestNode2() mesh.MeshNode {
return &MeshNodeCrdt{
HostEndpoint: "public-endpoint:8081",
WgEndpoint: "public-endpoint:21907",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d219/128",
PublicKey: "BBBBBBBBB",
Timestamp: time.Now().Unix(),
Description: "A node that we are adding",
}
}
func TestAddNodeNodeExists(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getTestNode())
node, err := testParams.manager.GetNode("public-endpoint:8080")
if err != nil {
t.Error(err)
}
if node == nil {
t.Fatalf(`node not added to the mesh when it should be`)
}
}
func TestAddNodeAddRoute(t *testing.T) {
testParams := setUpTests()
testNode := getTestNode()
testParams.manager.AddNode(testNode)
testParams.manager.AddRoutes(testNode.GetHostEndpoint(), "fd:1c64:1d00::/48")
updatedNode, err := testParams.manager.GetNode(testNode.GetHostEndpoint())
if err != nil {
t.Error(err)
}
if updatedNode == nil {
t.Fatalf(`Node does not exist in the mesh`)
}
routes := updatedNode.GetRoutes()
if !slices.Contains(routes, "fd:1c64:1d00::/48") {
t.Fatal("Route node not added")
}
if len(routes) != 1 {
t.Fatal(`Route length mismatch`)
}
}
func TestGetMeshIdReturnsTheMeshId(t *testing.T) {
testParams := setUpTests()
if len(testParams.manager.GetMeshId()) == 0 {
t.Fatal(`Meshid is less than 0`)
}
}
// Add 2 nodes to the mesh and then get the mesh.s
// It should return the 2 nodes that have been added to the mesh
func TestAdd2NodesGetMesh(t *testing.T) {
testParams := setUpTests()
node1 := getTestNode()
node2 := getTestNode2()
testParams.manager.AddNode(node1)
testParams.manager.AddNode(node2)
mesh, err := testParams.manager.GetMesh()
if err != nil {
t.Error(err)
}
nodes := mesh.GetNodes()
if len(nodes) != 2 {
t.Fatalf(`Mismatch in node slice`)
}
for _, node := range nodes {
if node.GetHostEndpoint() != node1.GetHostEndpoint() && node.GetHostEndpoint() != node2.GetHostEndpoint() {
t.Fatalf(`Node should not exist`)
}
}
}
func TestSaveMeshReturnsMeshBytes(t *testing.T) {
testParams := setUpTests()
node1 := getTestNode()
testParams.manager.AddNode(node1)
bytes := testParams.manager.Save()
if len(bytes) <= 0 {
t.Fatalf(`bytes in the mesh is less than 0`)
}
}
func TestSaveMeshThenLoad(t *testing.T) {
testParams := setUpTests()
testParams2 := setUpTests()
node1 := getTestNode()
testParams.manager.AddNode(node1)
bytes := testParams.manager.Save()
err := testParams2.manager.Load(bytes)
if err != nil {
t.Error(err)
}
if len(bytes) <= 0 {
t.Fatalf(`bytes in the mesh is less than 0`)
}
mesh2, err := testParams2.manager.GetMesh()
if err != nil {
t.Error(err)
}
nodes := mesh2.GetNodes()
if lib.MapValues(nodes)[0].GetHostEndpoint() != node1.GetHostEndpoint() {
t.Fatalf(`Node should be in the list of nodes`)
}
}
func TestLengthNoNodes(t *testing.T) {
testParams := setUpTests()
if testParams.manager.Length() != 0 {
t.Fatalf(`Number of nodes should be 0`)
}
}
func TestLength1Node(t *testing.T) {
testParams := setUpTests()
node := getTestNode()
testParams.manager.AddNode(node)
if testParams.manager.Length() != 1 {
t.Fatalf(`Number of nodes should be 1`)
}
}
func TestLengthMultipleNodes(t *testing.T) {
testParams := setUpTests()
node := getTestNode()
node1 := getTestNode2()
testParams.manager.AddNode(node)
testParams.manager.AddNode(node1)
if testParams.manager.Length() != 2 {
t.Fatalf(`Number of nodes should be 2`)
}
}
func TestHasChangesNoChanges(t *testing.T) {
testParams := setUpTests()
if testParams.manager.HasChanges() {
t.Fatalf(`Should not have changes just created document`)
}
}
func TestHasChangesChanges(t *testing.T) {
testParams := setUpTests()
node := getTestNode()
testParams.manager.AddNode(node)
if !testParams.manager.HasChanges() {
t.Fatalf(`Should have changes just added node`)
}
}
func TestHasChangesSavedChanges(t *testing.T) {
testParams := setUpTests()
node := getTestNode()
testParams.manager.AddNode(node)
testParams.manager.SaveChanges()
if testParams.manager.HasChanges() {
t.Fatalf(`Should not have changes just saved document`)
}
}
func TestUpdateTimeStampNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.UpdateTimeStamp("AAAAAA")
if err == nil {
t.Fatalf(`Error should have returned`)
}
}
func TestUpdateTimeStampNodeExists(t *testing.T) {
testParams := setUpTests()
node := getTestNode()
testParams.manager.AddNode(node)
err := testParams.manager.UpdateTimeStamp(node.GetHostEndpoint())
if err != nil {
t.Error(err)
}
}
func TestSetDescriptionNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.SetDescription("AAAAA", "Bob 123")
if err == nil {
t.Fatalf(`Error should have returned`)
}
}
func TestSetDescriptionNodeExists(t *testing.T) {
testParams := setUpTests()
node := getTestNode()
err := testParams.manager.SetDescription(node.GetHostEndpoint(), "Bob 123")
if err == nil {
t.Fatalf(`Error should have returned`)
}
}
func TestAddRoutesNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.AddRoutes("AAAAA", "fd:1c64:1d00::/48")
if err == nil {
t.Error(err)
}
}
func TestCompareComparesByPublicKey(t *testing.T) {
node := getTestNode().(*MeshNodeCrdt)
node2 := getTestNode2().(*MeshNodeCrdt)
if node.Compare(node2) != -1 {
t.Fatalf(`node is alphabetically before node2`)
}
if node2.Compare(node) != 1 {
t.Fatalf(`node is alphabetical;y before node2`)
}
if node.Compare(node) != 0 {
t.Fatalf(`node is equal to node`)
}
}
func TestGetHostEndpoint(t *testing.T) {
node := getTestNode()
if (node.(*MeshNodeCrdt)).HostEndpoint != node.GetHostEndpoint() {
t.Fatalf(`get hostendpoint should get the host endpoint`)
}
}
func TestGetPublicKey(t *testing.T) {
key1, _ := wgtypes.GenerateKey()
node := getTestNode()
node.(*MeshNodeCrdt).PublicKey = key1.String()
pubKey, err := node.GetPublicKey()
if err != nil {
t.Error(err)
}
if pubKey.String() != key1.String() {
t.Fatalf(`Expected %s got %s`, key1.String(), pubKey.String())
}
}
func TestGetWgEndpoint(t *testing.T) {
node := getTestNode()
if node.(*MeshNodeCrdt).WgEndpoint != node.GetWgEndpoint() {
t.Fatal(`Did not return the correct wgEndpoint`)
}
}
func TestGetWgHost(t *testing.T) {
node := getTestNode()
ip := node.GetWgHost()
if node.(*MeshNodeCrdt).WgHost != ip.String() {
t.Fatal(`Did not parse WgHost correctly`)
}
}
func TestGetTimeStamp(t *testing.T) {
node := getTestNode()
if node.(*MeshNodeCrdt).Timestamp != node.GetTimeStamp() {
t.Fatal(`Did not return return the correct timestamp`)
}
}
func TestGetIdentifierDoesNotContainPrefix(t *testing.T) {
node := getTestNode()
if strings.Contains(node.GetIdentifier(), "/128") {
t.Fatal(`Identifier should not contain prefix`)
}
}

View File

@ -35,7 +35,10 @@ func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNod
WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()),
// Always set the routes as empty.
// Routes handled by external component
Routes: map[string]interface{}{},
Routes: map[string]interface{}{},
Description: "",
Alias: "",
Type: string(params.Role),
}
}
@ -48,7 +51,13 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string
} else if len(f.Config.Endpoint) != 0 {
hostName = f.Config.Endpoint
} else {
hostName = lib.GetOutboundIP().String()
ip, err := lib.GetPublicIP()
if err != nil {
return ""
}
hostName = ip.String()
}
return hostName

View File

@ -8,7 +8,10 @@ type MeshNodeCrdt struct {
WgHost string `automerge:"wgHost"`
Timestamp int64 `automerge:"timestamp"`
Routes map[string]interface{} `automerge:"routes"`
Alias string `automerge:"alias"`
Description string `automerge:"description"`
Services map[string]string `automerge:"services"`
Type string `automerge:"type"`
}
// MeshCrdt: Represents the mesh network as a whole

View File

@ -8,6 +8,21 @@ import (
"gopkg.in/yaml.v3"
)
type WgMeshConfigurationError struct {
msg string
}
func (m *WgMeshConfigurationError) Error() string {
return m.msg
}
type NodeType string
const (
PEER_ROLE NodeType = "peer"
CLIENT_ROLE NodeType = "client"
)
type WgMeshConfiguration struct {
// CertificatePath is the path to the certificate to use in mTLS
CertificatePath string `yaml:"certificatePath"`
@ -24,13 +39,119 @@ type WgMeshConfiguration struct {
AdvertiseRoutes bool `yaml:"advertiseRoutes"`
// Endpoint is the IP in which this computer is publicly reachable.
// usecase is when the node has multiple IP addresses
Endpoint string `yaml:"publicEndpoint"`
ClusterSize int `yaml:"clusterSize"`
SyncRate float64 `yaml:"syncRate"`
Endpoint string `yaml:"publicEndpoint"`
// ClusterSize size of the cluster to split on
ClusterSize int `yaml:"clusterSize"`
// SyncRate number of times per second to perform a sync
SyncRate float64 `yaml:"syncRate"`
// InterClusterChance proability of inter-cluster communication in a sync round
InterClusterChance float64 `yaml:"interClusterChance"`
BranchRate int `yaml:"branchRate"`
InfectionCount int `yaml:"infectionCount"`
KeepAliveRate int `yaml:"keepAliveRate"`
// BranchRate number of nodes to randomly communicate with
BranchRate int `yaml:"branchRate"`
// InfectionCount number of times we sync before we can no longer catch the udpate
InfectionCount int `yaml:"infectionCount"`
// KeepAliveTime number of seconds before we update node indicating that we are still alive
KeepAliveTime int `yaml:"keepAliveTime"`
// Timeout number of seconds before we consider the node as dead
Timeout int `yaml:"timeout"`
// PruneTime number of seconds before we consider the 'node' as dead
PruneTime int `yaml:"pruneTime"`
// Profile whether or not to include a http server that profiles the code
Profile bool `yaml:"profile"`
// StubWg whether or not to stub the WireGuard types
StubWg bool `yaml:"stubWg"`
// Role specifies whether or not the user is globally accessible.
// If the user is globaly accessible they specify themselves as a client.
Role NodeType `yaml:"role"`
// KeepAliveWg configures the implementation so that we send keep alive packets to peers.
// KeepAlive can only be set if role is type client
KeepAliveWg int `yaml:"keepAliveWg"`
}
func ValidateConfiguration(c *WgMeshConfiguration) error {
if len(c.CertificatePath) == 0 {
return &WgMeshConfigurationError{
msg: "A public certificate must be specified for mTLS",
}
}
if len(c.PrivateKeyPath) == 0 {
return &WgMeshConfigurationError{
msg: "A private key must be specified for mTLS",
}
}
if len(c.CaCertificatePath) == 0 {
return &WgMeshConfigurationError{
msg: "A ca certificate must be specified for mTLS",
}
}
if len(c.GrpcPort) == 0 {
return &WgMeshConfigurationError{
msg: "A grpc port must be specified",
}
}
if c.ClusterSize <= 0 {
return &WgMeshConfigurationError{
msg: "A cluster size must not be 0",
}
}
if c.SyncRate <= 0 {
return &WgMeshConfigurationError{
msg: "SyncRate cannot be negative",
}
}
if c.BranchRate <= 0 {
return &WgMeshConfigurationError{
msg: "Branch rate cannot be negative",
}
}
if c.InfectionCount <= 0 {
return &WgMeshConfigurationError{
msg: "Infection count cannot be less than 1",
}
}
if c.KeepAliveTime <= 0 {
return &WgMeshConfigurationError{
msg: "KeepAliveRate cannot be less than negative",
}
}
if c.InterClusterChance <= 0 {
return &WgMeshConfigurationError{
msg: "Intercluster chance cannot be less than 0",
}
}
if c.Timeout < 1 {
return &WgMeshConfigurationError{
msg: "Timeout should be greater than or equal to 1",
}
}
if c.PruneTime <= 1 {
return &WgMeshConfigurationError{
msg: "Prune time cannot be <= 1",
}
}
if c.KeepAliveTime <= 1 {
return &WgMeshConfigurationError{
msg: "Prune time cannot be less than keep alive time",
}
}
if c.Role == "" {
c.Role = PEER_ROLE
}
return nil
}
// ParseConfiguration parses the mesh configuration
@ -51,5 +172,5 @@ func ParseConfiguration(filePath string) (*WgMeshConfiguration, error) {
return nil, err
}
return &conf, nil
return &conf, ValidateConfiguration(&conf)
}

165
pkg/conf/conf_test.go Normal file
View File

@ -0,0 +1,165 @@
package conf
import "testing"
func getExampleConfiguration() *WgMeshConfiguration {
return &WgMeshConfiguration{
CertificatePath: "./cert/cert.pem",
PrivateKeyPath: "./cert/key.pem",
CaCertificatePath: "./cert/ca.pems",
SkipCertVerification: true,
GrpcPort: "8080",
AdvertiseRoutes: true,
Endpoint: "localhost",
ClusterSize: 1,
SyncRate: 1,
InterClusterChance: 0.1,
BranchRate: 2,
KeepAliveTime: 4,
InfectionCount: 1,
Timeout: 2,
PruneTime: 20,
}
}
func TestConfigurationCertificatePathEmpty(t *testing.T) {
conf := getExampleConfiguration()
conf.CertificatePath = ""
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestConfigurationPrivateKeyPathEmpty(t *testing.T) {
conf := getExampleConfiguration()
conf.PrivateKeyPath = ""
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestConfigurationCaCertificatePathEmpty(t *testing.T) {
conf := getExampleConfiguration()
conf.CaCertificatePath = ""
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestConfigurationGrpcPortEmpty(t *testing.T) {
conf := getExampleConfiguration()
conf.GrpcPort = ""
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestClusterSizeZero(t *testing.T) {
conf := getExampleConfiguration()
conf.ClusterSize = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func SyncRateZero(t *testing.T) {
conf := getExampleConfiguration()
conf.SyncRate = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func BranchRateZero(t *testing.T) {
conf := getExampleConfiguration()
conf.BranchRate = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func InfectionCountZero(t *testing.T) {
conf := getExampleConfiguration()
conf.InfectionCount = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func KeepAliveRateZero(t *testing.T) {
conf := getExampleConfiguration()
conf.KeepAliveTime = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestValidCOnfiguration(t *testing.T) {
conf := getExampleConfiguration()
err := ValidateConfiguration(conf)
if err != nil {
t.Error(err)
}
}
func TestTimeout(t *testing.T) {
conf := getExampleConfiguration()
conf.Timeout = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestPruneTimeZero(t *testing.T) {
conf := getExampleConfiguration()
conf.PruneTime = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatalf(`Error should be thrown`)
}
}
func TestPruneTimeLessThanKeepAliveTime(t *testing.T) {
conf := getExampleConfiguration()
conf.PruneTime = 1
err := ValidateConfiguration(conf)
if err == nil {
t.Fatalf(`Error should be thrown`)
}
}

75
pkg/conn/cluster.go Normal file
View File

@ -0,0 +1,75 @@
package conn
import (
"errors"
"math"
"math/rand"
"slices"
)
// ConnCluster splits nodes into clusters where nodes in a cluster communicate
// frequently and nodes outside of a cluster communicate infrequently
type ConnCluster interface {
GetNeighbours(global []string, selfId string) []string
GetInterCluster(global []string, selfId string) string
}
type ConnClusterImpl struct {
clusterSize int
}
func binarySearch(global []string, selfId string, groupSize int) (int, int) {
slices.Sort(global)
lower := 0
higher := len(global) - 1
mid := (lower + higher) / 2
for (higher+1)-lower > groupSize {
if global[mid] < selfId {
lower = mid + 1
} else if global[mid] > selfId {
higher = mid - 1
} else {
break
}
mid = (lower + higher) / 2
}
return lower, int(math.Min(float64(lower+groupSize), float64(len(global))))
}
// GetNeighbours return the neighbours 'nearest' to you. In this implementation the
// neighbours aren't actually the ones nearest to you but just the ones nearest
// to you alphabetically. Perform binary search to get the total group
func (i *ConnClusterImpl) GetNeighbours(global []string, selfId string) []string {
slices.Sort(global)
lower, higher := binarySearch(global, selfId, i.clusterSize)
// slice the list to get the neighbours
return global[lower:higher]
}
// GetInterCluster get nodes not in your cluster. Every round there is a given chance
// you will communicate with a random node that is not in your cluster.
func (i *ConnClusterImpl) GetInterCluster(global []string, selfId string) string {
// Doesn't matter if not in it. Get index of where the node 'should' be
index, _ := binarySearch(global, selfId, 1)
numClusters := math.Ceil(float64(len(global)) / float64(i.clusterSize))
randomCluster := rand.Intn(int(numClusters)-1) + 1
neighbourIndex := (index + randomCluster) % len(global)
return global[neighbourIndex]
}
func NewConnCluster(clusterSize int) (ConnCluster, error) {
log2Cluster := math.Log2(float64(clusterSize))
if float64((log2Cluster))-log2Cluster != 0 {
return nil, errors.New("cluster must be a power of 2")
}
return &ConnClusterImpl{clusterSize: clusterSize}, nil
}

116
pkg/conn/cluster_test.go Normal file
View File

@ -0,0 +1,116 @@
package conn
import (
"math/rand"
"slices"
"testing"
)
func TestGetNeighboursClusterSizeTwo(t *testing.T) {
cluster := &ConnClusterImpl{
clusterSize: 2,
}
neighbours := []string{
"a",
"b",
"c",
"d",
}
result := cluster.GetNeighbours(neighbours, "b")
if len(result) != 2 {
t.Fatalf(`neighbour length should be 2`)
}
if result[0] != "a" && result[1] != "b" {
t.Fatalf(`Expected value b`)
}
}
func TestGetNeighboursGlobalListLessThanClusterSize(t *testing.T) {
cluster := &ConnClusterImpl{
clusterSize: 4,
}
neighbours := []string{
"a",
"b",
"c",
}
result := cluster.GetNeighbours(neighbours, "a")
if len(result) != 3 {
t.Fatalf(`neighbour length should be 3`)
}
slices.Sort(result)
if !slices.Equal(result, neighbours) {
t.Fatalf(`Cluster and neighbours should be equal`)
}
}
func TestGetNeighboursClusterSize4(t *testing.T) {
cluster := &ConnClusterImpl{
clusterSize: 4,
}
neighbours := []string{
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
"k", "l", "m", "n", "o",
}
result := cluster.GetNeighbours(neighbours, "k")
if len(result) != 4 {
t.Fatalf(`cluster size must be 4`)
}
slices.Sort(result)
if !slices.Equal(neighbours[8:12], result) {
t.Fatalf(`Cluster should be i, j, k, l`)
}
}
func TestGetNeighboursClusterSize4OneReturned(t *testing.T) {
cluster := &ConnClusterImpl{
clusterSize: 4,
}
neighbours := []string{
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
"k", "l", "m", "n", "o",
}
result := cluster.GetNeighbours(neighbours, "o")
if len(result) != 3 {
t.Fatalf(`Cluster should be of length 3`)
}
if !slices.Equal(neighbours[12:15], result) {
t.Fatalf(`Cluster should be m, n, o`)
}
}
func TestInterClusterNotInCluster(t *testing.T) {
rand.Seed(1)
cluster := &ConnClusterImpl{
clusterSize: 4,
}
global := []string{
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
"k", "l", "m", "n", "o",
}
neighbours := cluster.GetNeighbours(global, "c")
interCluster := cluster.GetInterCluster(global, "c")
if slices.Contains(neighbours, interCluster) {
t.Fatalf(`intercluster cannot be in your cluster`)
}
}

View File

@ -18,6 +18,8 @@ type PeerConnection interface {
GetClient() (*grpc.ClientConn, error)
}
type PeerConnectionFactory = func(clientConfig *tls.Config, server string) (PeerConnection, error)
// WgCtrlConnection implements PeerConnection.
type WgCtrlConnection struct {
clientConfig *tls.Config
@ -26,12 +28,12 @@ type WgCtrlConnection struct {
}
// NewWgCtrlConnection creates a new instance of a WireGuard control connection
func NewWgCtrlConnection(clientConfig *tls.Config, server string) (*WgCtrlConnection, error) {
func NewWgCtrlConnection(clientConfig *tls.Config, server string) (PeerConnection, error) {
var conn WgCtrlConnection
conn.clientConfig = clientConfig
conn.endpoint = server
if err := conn.createGrpcConn(); err != nil {
if err := conn.CreateGrpcConnection(); err != nil {
return nil, err
}
@ -39,7 +41,7 @@ func NewWgCtrlConnection(clientConfig *tls.Config, server string) (*WgCtrlConnec
}
// ConnectWithToken: Connects to a new gRPC peer given the address of the other server.
func (c *WgCtrlConnection) createGrpcConn() error {
func (c *WgCtrlConnection) CreateGrpcConnection() error {
conn, err := grpc.Dial(c.endpoint,
grpc.WithTransportCredentials(credentials.NewTLS(c.clientConfig)))
@ -62,7 +64,7 @@ func (c *WgCtrlConnection) GetClient() (*grpc.ClientConn, error) {
var err error = nil
if c.conn == nil {
err = errors.New("The client's config does not exist")
err = errors.New("the client's config does not exist")
}
return c.conn, err

View File

@ -34,10 +34,11 @@ type ConnectionManagerImpl struct {
clientConnections map[string]PeerConnection
serverConfig *tls.Config
clientConfig *tls.Config
connFactory PeerConnectionFactory
}
// Create a new instance of a connection manager.
type NewConnectionManageParams struct {
type NewConnectionManagerParams struct {
// The path to the certificate
CertificatePath string
// The private key of the node
@ -45,11 +46,12 @@ type NewConnectionManageParams struct {
// Whether or not to skip certificate verification
SkipCertVerification bool
CaCert string
ConnFactory PeerConnectionFactory
}
// NewConnectionManager: Creates a new instance of a ConnectionManager or an error
// if something went wrong.
func NewConnectionManager(params *NewConnectionManageParams) (ConnectionManager, error) {
func NewConnectionManager(params *NewConnectionManagerParams) (ConnectionManager, error) {
cert, err := tls.LoadX509KeyPair(params.CertificatePath, params.PrivateKey)
if err != nil {
@ -94,10 +96,12 @@ func NewConnectionManager(params *NewConnectionManageParams) (ConnectionManager,
}
connections := make(map[string]PeerConnection)
connMgr := ConnectionManagerImpl{sync.RWMutex{},
connMgr := ConnectionManagerImpl{
sync.RWMutex{},
connections,
serverConfig,
clientConfig,
params.ConnFactory,
}
return &connMgr, nil
@ -127,7 +131,7 @@ func (m *ConnectionManagerImpl) AddConnection(endPoint string) (PeerConnection,
return conn, nil
}
connections, err := NewWgCtrlConnection(m.clientConfig, endPoint)
connections, err := m.connFactory(m.clientConfig, endPoint)
if err != nil {
return nil, err

View File

@ -0,0 +1,145 @@
package conn
import (
"crypto/tls"
"errors"
"log"
"testing"
)
func getConnectionManagerParams() *NewConnectionManagerParams {
return &NewConnectionManagerParams{
CertificatePath: "./test/cert.pem",
PrivateKey: "./test/priv.pem",
CaCert: "./test/cacert.pem",
SkipCertVerification: false,
ConnFactory: MockFactory,
}
}
func TestNewConnectionManagerCertificatePathDoesNotExist(t *testing.T) {
params := getConnectionManagerParams()
params.CertificatePath = "./cert/sdfjdskjdsjkd.pem"
_, err := NewConnectionManager(params)
if err == nil {
t.Fatalf(`Expected error as certificate does not exist`)
}
}
func TestNewConnectionManagerPrivateKeyDoesNotExist(t *testing.T) {
params := getConnectionManagerParams()
params.PrivateKey = "./cert/sdjdjdks.pem"
_, err := NewConnectionManager(params)
if err == nil {
t.Fatalf(`Expected error as private key does not exist`)
}
}
func TestNewConnectionManagerCACertDoesNotExistAndVerify(t *testing.T) {
params := getConnectionManagerParams()
params.CaCert = "./cert/sdjdsjdksjdks.pem"
params.SkipCertVerification = false
_, err := NewConnectionManager(params)
if err == nil {
t.Fatal(`Expected error as ca cert does not exist and skip is false`)
}
}
func TestNewConnectionManagerCACertDoesNotExistAndNotVerify(t *testing.T) {
params := getConnectionManagerParams()
params.CaCert = ""
params.SkipCertVerification = true
_, err := NewConnectionManager(params)
if err != nil {
t.Fatal(`an error should not be thrown`)
}
}
func TestGetConnectionConnectionDoesNotExistAddsConnection(t *testing.T) {
params := getConnectionManagerParams()
m, _ := NewConnectionManager(params)
conn, err := m.GetConnection("abc-123.com")
if err != nil {
t.Error(err)
}
if conn == nil {
t.Fatal(`the connection should not be nil`)
}
conn2, _ := m.GetConnection("abc-123.com")
if conn != conn2 {
log.Fatalf(`should return the same connection instance`)
}
}
func TestAddConnectionThrowsAnErrorIfFactoryThrowsError(t *testing.T) {
params := getConnectionManagerParams()
params.ConnFactory = func(clientConfig *tls.Config, server string) (PeerConnection, error) {
return nil, errors.New("this is an error")
}
m, _ := NewConnectionManager(params)
_, err := m.AddConnection("abc-123.com")
if err == nil || err.Error() != "this is an error" {
t.Error(err)
}
}
func TestAddConnectionConnectionDoesNotExist(t *testing.T) {
params := getConnectionManagerParams()
m, _ := NewConnectionManager(params)
conn, err := m.AddConnection("abc-123.com")
if err != nil {
t.Error(err)
}
if conn == nil {
t.Fatal(`connection should not be nil`)
}
conn1, _ := m.GetConnection("abc-123.com")
if conn != conn1 {
t.Fatal(`underlying connections should be the same`)
}
}
func TestHasConnectionConnectionDoesNotExist(t *testing.T) {
params := getConnectionManagerParams()
m, _ := NewConnectionManager(params)
if m.HasConnection("abc-123.com") {
t.Fatal(`should return that the connection does not exist`)
}
}
func TestHasConnectionConnectionExists(t *testing.T) {
params := getConnectionManagerParams()
m, _ := NewConnectionManager(params)
m.AddConnection("abc-123.com")
if !m.HasConnection("abc-123.com") {
t.Fatal(`should return that the connection exists`)
}
}

View File

@ -16,9 +16,7 @@ type ConnectionServer struct {
// tlsConfiguration of the server
serverConfig *tls.Config
// server an instance of the grpc server
server *grpc.Server
// the authentication service to authenticate nodes
authProvider rpc.AuthenticationServer
server *grpc.Server // the authentication service to authenticate nodes
// the ctrl service to manage node
ctrlProvider rpc.MeshCtrlServerServer
// the sync service to synchronise nodes
@ -30,7 +28,6 @@ type ConnectionServer struct {
// NewConnectionServerParams contains params for creating a new connection server
type NewConnectionServerParams struct {
Conf *conf.WgMeshConfiguration
AuthProvider rpc.AuthenticationServer
CtrlProvider rpc.MeshCtrlServerServer
SyncProvider rpc.SyncServiceServer
}
@ -59,14 +56,12 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
grpc.Creds(credentials.NewTLS(serverConfig)),
)
authProvider := params.AuthProvider
ctrlProvider := params.CtrlProvider
syncProvider := params.SyncProvider
connServer := ConnectionServer{
serverConfig: serverConfig,
server: server,
authProvider: authProvider,
ctrlProvider: ctrlProvider,
syncProvider: syncProvider,
Conf: params.Conf,
@ -78,7 +73,6 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
// Listen for incoming requests. Returns an error if something went wrong.
func (s *ConnectionServer) Listen() error {
rpc.RegisterMeshCtrlServerServer(s.server, s.ctrlProvider)
rpc.RegisterAuthenticationServer(s.server, s.authProvider)
rpc.RegisterSyncServiceServer(s.server, s.syncProvider)

51
pkg/conn/stub.go Normal file
View File

@ -0,0 +1,51 @@
package conn
import (
"crypto/tls"
"google.golang.org/grpc"
)
type ConnectionManagerStub struct {
Endpoints map[string]PeerConnection
}
func (s *ConnectionManagerStub) AddConnection(endPoint string) (PeerConnection, error) {
mock := &PeerConnectionMock{}
s.Endpoints[endPoint] = mock
return mock, nil
}
func (s *ConnectionManagerStub) GetConnection(endPoint string) (PeerConnection, error) {
endpoint, ok := s.Endpoints[endPoint]
if !ok {
return s.AddConnection(endPoint)
}
return endpoint, nil
}
func (s *ConnectionManagerStub) HasConnection(endPoint string) bool {
_, ok := s.Endpoints[endPoint]
return ok
}
func (s *ConnectionManagerStub) Close() error {
return nil
}
type PeerConnectionMock struct {
}
func (c *PeerConnectionMock) Close() error {
return nil
}
func (c *PeerConnectionMock) GetClient() (*grpc.ClientConn, error) {
return &grpc.ClientConn{}, nil
}
var MockFactory PeerConnectionFactory = func(clientConfig *tls.Config, server string) (PeerConnection, error) {
return &PeerConnectionMock{}, nil
}

21
pkg/conn/test/cacert.pem Normal file
View File

@ -0,0 +1,21 @@
-----BEGIN CERTIFICATE-----
MIIDazCCAlOgAwIBAgIUDRIRI8UnHU2a4znsun0gxFwlrFQwDQYJKoZIhvcNAQEL
BQAwRTELMAkGA1UEBhMCR0IxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMzEwMjcxNTIzMDZaFw0yNDEw
MjYxNTIzMDZaMEUxCzAJBgNVBAYTAkdCMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw
HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB
AQUAA4IBDwAwggEKAoIBAQDJ5hOmzilimA/zM5hYP7CQf4iRmICtSbVLgt6/rTDP
p3JsGGQWZ4pZNofzGnGa7aEMoXS2Ztl7GzZbr1p4+rd6MBbVt8XZ/hP+X4zasCXi
/YubG0TYyBuAt+JrcYb0cbsTBkMXXnFcNIXDfeYFsNq+pfyJwq2ElMUUZ6SQmVhH
ovn1Wk9Fv4t2GJMhmUcObrSIoYdgo4Vf9CfQnn0PCaRf+RjspY/Kz33oyqDI6xJx
I0rfJR7f9B6ZKosfAkt4oTTfT9P8w/d1I95oBENhDkalgkdJCuNJ/AwKGxZrYf/P
aefcc91HheauObjBYPFrSn6bUj3LMJEfj4IeBK+fOZCfAgMBAAGjUzBRMB0GA1Ud
DgQWBBSpcF7jtpd9n73VM3xhPmI1GMEkFjAfBgNVHSMEGDAWgBSpcF7jtpd9n73V
M3xhPmI1GMEkFjAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQCK
GplAveP9nVo9zmg+/mkDpyVoo5rp64oJh4DFtm6X+EI31FmH6Cb71Kn2ZzXhQvSq
qrP7+VoGeBDxk4guJtAs/fhnuDupJG2SjsctjiFnDbSrJjWJjGhC0kuL0wcjLU5G
qUpCEJu13GkDlYHKKw0z+oLUOw+OHmvE5/sD23sKl2KxBWKItx0hwSCkGtm0RQld
8mfjOsHqJ2V/FOcHK6X2DSV1728PAhu4l/PRSB0drBA+7kdeCuWIRZw5RA/OyxvU
CuC5dfUh75MrK7KL6sZsXklsoXo8BZp4rRRUt/v1D3r/SMBJPULSGXh6QDjXQX1D
km71c3DEDyKznHTpGxPt
-----END CERTIFICATE-----

View File

@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDJ5hOmzilimA/z
M5hYP7CQf4iRmICtSbVLgt6/rTDPp3JsGGQWZ4pZNofzGnGa7aEMoXS2Ztl7GzZb
r1p4+rd6MBbVt8XZ/hP+X4zasCXi/YubG0TYyBuAt+JrcYb0cbsTBkMXXnFcNIXD
feYFsNq+pfyJwq2ElMUUZ6SQmVhHovn1Wk9Fv4t2GJMhmUcObrSIoYdgo4Vf9CfQ
nn0PCaRf+RjspY/Kz33oyqDI6xJxI0rfJR7f9B6ZKosfAkt4oTTfT9P8w/d1I95o
BENhDkalgkdJCuNJ/AwKGxZrYf/Paefcc91HheauObjBYPFrSn6bUj3LMJEfj4Ie
BK+fOZCfAgMBAAECggEADqAjoUxC9Dj2wtPkf9QRSs5qSr3E6Iiz4OX4k+MMa6aC
I/F6YqMagw7vtz0dqK75ISybA1GdBI16mRaxU5056FiOdunqo7mDokQytG7ZN8HN
OK23hYqtb1wiw0zEjXWlqyGjf5BgXuERJZG7tYLTvcbRbftTzYxnYGyHn8/z9LBp
GsTJ5X8XMLM5+bTvg1Ovv5s0q31FCeqAuw+auHH4pBNP+ylV6dF5XOWq4HO3TJ2b
grHxWB94JZChZnDC/K+HxQ6aHJfbZ5XCoXfIaIVkoXfnyPzgjvgK+/IpHEF8f/3I
uT/NBiArTpRl29pX5flEO4R121VaW93eM1tuzL32VQKBgQD6Trctx9SYuhzgfiO7
kdefvR43Kl9SFyEw3hN3HW1cxSNGCCFotjmdem+QdtMBtUd27UJ9tuiKJC0lcCER
t3WRz4kVd/cb0eC1DPzpGHA81o1rUUR3nMr1o7aBfvQ06VAxFUrFAOPpF8nD7tI4
0CiOh7/sL1ElThA3bOPUpXkYHQKBgQDOfYbP8dppIkC8pRTnHWe0qUY0G4YXxg7r
UtTo4GYOLJeKH/MKoK8MjBDS5VN5n5TAHJ8yUVzhpWXZIPIGzNEhIRDMa56sRPgI
9mLJNs5z/ZIxd/7ZQbDHrD4T3PKeTjzVUtjXrhLowokPlPB/RMQL6ZT+qMao+3bS
fDITSfLG6wKBgBpbcZSDh1JxvpqxDagxqkfqzSS39IObZeZUbC5NzfdH1vgH4SS6
k4SOoPLQYFW8tgLC5w5/1Sq+tnZLwV+xNtMczG2TTVUDm6rU7EjLRv5RBWE4lIIX
45NMIuqt6J8ttkEE4fOurVEdLSTRoBdVa//eMYp4TQ4lkzWS5Ma+ierNAoGAYO3z
1rFFQYzerq8ffM4E3H2JgvRYodhLMJQVdavAvG6aRDBzOk3rXgxx6U3VPYZ3oSbO
ZCRlYVbu1FnuwtpqYQ7Qf+UU+vD1Ld/ax3F+wFwLwET/0KRRg6mLCm/xQ/ad/9WA
DN6d6b1H8ZSMwHFbRexEELbRaomAYZYDO6K+4DkCgYEAv5De85hPnWtAvKhPzwQi
9mtyWo/cfQgtwL8IKNu6hBHl5RXDpPgX/+pNbXLJfBPwVR3H62x1CMYJDkWVuE6/
ZjtF7FSucZMz/mR6r1GhSOXy3YLwQ6JLPjjKzvnEjahGlKwALJNL0O2ZucjsZxHE
PM4rmhRZT9opiapiltEhRm0=
-----END PRIVATE KEY-----

19
pkg/conn/test/cert.pem Normal file
View File

@ -0,0 +1,19 @@
-----BEGIN CERTIFICATE-----
MIIDCjCCAfICFB/Vd2eOXWdNdrakThJhFIRtZmhUMA0GCSqGSIb3DQEBCwUAMEUx
CzAJBgNVBAYTAkdCMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRl
cm5ldCBXaWRnaXRzIFB0eSBMdGQwHhcNMjMxMDI3MTUzNDM1WhcNMjMxMTI2MTUz
NDM1WjA+MQswCQYDVQQGEwJHQjENMAsGA1UECAwERmlmZTENMAsGA1UEBwwEY2l0
eTERMA8GA1UECgwITWVzaCBMdGQwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEK
AoIBAQDVgcLtNU5AYfPML/mE5PyC7YYKvZn2mt6vEiJ7M/6EzYeTXFeYexD5ZqHg
ewGEd1fwiQWQsATsWd+EM4OnCAXAaNiOH6gGY7FR8CThfT+k8yIGPrl1BovzHHYS
Orekna17UFeIyFMHDPIjl4d2WiJPvmNn5PhLEppPHPBWPhl3J3sMrSbqyRuYbtta
oFIzN8mFcikixLg0SnBPtwlLC72ah9G+MF5CwEcU/E0bYbLQZXv+WhG5aw5JEzes
K2GLxVNgM0xXB7hSyLoX1wBc8DdQyLCMkOp55Hl04UKTxtVE82MiuAOVqMUuKFjR
u2a1C+/Gbk/PS5SHgenGjdZ8sZGpAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAHMc
jIFG5Rn9KaVmo7E+/UAq+3ld/3y2yMHg5wq7oG8b7/z0mlSGErHdFMzo75AFLN4r
kOuiF5ItF6dRLNrG8IUFSNMGVH3b3ukw1EI8E89L8ak3CM+wpLT6GVP3BfV8ah+X
4RRix40Tmx4C81l+Lf5W10rHIdlXBCanJy/Fa0ae+S+oXFc9jeXHlK9qlgszrECT
Pa3VCR95LAIc6o9pDL2Z8tpEkSbyzvIWhp53fnC80PyXpSsFMfIw657shagBc/Ov
e7/aPpPf3V3CafJlEIraQp24MDI5ZM59lT5vhRq2AC50gelL6UPV16mVVUlGVhWE
vYyejod5i5ZbuLFOy2g=
-----END CERTIFICATE-----

28
pkg/conn/test/priv.pem Normal file
View File

@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDVgcLtNU5AYfPM
L/mE5PyC7YYKvZn2mt6vEiJ7M/6EzYeTXFeYexD5ZqHgewGEd1fwiQWQsATsWd+E
M4OnCAXAaNiOH6gGY7FR8CThfT+k8yIGPrl1BovzHHYSOrekna17UFeIyFMHDPIj
l4d2WiJPvmNn5PhLEppPHPBWPhl3J3sMrSbqyRuYbttaoFIzN8mFcikixLg0SnBP
twlLC72ah9G+MF5CwEcU/E0bYbLQZXv+WhG5aw5JEzesK2GLxVNgM0xXB7hSyLoX
1wBc8DdQyLCMkOp55Hl04UKTxtVE82MiuAOVqMUuKFjRu2a1C+/Gbk/PS5SHgenG
jdZ8sZGpAgMBAAECggEARJNAggLYhtpPPVp9WJ9ZsU3L+0AppujYND/tXkf1bD89
V+nVYq7IZWp+/MRVWPAiCSphZLb8ZdN59JK9KtVrT4D9aSymwaKcjfZFSj15xyem
Wn4j///hzGxsSe+dE1znnw9PhindbQrN7Pua8TsDATzj3bdPvoETmexwDysz765i
u4zXvxP+xAessz1OYa5IUaDXdlWOf0e1zNXWwanjRggzCeWR3lTofG49GX087oVC
Sb9ASy+AScnOlwpTdQ8sKy1r9gXmE5ey4AULVb0nJ8LDvrCoBBhKBtVE5mJHepE6
bdC9l6poL6roGvHfMAo3SmiUUT5XceqUxBtHcyHX3wKBgQD1uh+Dv0PrH3CTW9cF
bwHL1rmQNJrbDzDAaounGBe9mcot1RrBhyQAoGw1no4c+QWDAwYRuBP2+Rp6JLU/
XnEXSyN85rJN6LajlrLEr+BNmKw6ghNsnAFUZBLaJ7epRi6OjACUwmtvH6hRIef8
aMg4WiOyDT+Z4Xe81pdXb91HXwKBgQDebs3idgVEau3LCKGYnqvmUhzv8iiQiJmD
R29o2G5Xrahf3r1O5gJdGLO1DaCBtdrI7J4xUOlM935KaEYFe5B7RVGXg23tNWgb
2M+YQqu5qz61bDxhg7dGkegHrdvKNcSkV6GUSm5w9rdxJlY8+l45p/7QpSkatcbd
IRiVzMNr9wKBgQC/+Z5fbpFgYxqvdaPicdxkZShqOj71f8OlwFfEvrTlgv4KmqAh
rDP7bVm89leu2PpuZXFbbIXkgK8n1//mNyGBgkmCbjXFWlc+LSETOxixZuK/fxov
0x3S0bBM0ZTSYatD4KsfjVkj4wa8BBJbB33NUNbsZx9WWGkUlk58mD+3XwKBgQDV
mgR+n6WJQUIfwqckH+Ol517AkYSg33zEE9qKDaVQ74QMpKKY3MqSSkFw8agcR93V
K1zysOeJsPYHUEFFzJY/up6S6HSs4aebbkZUylmMkEVFBa6qWkmrLDxs+2lgsuem
hjy1YhDSzCn3L8CLCEdqCMjr5l8ltkBFZB3u5NcZmwKBgHE9ODedQm783JfvDNBb
lB/IoUjMhMR0J2vHC3zxgTU4nIK+MR0vXvA7fmZebpaQNwYrHY9gvrL0/QevOrmG
PtXlkQ9GITMxTlqfHWV5jXZuRBIGTqh1QW3tKbVAhUhNlM0XDNBmBvjKIFjxUIo3
zMRw/o4R4cIaazyVxguZbsa2
-----END PRIVATE KEY-----

84
pkg/conn/window.go Normal file
View File

@ -0,0 +1,84 @@
package conn
import (
"errors"
"slices"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
// ConnectionWindow maintains a sliding window of connections between users
type ConnectionWindow interface {
// GetWindow is a list of connections to choose from
GetWindow() []string
// SlideConnection removes a node from the window and adds a random node
// not already in the window. connList represents the list of possible
// connections to choose from
SlideConnection(connList []string) error
// PushConneciton is used when connection list less than window size.
PutConnection(conn []string) error
// IsFull returns true if the window is full. In which case we must slide the window
IsFull() bool
}
type ConnectionWindowImpl struct {
window []string
windowSize int
}
// GetWindow gets the current list of active connections in
// the window
func (c *ConnectionWindowImpl) GetWindow() []string {
return c.window
}
// SlideConnection slides the connection window by one shuffling items
// in the windows
func (c *ConnectionWindowImpl) SlideConnection(connList []string) error {
// If the number of peer connections is less than the length of the window
// then exit early. Can't slide the window it should contain all nodes!
if len(c.window) < c.windowSize {
return nil
}
filter := func(node string) bool {
return !slices.Contains(c.window, node)
}
pool := lib.Filter(connList, filter)
newNode := lib.RandomSubsetOfLength(pool, 1)
if len(newNode) == 0 {
return errors.New("could not slide window")
}
for i := len(c.window) - 1; i >= 1; i-- {
c.window[i] = c.window[i-1]
}
c.window[0] = newNode[0]
return nil
}
// PutConnection put random connections in the connection
func (c *ConnectionWindowImpl) PutConnection(connList []string) error {
if len(c.window) >= c.windowSize {
return errors.New("cannot place connection. Window full need to slide")
}
c.window = lib.RandomSubsetOfLength(connList, c.windowSize)
return nil
}
func (c *ConnectionWindowImpl) IsFull() bool {
return len(c.window) >= c.windowSize
}
func NewConnectionWindow(windowLength int) ConnectionWindow {
window := &ConnectionWindowImpl{
window: make([]string, 0),
windowSize: windowLength,
}
return window
}

View File

@ -18,7 +18,6 @@ import (
type NewCtrlServerParams struct {
Conf *conf.WgMeshConfiguration
Client *wgctrl.Client
AuthProvider rpc.AuthenticationServer
CtrlProvider rpc.MeshCtrlServerServer
SyncProvider rpc.SyncServiceServer
Querier query.Querier
@ -36,6 +35,8 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
ipAllocator := &ip.ULABuilder{}
interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client)
configApplyer := mesh.NewWgMeshConfigApplyer(params.Conf)
meshManagerParams := &mesh.NewMeshManagerParams{
Conf: *params.Conf,
Client: params.Client,
@ -44,16 +45,19 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
IdGenerator: idGenerator,
IPAllocator: ipAllocator,
InterfaceManipulator: interfaceManipulator,
ConfigApplyer: configApplyer,
}
ctrlServer.MeshManager = mesh.NewMeshManager(meshManagerParams)
configApplyer.SetMeshManager(ctrlServer.MeshManager)
ctrlServer.Conf = params.Conf
connManagerParams := conn.NewConnectionManageParams{
connManagerParams := conn.NewConnectionManagerParams{
CertificatePath: params.Conf.CertificatePath,
PrivateKey: params.Conf.PrivateKeyPath,
SkipCertVerification: params.Conf.SkipCertVerification,
CaCert: params.Conf.CaCertificatePath,
ConnFactory: conn.NewWgCtrlConnection,
}
connMgr, err := conn.NewConnectionManager(&connManagerParams)
@ -65,7 +69,6 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
ctrlServer.ConnectionManager = connMgr
connServerParams := conn.NewConnectionServerParams{
Conf: params.Conf,
AuthProvider: params.AuthProvider,
CtrlProvider: params.CtrlProvider,
SyncProvider: params.SyncProvider,
}
@ -82,12 +85,36 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
return ctrlServer, nil
}
func (s *MeshCtrlServer) GetConfiguration() *conf.WgMeshConfiguration {
return s.Conf
}
func (s *MeshCtrlServer) GetClient() *wgctrl.Client {
return s.Client
}
func (s *MeshCtrlServer) GetQuerier() query.Querier {
return s.Querier
}
func (s *MeshCtrlServer) GetMeshManager() mesh.MeshManager {
return s.MeshManager
}
func (s *MeshCtrlServer) GetConnectionManager() conn.ConnectionManager {
return s.ConnectionManager
}
// Close closes the ctrl server tearing down any connections that exist
func (s *MeshCtrlServer) Close() error {
if err := s.ConnectionManager.Close(); err != nil {
logging.Log.WriteErrorf(err.Error())
}
if err := s.MeshManager.Close(); err != nil {
logging.Log.WriteErrorf(err.Error())
}
if err := s.ConnectionServer.Close(); err != nil {
logging.Log.WriteErrorf(err.Error())
}

View File

@ -17,6 +17,9 @@ type MeshNode struct {
WgHost string
Timestamp int64
Routes []string
Description string
Alias string
Services map[string]string
}
// Represents a WireGuard Mesh
@ -25,10 +28,19 @@ type Mesh struct {
Nodes map[string]MeshNode
}
type CtrlServer interface {
GetConfiguration() *conf.WgMeshConfiguration
GetClient() *wgctrl.Client
GetQuerier() query.Querier
GetMeshManager() mesh.MeshManager
Close() error
GetConnectionManager() conn.ConnectionManager
}
// Represents a ctrlserver to be used in WireGuard
type MeshCtrlServer struct {
Client *wgctrl.Client
MeshManager *mesh.MeshManager
MeshManager mesh.MeshManager
ConnectionManager conn.ConnectionManager
ConnectionServer *conn.ConnectionServer
Conf *conf.WgMeshConfiguration

51
pkg/ctrlserver/stub.go Normal file
View File

@ -0,0 +1,51 @@
package ctrlserver
import (
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query"
"golang.zx2c4.com/wireguard/wgctrl"
)
type CtrlServerStub struct {
manager mesh.MeshManager
querier query.Querier
connectionManager conn.ConnectionManager
}
func NewCtrlServerStub() *CtrlServerStub {
var manager mesh.MeshManager = mesh.NewMeshManagerStub()
return &CtrlServerStub{
manager: manager,
querier: query.NewJmesQuerier(manager),
connectionManager: &conn.ConnectionManagerStub{},
}
}
func (c *CtrlServerStub) GetConfiguration() *conf.WgMeshConfiguration {
return &conf.WgMeshConfiguration{
GrpcPort: "8080",
Endpoint: "abc.com",
}
}
func (c *CtrlServerStub) GetClient() *wgctrl.Client {
return &wgctrl.Client{}
}
func (c *CtrlServerStub) GetQuerier() query.Querier {
return c.querier
}
func (c *CtrlServerStub) GetMeshManager() mesh.MeshManager {
return c.manager
}
func (c *CtrlServerStub) Close() error {
return nil
}
func (c *CtrlServerStub) GetConnectionManager() conn.ConnectionManager {
return c.connectionManager
}

132
pkg/hosts/hosts.go Normal file
View File

@ -0,0 +1,132 @@
// hosts: utility for modifying the /etc/hosts file
package hosts
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"os"
"strings"
)
// HOSTS_FILE is the hosts file location
const HOSTS_FILE = "/etc/hosts"
const DOMAIN_HEADER = "#WG AUTO GENERATED HOSTS"
const DOMAIN_TRAILER = "#WG AUTO GENERATED HOSTS END"
type HostsEntry struct {
Alias string
Ip net.IP
}
// Generic interface to manipulate /etc/hosts file
type HostsManipulator interface {
// AddrAddr associates an aliasd with a given IP address
AddAddr(hosts ...HostsEntry)
// Remove deletes the entry from /etc/hosts
Remove(hosts ...HostsEntry)
// Writes the changes to /etc/hosts file
Write() error
}
type HostsManipulatorImpl struct {
hosts map[string]HostsEntry
}
// AddAddr implements HostsManipulator.
func (m *HostsManipulatorImpl) AddAddr(hosts ...HostsEntry) {
changed := false
for _, host := range hosts {
prev, ok := m.hosts[host.Ip.String()]
if !ok || prev.Alias != host.Alias {
changed = true
}
m.hosts[host.Ip.String()] = host
}
if changed {
m.Write()
}
}
// Remove implements HostsManipulator.
func (m *HostsManipulatorImpl) Remove(hosts ...HostsEntry) {
lenBefore := len(m.hosts)
for _, host := range hosts {
delete(m.hosts, host.Alias)
}
if lenBefore != len(m.hosts) {
m.Write()
}
}
func (m *HostsManipulatorImpl) removeHosts() string {
hostsFile, err := os.ReadFile(HOSTS_FILE)
if err != nil {
return ""
}
var contents strings.Builder
scanner := bufio.NewScanner(bytes.NewReader(hostsFile))
hostsSection := false
for scanner.Scan() {
line := scanner.Text()
if err == io.EOF {
break
} else if err != nil {
return ""
}
if !hostsSection && strings.Contains(line, DOMAIN_HEADER) {
hostsSection = true
}
if !hostsSection {
contents.WriteString(line + "\n")
}
if hostsSection && strings.Contains(line, DOMAIN_TRAILER) {
hostsSection = false
}
}
if scanner.Err() != nil && scanner.Err() != io.EOF {
return ""
}
return contents.String()
}
// Write implements HostsManipulator
func (m *HostsManipulatorImpl) Write() error {
contents := m.removeHosts()
var nextHosts strings.Builder
nextHosts.WriteString(contents)
nextHosts.WriteString(DOMAIN_HEADER + "\n")
for _, host := range m.hosts {
nextHosts.WriteString(fmt.Sprintf("%s\t%s\n", host.Ip.String(), host.Alias))
}
nextHosts.WriteString(DOMAIN_TRAILER + "\n")
return os.WriteFile(HOSTS_FILE, []byte(nextHosts.String()), 0644)
}
func NewHostsManipulator() HostsManipulator {
return &HostsManipulatorImpl{hosts: make(map[string]HostsEntry)}
}

View File

@ -11,8 +11,6 @@ import (
)
type NewMeshArgs struct {
// IfName is the interface that the mesh instance will run on
IfName string
// WgPort is the WireGuard port to expose
WgPort int
// Endpoint is the routable alias of the machine. Can be an IP
@ -25,13 +23,19 @@ type JoinMeshArgs struct {
MeshId string
// IpAddress is a routable IP in another mesh
IpAdress string
// IfName is the interface name of the mesh
IfName string
// Port is the WireGuard port to expose
Port int
// Endpoint is the routable address of this machine. If not provided
// defaults to the default address
Endpoint string
// Client specifies whether we should join as a client of the peer
// we are connecting to
Client bool
}
type PutServiceArgs struct {
Service string
Value string
}
type GetMeshReply struct {
@ -47,6 +51,11 @@ type QueryMesh struct {
Query string
}
type GetNodeArgs struct {
NodeId string
MeshId string
}
type MeshIpc interface {
CreateMesh(args *NewMeshArgs, reply *string) error
ListMeshes(name string, reply *ListMeshReply) error
@ -57,13 +66,17 @@ type MeshIpc interface {
GetDOT(meshId string, reply *string) error
Query(query QueryMesh, reply *string) error
PutDescription(description string, reply *string) error
PutAlias(alias string, reply *string) error
PutService(args PutServiceArgs, reply *string) error
GetNode(args GetNodeArgs, reply *string) error
DeleteService(service string, reply *string) error
}
const SockAddr = "/tmp/wgmesh_ipc.sock"
func RunIpcHandler(server MeshIpc) error {
if err := os.RemoveAll(SockAddr); err != nil {
return errors.New("Could not find to address")
return errors.New("could not find to address")
}
rpc.Register(server)

View File

@ -30,7 +30,7 @@ func MapKeys[K comparable, V any](m map[K]V) []K {
values := make([]K, len(m))
i := 0
for k, _ := range m {
for k := range m {
values[i] = k
i++
}
@ -58,7 +58,7 @@ type filterFunc[V any] func(V) bool
func Filter[V any](list []V, f filterFunc[V]) []V {
newList := make([]V, 0)
for _, elem := range newList {
for _, elem := range list {
if f(elem) {
newList = append(newList, elem)
}

144
pkg/lib/conv_test.go Normal file
View File

@ -0,0 +1,144 @@
package lib
import (
"slices"
"testing"
)
func stringToInt(input string) int {
return len(input)
}
func intDiv(input int) int {
return input / 2
}
func TestMapValuesMapsValues(t *testing.T) {
values := []int{1, 4, 11, 92}
var theMap map[string]int = map[string]int{
"mynameisjeff": values[0],
"tim": values[1],
"bob": values[2],
"derek": values[3],
}
mapValues := MapValues(theMap)
for _, elem := range mapValues {
if !slices.Contains(values, elem) {
t.Fatalf(`%d is not an expected value`, elem)
}
}
if len(mapValues) != len(theMap) {
t.Fatalf(`Expected length %d got %d`, len(theMap), len(mapValues))
}
}
func TestMapValuesWithExcludeExcludesValues(t *testing.T) {
values := []int{1, 9, 22}
var theMap map[string]int = map[string]int{
"mynameisbob": values[0],
"tim": values[1],
"bob": values[2],
}
exclude := map[string]struct{}{
"tim": {},
}
mapValues := MapValuesWithExclude(theMap, exclude)
if slices.Contains(mapValues, values[1]) {
t.Fatalf(`Failed to exclude expected value`)
}
if len(mapValues) != 2 {
t.Fatalf(`Incorrect expected length`)
}
for _, value := range theMap {
if !slices.Contains(values, value) {
t.Fatalf(`Element does not exist in the list of
expected values`)
}
}
}
func TestMapKeys(t *testing.T) {
keys := []string{"1", "2", "3"}
theMap := map[string]int{
keys[0]: 1,
keys[1]: 2,
keys[2]: 3,
}
mapKeys := MapKeys(theMap)
for _, elem := range mapKeys {
if !slices.Contains(keys, elem) {
t.Fatalf(`%s elem is not an expected key`, elem)
}
}
if len(mapKeys) != len(theMap) {
t.Fatalf(`Missing expected values`)
}
}
func TestMapValues(t *testing.T) {
array := []string{"mynameisjeff", "tim", "bob", "derek"}
intArray := Map(array, stringToInt)
for index, elem := range intArray {
if len(array[index]) != elem {
t.Fatalf(`Have %d want %d`, elem, len(array[index]))
}
}
}
func TestFilterFilterAll(t *testing.T) {
values := []int{1, 2, 3, 4, 5}
filterFunc := func(n int) bool {
return false
}
newValues := Filter(values, filterFunc)
if len(newValues) != 0 {
t.Fatalf(`Expected value was 0`)
}
}
func TestFilterFilterNone(t *testing.T) {
values := []int{1, 2, 3, 4, 5}
filterFunc := func(n int) bool {
return true
}
newValues := Filter(values, filterFunc)
if !slices.Equal(values, newValues) {
t.Fatalf(`Expected lists to be the same`)
}
}
func TestFilterFilterSome(t *testing.T) {
values := []int{1, 2, 3, 4, 5}
filterFunc := func(n int) bool {
return n < 3
}
expected := []int{1, 2}
actual := Filter(values, filterFunc)
if !slices.Equal(expected, actual) {
t.Fatalf(`Expected expected and actual to be the same`)
}
}

46
pkg/lib/hashing.go Normal file
View File

@ -0,0 +1,46 @@
package lib
import (
"hash/fnv"
"sort"
)
type consistentHashRecord[V any] struct {
record V
value int
}
func HashString(value string) int {
f := fnv.New32a()
f.Write([]byte(value))
return int(f.Sum32())
}
// ConsistentHash implementation. Traverse the values until we find a key
// less than ours.
func ConsistentHash[V any](values []V, client V, keyFunc func(V) int) V {
if len(values) == 0 {
panic("values is empty")
}
vs := Map(values, func(v V) consistentHashRecord[V] {
return consistentHashRecord[V]{
v,
keyFunc(v),
}
})
sort.SliceStable(vs, func(i, j int) bool {
return vs[i].value < vs[j].value
})
ourKey := keyFunc(client)
for _, record := range vs {
if ourKey < record.value {
return record.record
}
}
return vs[0].record
}

View File

@ -1,8 +1,11 @@
package lib
import (
"encoding/json"
"io"
"log"
"net"
"net/http"
)
// GetOutboundIP: gets the oubound IP of this packet
@ -15,3 +18,44 @@ func GetOutboundIP() net.IP {
localAddr := conn.LocalAddr().(*net.UDPAddr)
return localAddr.IP
}
const IP_SERVICE = "https://api.ipify.org?format=json"
type IpResponse struct {
Ip string `json:"ip"`
}
func (i *IpResponse) GetIP() net.IP {
return net.ParseIP(i.Ip)
}
// GetPublicIP: get the nodes public IP address. For when a node is behind NAT
func GetPublicIP() (net.IP, error) {
req, err := http.NewRequest(http.MethodGet, IP_SERVICE, nil)
if err != nil {
return nil, err
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
resBody, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
var jsonResponse IpResponse
err = json.Unmarshal([]byte(resBody), &jsonResponse)
if err != nil {
return nil, err
}
return jsonResponse.GetIP(), nil
}

46
pkg/lib/random_test.go Normal file
View File

@ -0,0 +1,46 @@
package lib
import (
"slices"
"testing"
)
// Test that a random subset of length 0 produces a zero length
// list
func TestRandomSubsetOfLength0(t *testing.T) {
values := []int{1, 2, 3, 4, 5, 6, 7, 8}
randomValues := RandomSubsetOfLength(values, 0)
if len(randomValues) != 0 {
t.Fatalf(`Expected length to be 0`)
}
}
func TestRandomSubsetOfLength1(t *testing.T) {
values := []int{1, 2, 3, 4, 5}
randomValues := RandomSubsetOfLength(values, 1)
if len(randomValues) != 1 {
t.Fatalf(`Expected length to be 1`)
}
if !slices.Contains(values, randomValues[0]) {
t.Fatalf(`Expected length to be 1`)
}
}
func TestRandomSubsetEntireList(t *testing.T) {
values := []int{1, 2, 3, 4, 5}
randomValues := RandomSubsetOfLength(values, len(values))
if len(randomValues) != len(values) {
t.Fatalf(`Expected length to be %d was %d`, len(values), len(randomValues))
}
slices.Sort(randomValues)
if !slices.Equal(values, randomValues) {
t.Fatalf(`Expected slices to be equal`)
}
}

300
pkg/lib/rtnetlink.go Normal file
View File

@ -0,0 +1,300 @@
package lib
import (
"encoding/binary"
"fmt"
"net"
"github.com/jsimonetti/rtnetlink"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"golang.org/x/sys/unix"
)
type RtNetlinkConfig struct {
conn *rtnetlink.Conn
}
func NewRtNetlinkConfig() (*RtNetlinkConfig, error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return nil, err
}
return &RtNetlinkConfig{conn: conn}, nil
}
const WIREGUARD_MTU = 1420
// Create a netlink interface if it does not exist. ifName is the name of the netlink interface
func (c *RtNetlinkConfig) CreateLink(ifName string) error {
_, err := net.InterfaceByName(ifName)
if err == nil {
return fmt.Errorf("interface %s already exists", ifName)
}
err = c.conn.Link.New(&rtnetlink.LinkMessage{
Family: unix.AF_UNSPEC,
Flags: unix.IFF_UP,
Attributes: &rtnetlink.LinkAttributes{
Name: ifName,
Info: &rtnetlink.LinkInfo{Kind: "wireguard"},
MTU: uint32(WIREGUARD_MTU),
},
})
if err != nil {
return fmt.Errorf("failed to create wireguard interface: %w", err)
}
return nil
}
// Delete link delete the specified interface
func (c *RtNetlinkConfig) DeleteLink(ifName string) error {
iface, err := net.InterfaceByName(ifName)
if err != nil {
return fmt.Errorf("failed to get interface %s %w", ifName, err)
}
err = c.conn.Link.Delete(uint32(iface.Index))
if err != nil {
return fmt.Errorf("failed to delete wg interface %w", err)
}
return nil
}
// AddAddress adds an address to the given interface.
func (c *RtNetlinkConfig) AddAddress(ifName string, address string) error {
iface, err := net.InterfaceByName(ifName)
if err != nil {
return fmt.Errorf("failed to get interface %s error: %w", ifName, err)
}
addr, cidr, err := net.ParseCIDR(address)
if err != nil {
return fmt.Errorf("failed to parse CIDR %s error: %w", addr, err)
}
family := unix.AF_INET6
ipv4 := cidr.IP.To4()
if ipv4 != nil {
family = unix.AF_INET
}
// Calculate the prefix length
ones, _ := cidr.Mask.Size()
// Calculate the broadcast IP
// Only used when family is AF_INET
var brd net.IP
if ipv4 != nil {
brd = make(net.IP, len(ipv4))
binary.BigEndian.PutUint32(brd, binary.BigEndian.Uint32(ipv4)|^binary.BigEndian.Uint32(net.IP(cidr.Mask).To4()))
}
err = c.conn.Address.New(&rtnetlink.AddressMessage{
Family: uint8(family),
PrefixLength: uint8(ones),
Scope: unix.RT_SCOPE_UNIVERSE,
Index: uint32(iface.Index),
Attributes: &rtnetlink.AddressAttributes{
Address: addr,
Local: addr,
Broadcast: brd,
},
})
if err != nil {
err = fmt.Errorf("failed to add address to link %w", err)
}
return err
}
// AddRoute: adds a route to the routing table.
// ifName is the intrface to add the route to
// gateway is the IP of the gateway device to hop to
// dst is the network prefix of the advertised destination
func (c *RtNetlinkConfig) AddRoute(ifName string, route Route) error {
iface, err := net.InterfaceByName(ifName)
if err != nil {
return fmt.Errorf("failed accessing interface %s error %w", ifName, err)
}
gw := route.Gateway
dst := route.Destination
var family uint8 = unix.AF_INET6
if dst.IP.To4() != nil {
family = unix.AF_INET
}
attr := rtnetlink.RouteAttributes{
Dst: dst.IP,
OutIface: uint32(iface.Index),
Gateway: gw,
}
ones, _ := dst.Mask.Size()
err = c.conn.Route.Replace(&rtnetlink.RouteMessage{
Family: family,
Table: unix.RT_TABLE_MAIN,
Protocol: unix.RTPROT_BOOT,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
DstLength: uint8(ones),
Attributes: attr,
})
if err != nil {
return fmt.Errorf("failed to add route %w", err)
}
return nil
}
// DeleteRoute deletes routes with the gateway and destination
func (c *RtNetlinkConfig) DeleteRoute(ifName string, route Route) error {
iface, err := net.InterfaceByName(ifName)
if err != nil {
return fmt.Errorf("failed accessing interface %s error %w", ifName, err)
}
gw := route.Gateway
dst := route.Destination
var family uint8 = unix.AF_INET6
if dst.IP.To4() != nil {
family = unix.AF_INET
}
attr := rtnetlink.RouteAttributes{
Dst: dst.IP,
OutIface: uint32(iface.Index),
Gateway: gw,
}
ones, _ := dst.Mask.Size()
err = c.conn.Route.Delete(&rtnetlink.RouteMessage{
Family: family,
Table: unix.RT_TABLE_MAIN,
Protocol: unix.RTPROT_BOOT,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
DstLength: uint8(ones),
Attributes: attr,
})
if err != nil {
return fmt.Errorf("failed to delete route %w", err)
}
return nil
}
type Route struct {
Gateway net.IP
Destination net.IPNet
}
func (r1 Route) equal(r2 Route) bool {
return r1.Gateway.String() == r2.Gateway.String() &&
r1.Destination.String() == r2.Destination.String()
}
// DeleteRoutes deletes all routes not in exclude
func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...Route) error {
routes := make([]rtnetlink.RouteMessage, 0)
if len(exclude) != 0 {
lRoutes, err := c.listRoutes(ifName, family, exclude[0].Gateway)
if err != nil {
return err
}
routes = lRoutes
}
ifRoutes := make([]Route, 0)
for _, rtRoute := range routes {
logging.Log.WriteInfof("Routes: %s", rtRoute.Attributes.Dst.String())
maskSize := 128
if family == unix.AF_INET {
maskSize = 32
}
cidr := net.CIDRMask(int(rtRoute.DstLength), maskSize)
route := Route{
Gateway: rtRoute.Attributes.Gateway,
Destination: net.IPNet{IP: rtRoute.Attributes.Dst, Mask: cidr},
}
ifRoutes = append(ifRoutes, route)
}
shouldExclude := func(r Route) bool {
for _, route := range exclude {
if route.equal(r) {
return false
}
}
return true
}
toDelete := Filter(ifRoutes, shouldExclude)
for _, route := range toDelete {
logging.Log.WriteInfof("Deleting route %s", route.Destination.String())
err := c.DeleteRoute(ifName, route)
if err != nil {
return err
}
}
return nil
}
// listRoutes lists all routes on the interface
func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8, gateway net.IP) ([]rtnetlink.RouteMessage, error) {
iface, err := net.InterfaceByName(ifName)
if err != nil {
return nil, fmt.Errorf("failed accessing interface %s error %w", ifName, err)
}
routes, err := c.conn.Route.List()
if err != nil {
return nil, fmt.Errorf("failed to get route %w", err)
}
filterFunc := func(r rtnetlink.RouteMessage) bool {
return r.Attributes.Gateway.Equal(gateway) && r.Attributes.OutIface == uint32(iface.Index)
}
routes = Filter(routes, filterFunc)
return routes, nil
}
func (c *RtNetlinkConfig) Close() error {
return c.conn.Close()
}

42
pkg/lib/timer.go Normal file
View File

@ -0,0 +1,42 @@
package lib
import "time"
type TimerFunc = func() error
type Timer struct {
f TimerFunc
quit chan struct{}
updateRate int
}
func (t *Timer) Run() error {
ticker := time.NewTicker(time.Duration(t.updateRate) * time.Second)
t.quit = make(chan struct{})
for {
select {
case <-ticker.C:
err := t.f()
if err != nil {
return err
}
case <-t.quit:
break
}
}
}
func (t *Timer) Stop() error {
close(t.quit)
return nil
}
func NewTimer(f TimerFunc, updateRate int) *Timer {
return &Timer{
f: f,
updateRate: updateRate,
}
}

View File

@ -2,6 +2,7 @@
package logging
import (
"io"
"os"
"github.com/sirupsen/logrus"
@ -15,6 +16,7 @@ type Logger interface {
WriteInfof(msg string, args ...interface{})
WriteErrorf(msg string, args ...interface{})
WriteWarnf(msg string, args ...interface{})
Writer() io.Writer
}
type LogrusLogger struct {
@ -33,6 +35,10 @@ func (l *LogrusLogger) WriteWarnf(msg string, args ...interface{}) {
l.logger.Warnf(msg, args...)
}
func (l *LogrusLogger) Writer() io.Writer {
return l.logger.Writer()
}
func NewLogrusLogger() *LogrusLogger {
logger := logrus.New()
logger.SetFormatter(&logrus.TextFormatter{FullTimestamp: true})

46
pkg/mesh/alias.go Normal file
View File

@ -0,0 +1,46 @@
package mesh
import (
"fmt"
"github.com/tim-beatham/wgmesh/pkg/hosts"
)
type MeshAliasManager interface {
AddAliases(nodes []MeshNode)
RemoveAliases(node []MeshNode)
}
type AliasManager struct {
hosts hosts.HostsManipulator
}
// AddAliases: on node update or change add aliases to the hosts file
func (a *AliasManager) AddAliases(nodes []MeshNode) {
for _, node := range nodes {
if node.GetAlias() != "" {
a.hosts.AddAddr(hosts.HostsEntry{
Alias: fmt.Sprintf("%s.smeg", node.GetAlias()),
Ip: node.GetWgHost().IP,
})
}
}
}
// RemoveAliases: on node remove remove aliases from the hosts file
func (a *AliasManager) RemoveAliases(nodes []MeshNode) {
for _, node := range nodes {
if node.GetAlias() != "" {
a.hosts.Remove(hosts.HostsEntry{
Alias: fmt.Sprintf("%s.smeg", node.GetAlias()),
Ip: node.GetWgHost().IP,
})
}
}
}
func NewAliasManager() MeshAliasManager {
return &AliasManager{
hosts: hosts.NewHostsManipulator(),
}
}

View File

@ -1,10 +1,13 @@
package mesh
import (
"errors"
"fmt"
"net"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/route"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@ -12,14 +15,17 @@ import (
type MeshConfigApplyer interface {
ApplyConfig() error
RemovePeers(meshId string) error
SetMeshManager(manager MeshManager)
}
// WgMeshConfigApplyer applies WireGuard configuration
type WgMeshConfigApplyer struct {
meshManager *MeshManager
meshManager MeshManager
config *conf.WgMeshConfiguration
routeInstaller route.RouteInstaller
}
func convertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) {
func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, peerToClients map[string][]net.IPNet) (*wgtypes.PeerConfig, error) {
endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint())
if err != nil {
@ -40,10 +46,19 @@ func convertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) {
allowedips = append(allowedips, *ipnet)
}
clients, ok := peerToClients[node.GetWgHost().String()]
if ok {
allowedips = append(allowedips, clients...)
}
keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second
peerConfig := wgtypes.PeerConfig{
PublicKey: pubKey,
Endpoint: endpoint,
AllowedIPs: allowedips,
PublicKey: pubKey,
Endpoint: endpoint,
AllowedIPs: allowedips,
PersistentKeepaliveInterval: &keepAlive,
}
return &peerConfig, nil
@ -56,13 +71,66 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
return err
}
nodes := snap.GetNodes()
nodes := lib.MapValues(snap.GetNodes())
peerConfigs := make([]wgtypes.PeerConfig, len(nodes))
peers := lib.Filter(nodes, func(mn MeshNode) bool {
return mn.GetType() == conf.PEER_ROLE
})
var count int = 0
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
if err != nil {
return err
}
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return err
}
peerToClients := make(map[string][]net.IPNet)
for _, n := range nodes {
peer, err := convertMeshNode(n)
if NodeEquals(n, self) {
continue
}
if n.GetType() == conf.CLIENT_ROLE && len(peers) > 0 && self.GetType() == conf.CLIENT_ROLE {
peer := lib.ConsistentHash(peers, n, func(mn MeshNode) int {
return lib.HashString(mn.GetWgHost().String())
})
dev, err := mesh.GetDevice()
if err != nil {
return err
}
rtnl.AddRoute(dev.Name, lib.Route{
Gateway: peer.GetWgHost().IP,
Destination: *n.GetWgHost(),
})
if err != nil {
return err
}
clients, ok := peerToClients[peer.GetWgHost().String()]
if !ok {
clients = make([]net.IPNet, 0)
peerToClients[peer.GetWgHost().String()] = clients
}
peerToClients[peer.GetWgHost().String()] = append(clients, *n.GetWgHost())
continue
}
peer, err := m.convertMeshNode(n, peerToClients)
if err != nil {
return err
@ -82,11 +150,11 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
return err
}
return m.meshManager.Client.ConfigureDevice(dev.Name, cfg)
return m.meshManager.GetClient().ConfigureDevice(dev.Name, cfg)
}
func (m *WgMeshConfigApplyer) ApplyConfig() error {
for _, mesh := range m.meshManager.Meshes {
for _, mesh := range m.meshManager.GetMeshes() {
err := m.updateWgConf(mesh)
if err != nil {
@ -101,7 +169,7 @@ func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error {
mesh := m.meshManager.GetMesh(meshId)
if mesh == nil {
return errors.New(fmt.Sprintf("mesh %s does not exist", meshId))
return fmt.Errorf("mesh %s does not exist", meshId)
}
dev, err := mesh.GetDevice()
@ -110,7 +178,7 @@ func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error {
return err
}
m.meshManager.Client.ConfigureDevice(dev.Name, wgtypes.Config{
m.meshManager.GetClient().ConfigureDevice(dev.Name, wgtypes.Config{
ReplacePeers: true,
Peers: make([]wgtypes.PeerConfig, 1),
})
@ -118,6 +186,13 @@ func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error {
return nil
}
func NewWgMeshConfigApplyer(manager *MeshManager) MeshConfigApplyer {
return &WgMeshConfigApplyer{meshManager: manager}
func (m *WgMeshConfigApplyer) SetMeshManager(manager MeshManager) {
m.meshManager = manager
}
func NewWgMeshConfigApplyer(config *conf.WgMeshConfiguration) MeshConfigApplyer {
return &WgMeshConfigApplyer{
config: config,
routeInstaller: route.NewRouteInstaller(),
}
}

View File

@ -15,7 +15,7 @@ type MeshGraphConverter interface {
}
type MeshDOTConverter struct {
manager *MeshManager
manager MeshManager
}
func (c *MeshDOTConverter) Generate(meshId string) (string, error) {
@ -34,7 +34,7 @@ func (c *MeshDOTConverter) Generate(meshId string) (string, error) {
}
for _, node := range snapshot.GetNodes() {
c.graphNode(g, node)
c.graphNode(g, node, meshId)
}
nodes := lib.MapValues(snapshot.GetNodes())
@ -55,11 +55,13 @@ func (c *MeshDOTConverter) Generate(meshId string) (string, error) {
}
// graphNode: graphs a node within the mesh
func (c *MeshDOTConverter) graphNode(g *graph.Graph, node MeshNode) {
func (c *MeshDOTConverter) graphNode(g *graph.Graph, node MeshNode, meshId string) {
nodeId := fmt.Sprintf("\"%s\"", node.GetIdentifier())
g.PutNode(nodeId, graph.CIRCLE)
if node.GetHostEndpoint() == c.manager.HostParameters.HostEndpoint {
self, _ := c.manager.GetSelf(meshId)
if node.GetHostEndpoint() == self.GetHostEndpoint() {
return
}
@ -70,6 +72,6 @@ func (c *MeshDOTConverter) graphNode(g *graph.Graph, node MeshNode) {
}
}
func NewMeshDotConverter(m *MeshManager) MeshGraphConverter {
func NewMeshDotConverter(m MeshManager) MeshGraphConverter {
return &MeshDOTConverter{manager: m}
}

View File

@ -13,7 +13,31 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type MeshManager struct {
type MeshManager interface {
CreateMesh(port int) (string, error)
AddMesh(params *AddMeshParams) error
HasChanges(meshid string) bool
GetMesh(meshId string) MeshProvider
EnableInterface(meshId string) error
GetPublicKey(meshId string) (*wgtypes.Key, error)
AddSelf(params *AddSelfParams) error
LeaveMesh(meshId string) error
GetSelf(meshId string) (MeshNode, error)
ApplyConfig() error
SetDescription(description string) error
SetAlias(alias string) error
SetService(service string, value string) error
RemoveService(service string) error
UpdateTimeStamp() error
GetClient() *wgctrl.Client
GetMeshes() map[string]MeshProvider
Prune() error
Close() error
GetMonitor() MeshMonitor
GetNode(string, string) MeshNode
}
type MeshManagerImpl struct {
Meshes map[string]MeshProvider
RouteManager RouteManager
Client *wgctrl.Client
@ -27,18 +51,89 @@ type MeshManager struct {
idGenerator lib.IdGenerator
ipAllocator ip.IPAllocator
interfaceManipulator wg.WgInterfaceManipulator
Monitor MeshMonitor
}
// RemoveService implements MeshManager.
func (m *MeshManagerImpl) RemoveService(service string) error {
for _, mesh := range m.Meshes {
err := mesh.RemoveService(m.HostParameters.HostEndpoint, service)
if err != nil {
return err
}
}
return nil
}
// SetService implements MeshManager.
func (m *MeshManagerImpl) SetService(service string, value string) error {
for _, mesh := range m.Meshes {
err := mesh.AddService(m.HostParameters.HostEndpoint, service, value)
if err != nil {
return err
}
}
return nil
}
func (m *MeshManagerImpl) GetNode(meshid, nodeId string) MeshNode {
mesh, ok := m.Meshes[meshid]
if !ok {
return nil
}
node, err := mesh.GetNode(nodeId)
if err != nil {
return nil
}
return node
}
// GetMonitor implements MeshManager.
func (m *MeshManagerImpl) GetMonitor() MeshMonitor {
return m.Monitor
}
// Prune implements MeshManager.
func (m *MeshManagerImpl) Prune() error {
for _, mesh := range m.Meshes {
err := mesh.Prune(m.conf.PruneTime)
if err != nil {
return err
}
}
return nil
}
// CreateMesh: Creates a new mesh, stores it and returns the mesh id
func (m *MeshManager) CreateMesh(devName string, port int) (string, error) {
func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
meshId, err := m.idGenerator.GetId()
var ifName string = ""
if err != nil {
return "", err
}
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(port)
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
}
}
nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: devName,
DevName: ifName,
Port: port,
Conf: m.conf,
Client: m.Client,
@ -46,39 +141,34 @@ func (m *MeshManager) CreateMesh(devName string, port int) (string, error) {
})
if err != nil {
return "", err
}
err = m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: devName,
Port: port,
})
if err != nil {
return "", nil
return "", fmt.Errorf("error creating mesh: %w", err)
}
m.Meshes[meshId] = nodeManager
err = m.configApplyer.RemovePeers(meshId)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
return meshId, nil
}
type AddMeshParams struct {
MeshId string
DevName string
WgPort int
MeshBytes []byte
}
// AddMesh: Add the mesh to the list of meshes
func (m *MeshManager) AddMesh(params *AddMeshParams) error {
func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
var ifName string
var err error
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(params.WgPort)
if err != nil {
return err
}
}
meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: params.DevName,
DevName: ifName,
Port: params.WgPort,
Conf: m.conf,
Client: m.Client,
@ -96,55 +186,44 @@ func (m *MeshManager) AddMesh(params *AddMeshParams) error {
}
m.Meshes[params.MeshId] = meshProvider
return m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: params.DevName,
Port: params.WgPort,
})
return nil
}
// HasChanges returns true if the mesh has changes
func (m *MeshManager) HasChanges(meshId string) bool {
func (m *MeshManagerImpl) HasChanges(meshId string) bool {
return m.Meshes[meshId].HasChanges()
}
// GetMesh returns the mesh with the given meshid
func (m *MeshManager) GetMesh(meshId string) MeshProvider {
theMesh, _ := m.Meshes[meshId]
func (m *MeshManagerImpl) GetMesh(meshId string) MeshProvider {
theMesh := m.Meshes[meshId]
return theMesh
}
// EnableInterface: Enables the given WireGuard interface.
func (s *MeshManager) EnableInterface(meshId string) error {
func (s *MeshManagerImpl) EnableInterface(meshId string) error {
err := s.configApplyer.ApplyConfig()
if err != nil {
return err
}
meshNode, err := s.GetSelf(meshId)
err = s.RouteManager.InstallRoutes()
if err != nil {
return err
}
mesh := s.GetMesh(meshId)
if err != nil {
return err
}
dev, err := mesh.GetDevice()
if err != nil {
return err
}
return s.interfaceManipulator.EnableInterface(dev.Name, meshNode.GetWgHost().String())
return nil
}
// GetPublicKey: Gets the public key of the WireGuard mesh
func (s *MeshManager) GetPublicKey(meshId string) (*wgtypes.Key, error) {
func (s *MeshManagerImpl) GetPublicKey(meshId string) (*wgtypes.Key, error) {
if s.conf.StubWg {
zeroedKey := make([]byte, wgtypes.KeyLen)
return (*wgtypes.Key)(zeroedKey), nil
}
mesh, ok := s.Meshes[meshId]
if !ok {
@ -166,12 +245,27 @@ type AddSelfParams struct {
// WgPort is the WireGuard port to advertise
WgPort int
// Endpoint is the alias of the machine to send routable packets
// to
Endpoint string
}
// AddSelf adds this host to the mesh
func (s *MeshManager) AddSelf(params *AddSelfParams) error {
func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
mesh := s.GetMesh(params.MeshId)
if mesh == nil {
return fmt.Errorf("addself: mesh %s does not exist", params.MeshId)
}
if params.WgPort == 0 && !s.conf.StubWg {
device, err := mesh.GetDevice()
if err != nil {
return err
}
params.WgPort = device.ListenPort
}
pubKey, err := s.GetPublicKey(params.MeshId)
if err != nil {
@ -189,80 +283,154 @@ func (s *MeshManager) AddSelf(params *AddSelfParams) error {
NodeIP: nodeIP,
WgPort: params.WgPort,
Endpoint: params.Endpoint,
Role: s.conf.Role,
})
if !s.conf.StubWg {
device, err := mesh.GetDevice()
if err != nil {
return fmt.Errorf("failed to get device %w", err)
}
err = s.interfaceManipulator.AddAddress(device.Name, fmt.Sprintf("%s/64", nodeIP))
if err != nil {
return fmt.Errorf("addSelf: failed to add address to dev %w", err)
}
}
s.Meshes[params.MeshId].AddNode(node)
return s.RouteManager.UpdateRoutes()
}
// LeaveMesh leaves the mesh network
func (s *MeshManager) LeaveMesh(meshId string) error {
_, exists := s.Meshes[meshId]
func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
mesh, exists := s.Meshes[meshId]
if !exists {
return errors.New(fmt.Sprintf("mesh %s does not exist", meshId))
return fmt.Errorf("mesh %s does not exist", meshId)
}
err := s.RouteManager.RemoveRoutes(meshId)
if err != nil {
return err
}
if !s.conf.StubWg {
device, e := mesh.GetDevice()
if e != nil {
return err
}
err = s.interfaceManipulator.RemoveInterface(device.Name)
}
// For now just delete the mesh with the ID.
delete(s.Meshes, meshId)
return nil
return err
}
func (s *MeshManager) GetSelf(meshId string) (MeshNode, error) {
func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
meshInstance, ok := s.Meshes[meshId]
if !ok {
return nil, errors.New(fmt.Sprintf("mesh %s does not exist", meshId))
return nil, fmt.Errorf("mesh %s does not exist", meshId)
}
snapshot, err := meshInstance.GetMesh()
logging.Log.WriteInfof(s.HostParameters.HostEndpoint)
node, err := meshInstance.GetNode(s.HostParameters.HostEndpoint)
if err != nil {
return nil, err
}
node, ok := snapshot.GetNodes()[s.HostParameters.HostEndpoint]
if !ok {
return nil, errors.New("the node doesn't exist in the mesh")
}
return node, nil
}
func (s *MeshManager) ApplyConfig() error {
return s.configApplyer.ApplyConfig()
func (s *MeshManagerImpl) ApplyConfig() error {
if s.conf.StubWg {
return nil
}
err := s.configApplyer.ApplyConfig()
if err != nil {
return err
}
return nil
}
func (s *MeshManager) SetDescription(description string) error {
func (s *MeshManagerImpl) SetDescription(description string) error {
for _, mesh := range s.Meshes {
err := mesh.SetDescription(s.HostParameters.HostEndpoint, description)
if mesh.NodeExists(s.HostParameters.HostEndpoint) {
err := mesh.SetDescription(s.HostParameters.HostEndpoint, description)
if err != nil {
return err
if err != nil {
return err
}
}
}
return nil
}
// UpdateTimeStamp updates the timestamp of this node in all meshes
func (s *MeshManager) UpdateTimeStamp() error {
// SetAlias implements MeshManager.
func (s *MeshManagerImpl) SetAlias(alias string) error {
for _, mesh := range s.Meshes {
snapshot, err := mesh.GetMesh()
if mesh.NodeExists(s.HostParameters.HostEndpoint) {
err := mesh.SetAlias(s.HostParameters.HostEndpoint, alias)
if err != nil {
return err
}
}
}
return nil
}
// UpdateTimeStamp updates the timestamp of this node in all meshes
func (s *MeshManagerImpl) UpdateTimeStamp() error {
for _, mesh := range s.Meshes {
if mesh.NodeExists(s.HostParameters.HostEndpoint) {
err := mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint)
if err != nil {
return err
}
}
}
return nil
}
func (s *MeshManagerImpl) GetClient() *wgctrl.Client {
return s.Client
}
func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider {
return s.Meshes
}
// Close the mesh manager
func (s *MeshManagerImpl) Close() error {
if s.conf.StubWg {
return nil
}
for _, mesh := range s.Meshes {
dev, err := mesh.GetDevice()
if err != nil {
return err
}
_, exists := snapshot.GetNodes()[s.HostParameters.HostEndpoint]
err = s.interfaceManipulator.RemoveInterface(dev.Name)
if exists {
err = mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint)
if err != nil {
return err
}
if err != nil {
return err
}
}
@ -278,22 +446,25 @@ type NewMeshManagerParams struct {
IdGenerator lib.IdGenerator
IPAllocator ip.IPAllocator
InterfaceManipulator wg.WgInterfaceManipulator
ConfigApplyer MeshConfigApplyer
RouteManager RouteManager
}
// Creates a new instance of a mesh manager with the given parameters
func NewMeshManager(params *NewMeshManagerParams) *MeshManager {
func NewMeshManager(params *NewMeshManagerParams) MeshManager {
hostParams := HostParameters{}
switch params.Conf.Endpoint {
case "":
hostParams.HostEndpoint = fmt.Sprintf("%s:%s", lib.GetOutboundIP().String(), params.Conf.GrpcPort)
ip, _ := lib.GetPublicIP()
hostParams.HostEndpoint = fmt.Sprintf("%s:%s", ip.String(), params.Conf.GrpcPort)
default:
hostParams.HostEndpoint = fmt.Sprintf("%s:%s", params.Conf.Endpoint, params.Conf.GrpcPort)
}
logging.Log.WriteInfof("Endpoint %s", hostParams.HostEndpoint)
m := &MeshManager{
m := &MeshManagerImpl{
Meshes: make(map[string]MeshProvider),
HostParameters: &hostParams,
meshProviderFactory: params.MeshProvider,
@ -301,10 +472,22 @@ func NewMeshManager(params *NewMeshManagerParams) *MeshManager {
Client: params.Client,
conf: &params.Conf,
}
m.configApplyer = NewWgMeshConfigApplyer(m)
m.RouteManager = NewRouteManager(m)
m.configApplyer = params.ConfigApplyer
m.RouteManager = params.RouteManager
if m.RouteManager == nil {
m.RouteManager = NewRouteManager(m)
}
m.idGenerator = params.IdGenerator
m.ipAllocator = params.IPAllocator
m.interfaceManipulator = params.InterfaceManipulator
m.Monitor = NewMeshMonitor(m)
aliasManager := NewAliasManager()
m.Monitor.AddUpdateCallback(aliasManager.AddAliases)
m.Monitor.AddRemoveCallback(aliasManager.RemoveAliases)
return m
}

247
pkg/mesh/manager_test.go Normal file
View File

@ -0,0 +1,247 @@
package mesh
import (
"testing"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/wg"
)
func getMeshConfiguration() *conf.WgMeshConfiguration {
return &conf.WgMeshConfiguration{
GrpcPort: "8080",
Endpoint: "abc.com",
ClusterSize: 64,
SyncRate: 4,
BranchRate: 3,
InterClusterChance: 0.15,
InfectionCount: 2,
KeepAliveTime: 60,
}
}
func getMeshManager() MeshManager {
manager := NewMeshManager(&NewMeshManagerParams{
Conf: *getMeshConfiguration(),
Client: nil,
MeshProvider: &StubMeshProviderFactory{},
NodeFactory: &StubNodeFactory{Config: getMeshConfiguration()},
IdGenerator: &lib.UUIDGenerator{},
IPAllocator: &ip.ULABuilder{},
InterfaceManipulator: &wg.WgInterfaceManipulatorStub{},
ConfigApplyer: &MeshConfigApplyerStub{},
RouteManager: &RouteManagerStub{},
})
return manager
}
func TestCreateMeshCreatesANewMeshProvider(t *testing.T) {
manager := getMeshManager()
meshId, err := manager.CreateMesh("wg0", 5000)
if err != nil {
t.Error(err)
}
if len(meshId) == 0 {
t.Fatal(`meshId should not be empty`)
}
_, exists := manager.GetMeshes()[meshId]
if !exists {
t.Fatal(`mesh was not created when it should be`)
}
}
func TestAddMeshAddsAMesh(t *testing.T) {
manager := getMeshManager()
meshId := "meshid123"
manager.AddMesh(&AddMeshParams{
MeshId: meshId,
DevName: "wg0",
WgPort: 6000,
MeshBytes: make([]byte, 0),
})
mesh := manager.GetMesh(meshId)
if mesh == nil || mesh.GetMeshId() != meshId {
t.Fatalf(`mesh has not been added to the list of meshes`)
}
}
func TestAddMeshMeshAlreadyExistsReplacesIt(t *testing.T) {
manager := getMeshManager()
meshId := "meshid123"
for i := 0; i < 2; i++ {
err := manager.AddMesh(&AddMeshParams{
MeshId: meshId,
DevName: "wg0",
WgPort: 6000,
MeshBytes: make([]byte, 0),
})
if err != nil {
t.Error(err)
}
}
mesh := manager.GetMesh(meshId)
if mesh == nil || mesh.GetMeshId() != meshId {
t.Fatalf(`mesh has not been added to the list of meshes`)
}
}
func TestAddSelfAddsSelfToTheMesh(t *testing.T) {
manager := getMeshManager()
meshId := "meshid123"
err := manager.AddMesh(&AddMeshParams{
MeshId: meshId,
DevName: "wg0",
WgPort: 6000,
MeshBytes: make([]byte, 0),
})
if err != nil {
t.Error(err)
}
err = manager.AddSelf(&AddSelfParams{
MeshId: meshId,
WgPort: 5000,
Endpoint: "abc.com",
})
if err != nil {
t.Error(err)
}
mesh, err := manager.GetMesh(meshId).GetMesh()
if err != nil {
t.Error(err)
}
_, ok := mesh.GetNodes()["abc.com"]
if !ok {
t.Fatalf(`node has not been added`)
}
}
func TestAddSelfToMeshAlreadyInMesh(t *testing.T) {
TestAddSelfAddsSelfToTheMesh(t)
TestAddSelfAddsSelfToTheMesh(t)
}
func TestAddSelfToMeshMeshDoesNotExist(t *testing.T) {
manager := getMeshManager()
meshId := "meshid123"
err := manager.AddSelf(&AddSelfParams{
MeshId: meshId,
WgPort: 5000,
Endpoint: "abc.com",
})
if err == nil {
t.Fatalf(`Expected error to be thrown`)
}
}
func TestLeaveMeshMeshDoesNotExist(t *testing.T) {
manager := getMeshManager()
meshId := "meshid123"
err := manager.LeaveMesh(meshId)
if err == nil {
t.Fatalf(`Expected error to be thrown`)
}
}
func TestLeaveMeshDeletesMesh(t *testing.T) {
manager := getMeshManager()
meshId := "meshid123"
err := manager.AddMesh(&AddMeshParams{
MeshId: meshId,
DevName: "wg0",
WgPort: 6000,
MeshBytes: make([]byte, 0),
})
if err != nil {
t.Error(err)
}
err = manager.LeaveMesh(meshId)
if err != nil {
t.Fatalf("%s", err.Error())
}
_, exists := manager.GetMeshes()[meshId]
if exists {
t.Fatalf(`expected mesh to have been deleted`)
}
}
func TestSetDescription(t *testing.T) {
manager := getMeshManager()
description := "wooooo"
meshId1, _ := manager.CreateMesh(5000)
meshId2, _ := manager.CreateMesh(5001)
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId2,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
err := manager.SetDescription(description)
if err != nil {
t.Fatalf(`failed to set the descriptions`)
}
}
func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) {
manager := getMeshManager()
meshId1, _ := manager.CreateMesh(5000)
meshId2, _ := manager.CreateMesh(5001)
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId2,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
err := manager.UpdateTimeStamp()
if err != nil {
t.Fatalf(`failed to update the timestamp`)
}
}

81
pkg/mesh/monitor.go Normal file
View File

@ -0,0 +1,81 @@
package mesh
type OnChange = func([]MeshNode)
type MeshMonitor interface {
AddUpdateCallback(cb OnChange)
AddRemoveCallback(cb OnChange)
Trigger() error
}
type MeshMonitorImpl struct {
updateCbs []OnChange
removeCbs []OnChange
nodes map[string]MeshNode
manager MeshManager
}
// Trigger causes the mesh monitor to trigger all of
// the callbacks.
func (m *MeshMonitorImpl) Trigger() error {
changedNodes := make([]MeshNode, 0)
removedNodes := make([]MeshNode, 0)
nodes := make(map[string]MeshNode)
for _, mesh := range m.manager.GetMeshes() {
snapshot, err := mesh.GetMesh()
if err != nil {
return err
}
for _, node := range snapshot.GetNodes() {
previous, exists := m.nodes[node.GetWgHost().String()]
if !exists || !NodeEquals(previous, node) {
changedNodes = append(changedNodes, node)
}
nodes[node.GetWgHost().String()] = node
}
}
for _, previous := range m.nodes {
_, ok := nodes[previous.GetWgHost().String()]
if !ok {
removedNodes = append(removedNodes, previous)
}
}
if len(removedNodes) > 0 {
for _, cb := range m.removeCbs {
cb(removedNodes)
}
}
if len(changedNodes) > 0 {
for _, cb := range m.updateCbs {
cb(changedNodes)
}
}
return nil
}
func (m *MeshMonitorImpl) AddUpdateCallback(cb OnChange) {
m.updateCbs = append(m.updateCbs, cb)
}
func (m *MeshMonitorImpl) AddRemoveCallback(cb OnChange) {
m.removeCbs = append(m.removeCbs, cb)
}
func NewMeshMonitor(manager MeshManager) MeshMonitor {
return &MeshMonitorImpl{
updateCbs: make([]OnChange, 0),
nodes: make(map[string]MeshNode),
manager: manager,
}
}

16
pkg/mesh/pruner.go Normal file
View File

@ -0,0 +1,16 @@
package mesh
import (
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
func pruneFunction(m MeshManager) lib.TimerFunc {
return func() error {
return m.Prune()
}
}
func NewPruner(m MeshManager, conf conf.WgMeshConfiguration) *lib.Timer {
return lib.NewTimer(pruneFunction(m), conf.PruneTime/2)
}

View File

@ -1,22 +1,29 @@
package mesh
import (
"fmt"
"net"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/route"
"golang.org/x/sys/unix"
)
type RouteManager interface {
UpdateRoutes() error
InstallRoutes() error
RemoveRoutes(meshId string) error
}
type RouteManagerImpl struct {
meshManager *MeshManager
meshManager MeshManager
routeInstaller route.RouteInstaller
}
func (r *RouteManagerImpl) UpdateRoutes() error {
meshes := r.meshManager.Meshes
meshes := r.meshManager.GetMeshes()
ulaBuilder := new(ip.ULABuilder)
for _, mesh1 := range meshes {
@ -32,7 +39,13 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
return err
}
err = mesh1.AddRoutes(r.meshManager.HostParameters.HostEndpoint, ipNet.String())
self, err := r.meshManager.GetSelf(mesh1.GetMeshId())
if err != nil {
return err
}
err = mesh1.AddRoutes(self.GetHostEndpoint(), ipNet.String())
if err != nil {
return err
@ -43,6 +56,129 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
return nil
}
func NewRouteManager(m *MeshManager) RouteManager {
// removeRoutes: removes all meshes we are no longer a part of
func (r *RouteManagerImpl) RemoveRoutes(meshId string) error {
ulaBuilder := new(ip.ULABuilder)
meshes := r.meshManager.GetMeshes()
ipNet, err := ulaBuilder.GetIPNet(meshId)
if err != nil {
return err
}
for _, mesh1 := range meshes {
self, err := r.meshManager.GetSelf(meshId)
if err != nil {
return err
}
mesh1.RemoveRoutes(self.GetHostEndpoint(), ipNet.String())
}
return nil
}
// AddRoute adds a route to the given interface
func (m *RouteManagerImpl) addRoute(ifName string, meshPrefix string, routes ...lib.Route) error {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return fmt.Errorf("failed to create config: %w", err)
}
defer rtnl.Close()
// Delete any routes that may be vacant
err = rtnl.DeleteRoutes(ifName, unix.AF_INET6, routes...)
if err != nil {
return err
}
for _, route := range routes {
if route.Destination.String() == meshPrefix {
continue
}
err = rtnl.AddRoute(ifName, route)
if err != nil {
return err
}
}
return nil
}
func (m *RouteManagerImpl) installRoute(ifName string, meshid string, node MeshNode) error {
routeMapFunc := func(route string) lib.Route {
_, cidr, _ := net.ParseCIDR(route)
r := lib.Route{
Destination: *cidr,
Gateway: node.GetWgHost().IP,
}
return r
}
ipBuilder := &ip.ULABuilder{}
ipNet, err := ipBuilder.GetIPNet(meshid)
if err != nil {
return err
}
routes := lib.Map(append(node.GetRoutes(), ipNet.String()), routeMapFunc)
return m.addRoute(ifName, ipNet.String(), routes...)
}
func (m *RouteManagerImpl) installRoutes(meshProvider MeshProvider) error {
mesh, err := meshProvider.GetMesh()
if err != nil {
return err
}
dev, err := meshProvider.GetDevice()
if err != nil {
return err
}
self, err := m.meshManager.GetSelf(meshProvider.GetMeshId())
if err != nil {
return err
}
for _, node := range mesh.GetNodes() {
if self.GetHostEndpoint() == node.GetHostEndpoint() {
continue
}
err = m.installRoute(dev.Name, meshProvider.GetMeshId(), node)
if err != nil {
return err
}
}
return nil
}
// InstallRoutes installs all routes to the RIB
func (r *RouteManagerImpl) InstallRoutes() error {
for _, mesh := range r.meshManager.GetMeshes() {
err := r.installRoutes(mesh)
if err != nil {
return err
}
}
return nil
}
func NewRouteManager(m MeshManager) RouteManager {
return &RouteManagerImpl{meshManager: m, routeInstaller: route.NewRouteInstaller()}
}

16
pkg/mesh/route_stub.go Normal file
View File

@ -0,0 +1,16 @@
package mesh
type RouteManagerStub struct {
}
func (r *RouteManagerStub) UpdateRoutes() error {
return nil
}
func (r *RouteManagerStub) InstallRoutes() error {
return nil
}
func (r *RouteManagerStub) RemoveRoutes(meshId string) error {
return nil
}

320
pkg/mesh/stub_types.go Normal file
View File

@ -0,0 +1,320 @@
package mesh
import (
"fmt"
"net"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type MeshNodeStub struct {
hostEndpoint string
publicKey wgtypes.Key
wgEndpoint string
wgHost *net.IPNet
timeStamp int64
routes []string
identifier string
description string
}
// GetType implements MeshNode.
func (*MeshNodeStub) GetType() conf.NodeType {
return PEER
}
// GetServices implements MeshNode.
func (*MeshNodeStub) GetServices() map[string]string {
return make(map[string]string)
}
// GetAlias implements MeshNode.
func (*MeshNodeStub) GetAlias() string {
return ""
}
func (m *MeshNodeStub) GetHostEndpoint() string {
return m.hostEndpoint
}
func (m *MeshNodeStub) GetPublicKey() (wgtypes.Key, error) {
return m.publicKey, nil
}
func (m *MeshNodeStub) GetWgEndpoint() string {
return m.wgEndpoint
}
func (m *MeshNodeStub) GetWgHost() *net.IPNet {
return m.wgHost
}
func (m *MeshNodeStub) GetTimeStamp() int64 {
return m.timeStamp
}
func (m *MeshNodeStub) GetRoutes() []string {
return m.routes
}
func (m *MeshNodeStub) GetIdentifier() string {
return m.identifier
}
func (m *MeshNodeStub) GetDescription() string {
return m.description
}
type MeshSnapshotStub struct {
nodes map[string]MeshNode
}
func (s *MeshSnapshotStub) GetNodes() map[string]MeshNode {
return s.nodes
}
type MeshProviderStub struct {
meshId string
snapshot *MeshSnapshotStub
}
// GetNodeIds implements MeshProvider.
func (*MeshProviderStub) GetPeers() []string {
return make([]string, 0)
}
// GetNode implements MeshProvider.
func (*MeshProviderStub) GetNode(string) (MeshNode, error) {
return nil, nil
}
// NodeExists implements MeshProvider.
func (*MeshProviderStub) NodeExists(string) bool {
return false
}
// AddService implements MeshProvider.
func (*MeshProviderStub) AddService(nodeId string, key string, value string) error {
return nil
}
// RemoveService implements MeshProvider.
func (*MeshProviderStub) RemoveService(nodeId string, key string) error {
return nil
}
// SetAlias implements MeshProvider.
func (*MeshProviderStub) SetAlias(nodeId string, alias string) error {
panic("unimplemented")
}
// RemoveRoutes implements MeshProvider.
func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error {
return nil
}
// Prune implements MeshProvider.
func (*MeshProviderStub) Prune(pruneAmount int) error {
return nil
}
// UpdateTimeStamp implements MeshProvider.
func (*MeshProviderStub) UpdateTimeStamp(nodeId string) error {
return nil
}
func (s *MeshProviderStub) AddNode(node MeshNode) {
s.snapshot.nodes[node.GetHostEndpoint()] = node
}
func (s *MeshProviderStub) GetMesh() (MeshSnapshot, error) {
return s.snapshot, nil
}
func (s *MeshProviderStub) GetMeshId() string {
return s.meshId
}
func (s *MeshProviderStub) Save() []byte {
return make([]byte, 0)
}
func (s *MeshProviderStub) Load(bytes []byte) error {
return nil
}
func (s *MeshProviderStub) GetDevice() (*wgtypes.Device, error) {
pubKey, _ := wgtypes.GenerateKey()
return &wgtypes.Device{
PublicKey: pubKey,
}, nil
}
func (s *MeshProviderStub) SaveChanges() {}
func (s *MeshProviderStub) HasChanges() bool {
return false
}
func (s *MeshProviderStub) AddRoutes(nodeId string, route ...string) error {
return nil
}
func (s *MeshProviderStub) GetSyncer() MeshSyncer {
return nil
}
func (s *MeshProviderStub) SetDescription(nodeId string, description string) error {
return nil
}
type StubMeshProviderFactory struct{}
func (s *StubMeshProviderFactory) CreateMesh(params *MeshProviderFactoryParams) (MeshProvider, error) {
return &MeshProviderStub{
meshId: params.MeshId,
snapshot: &MeshSnapshotStub{nodes: make(map[string]MeshNode)},
}, nil
}
type StubNodeFactory struct {
Config *conf.WgMeshConfiguration
}
func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode {
_, wgHost, _ := net.ParseCIDR(fmt.Sprintf("%s/128", params.NodeIP.String()))
return &MeshNodeStub{
hostEndpoint: params.Endpoint,
publicKey: *params.PublicKey,
wgEndpoint: fmt.Sprintf("%s:%s", params.Endpoint, s.Config.GrpcPort),
wgHost: wgHost,
timeStamp: time.Now().Unix(),
routes: make([]string, 0),
identifier: "abc",
description: "A Mesh Node Stub",
}
}
type MeshConfigApplyerStub struct{}
func (a *MeshConfigApplyerStub) ApplyConfig() error {
return nil
}
func (a *MeshConfigApplyerStub) RemovePeers(meshId string) error {
return nil
}
func (a *MeshConfigApplyerStub) SetMeshManager(manager MeshManager) {
}
type MeshManagerStub struct {
meshes map[string]MeshProvider
}
// GetNode implements MeshManager.
func (*MeshManagerStub) GetNode(string, string) MeshNode {
panic("unimplemented")
}
// RemoveService implements MeshManager.
func (*MeshManagerStub) RemoveService(service string) error {
panic("unimplemented")
}
// SetService implements MeshManager.
func (*MeshManagerStub) SetService(service string, value string) error {
panic("unimplemented")
}
// GetMonitor implements MeshManager.
func (*MeshManagerStub) GetMonitor() MeshMonitor {
panic("unimplemented")
}
// SetAlias implements MeshManager.
func (*MeshManagerStub) SetAlias(alias string) error {
panic("unimplemented")
}
// Close implements MeshManager.
func (*MeshManagerStub) Close() error {
panic("unimplemented")
}
// Prune implements MeshManager.
func (*MeshManagerStub) Prune() error {
return nil
}
func NewMeshManagerStub() MeshManager {
return &MeshManagerStub{meshes: make(map[string]MeshProvider)}
}
func (m *MeshManagerStub) CreateMesh(port int) (string, error) {
return "tim123", nil
}
func (m *MeshManagerStub) AddMesh(params *AddMeshParams) error {
m.meshes[params.MeshId] = &MeshProviderStub{
params.MeshId,
&MeshSnapshotStub{nodes: make(map[string]MeshNode)},
}
return nil
}
func (m *MeshManagerStub) HasChanges(meshId string) bool {
return false
}
func (m *MeshManagerStub) GetMesh(meshId string) MeshProvider {
return &MeshProviderStub{
meshId: meshId,
snapshot: &MeshSnapshotStub{nodes: make(map[string]MeshNode)}}
}
func (m *MeshManagerStub) EnableInterface(meshId string) error {
return nil
}
func (m *MeshManagerStub) GetPublicKey(meshId string) (*wgtypes.Key, error) {
key, _ := wgtypes.GenerateKey()
return &key, nil
}
func (m *MeshManagerStub) AddSelf(params *AddSelfParams) error {
return nil
}
func (m *MeshManagerStub) GetSelf(meshId string) (MeshNode, error) {
return nil, nil
}
func (m *MeshManagerStub) ApplyConfig() error {
return nil
}
func (m *MeshManagerStub) SetDescription(description string) error {
return nil
}
func (m *MeshManagerStub) UpdateTimeStamp() error {
return nil
}
func (m *MeshManagerStub) GetClient() *wgctrl.Client {
return nil
}
func (m *MeshManagerStub) GetMeshes() map[string]MeshProvider {
return m.meshes
}
func (m *MeshManagerStub) LeaveMesh(meshId string) error {
return nil
}

View File

@ -10,6 +10,13 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
const (
// Data Exchanged Between Peers
PEER conf.NodeType = "peer"
// Data Exchanged Between Clients
CLIENT conf.NodeType = "client"
)
// MeshNode represents an implementation of a node in a mesh
type MeshNode interface {
// GetHostEndpoint: gets the gRPC endpoint of the node
@ -28,6 +35,17 @@ type MeshNode interface {
GetIdentifier() string
// GetDescription: returns the description for this node
GetDescription() string
// GetAlias: associates the node with an alias. Potentially used
// for DNS and so forth.
GetAlias() string
// GetServices: returns a list of services offered by the node
GetServices() map[string]string
GetType() conf.NodeType
}
// NodeEquals: determines if two mesh nodes are equivalent to one another
func NodeEquals(node1, node2 MeshNode) bool {
return node1.GetHostEndpoint() == node2.GetHostEndpoint()
}
type MeshSnapshot interface {
@ -46,7 +64,7 @@ type MeshSyncer interface {
type MeshProvider interface {
// AddNode() adds a node to the mesh
AddNode(node MeshNode)
// GetMesh() returns a snapshot of the mesh provided by the mesh provider
// GetMesh() returns a snapshot of the mesh provided by the mesh provider.
GetMesh() (MeshSnapshot, error)
// GetMeshId() returns the ID of the mesh network
GetMeshId() string
@ -58,20 +76,37 @@ type MeshProvider interface {
GetDevice() (*wgtypes.Device, error)
// HasChanges returns true if we have changes since last time we synced
HasChanges() bool
// Record that we have changges and save the corresponding changes
// Record that we have changes and save the corresponding changes
SaveChanges()
// UpdateTimeStamp: update the timestamp of the given node
UpdateTimeStamp(nodeId string) error
// AddRoutes: adds routes to the given node
AddRoutes(nodeId string, route ...string) error
// DeleteRoutes: deletes the routes from the node
RemoveRoutes(nodeId string, route ...string) error
// GetSyncer: returns the automerge syncer for sync
GetSyncer() MeshSyncer
// GetNode get a particular not within the mesh
GetNode(string) (MeshNode, error)
// NodeExists: returns true if a particular node exists false otherwise
NodeExists(string) bool
// SetDescription: sets the description of this automerge data type
SetDescription(nodeId string, description string) error
// SetAlias: set the alias of the nodeId
SetAlias(nodeId string, alias string) error
// AddService: adds the service to the given node
AddService(nodeId, key, value string) error
// RemoveService: removes the service form the node. throws an error if the service does not exist
RemoveService(nodeId, key string) error
// Prune: prunes all nodes that have not updated their timestamp in
// pruneAmount seconds
Prune(pruneAmount int) error
GetPeers() []string
}
// HostParameters contains the IDs of a node
type HostParameters struct {
HostEndpoint string
// TODO: Contain the WireGungracefullyuard identifier in this
}
// MeshProviderFactoryParams parameters required to build a mesh provider
@ -95,6 +130,7 @@ type MeshNodeFactoryParams struct {
NodeIP net.IP
WgPort int
Endpoint string
Role conf.NodeType
}
// MeshBuilder build the hosts mesh node for it to be added to the mesh

View File

@ -1,29 +0,0 @@
package middleware
import (
"context"
"errors"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/rpc"
)
// AuthRpcProvider implements the AuthRpcProvider service
type AuthRpcProvider struct {
rpc.UnimplementedAuthenticationServer
}
// JoinMesh handles a JoinMeshRequest. Succeeds by stating the node managed to join the mesh
// or returns an error if it failed
func (a *AuthRpcProvider) JoinMesh(ctx context.Context, in *rpc.JoinAuthMeshRequest) (*rpc.JoinAuthMeshReply, error) {
meshId := in.MeshId
if meshId == "" {
return nil, errors.New("Must specify the meshId")
}
logging.Log.WriteInfof("MeshID: " + in.MeshId)
var token string = ""
return &rpc.JoinAuthMeshReply{Success: true, Token: &token}, nil
}

View File

@ -5,6 +5,7 @@ import (
"fmt"
"github.com/jmespath/go-jmespath"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
)
@ -16,7 +17,7 @@ type Querier interface {
}
type JmesQuerier struct {
manager *mesh.MeshManager
manager mesh.MeshManager
}
type QueryError struct {
@ -24,13 +25,16 @@ type QueryError struct {
}
type QueryNode struct {
HostEndpoint string `json:"hostEndpoint"`
PublicKey string `json:"publicKey"`
WgEndpoint string `json:"wgEndpoint"`
WgHost string `json:"wgHost"`
Timestamp int64 `json:"timestmap"`
Description string `json:"description"`
Routes []string `json:"routes"`
HostEndpoint string `json:"hostEndpoint"`
PublicKey string `json:"publicKey"`
WgEndpoint string `json:"wgEndpoint"`
WgHost string `json:"wgHost"`
Timestamp int64 `json:"timestamp"`
Description string `json:"description"`
Routes []string `json:"routes"`
Alias string `json:"alias"`
Services map[string]string `json:"services"`
Type conf.NodeType `json:"type"`
}
func (m *QueryError) Error() string {
@ -39,7 +43,7 @@ func (m *QueryError) Error() string {
// Query: queries the data
func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
mesh, ok := j.manager.Meshes[meshId]
mesh, ok := j.manager.GetMeshes()[meshId]
if !ok {
return nil, &QueryError{msg: fmt.Sprintf("%s does not exist", meshId)}
@ -51,7 +55,7 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
return nil, err
}
nodes := lib.Map(lib.MapValues(snapshot.GetNodes()), meshNodeToQueryNode)
nodes := lib.Map(lib.MapValues(snapshot.GetNodes()), MeshNodeToQueryNode)
result, err := jmespath.Search(queryParams, nodes)
@ -63,7 +67,7 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
return bytes, err
}
func meshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
func MeshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode := new(QueryNode)
queryNode.HostEndpoint = node.GetHostEndpoint()
pubKey, _ := node.GetPublicKey()
@ -76,9 +80,13 @@ func meshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode.Timestamp = node.GetTimeStamp()
queryNode.Routes = node.GetRoutes()
queryNode.Description = node.GetDescription()
queryNode.Alias = node.GetAlias()
queryNode.Services = node.GetServices()
queryNode.Type = node.GetType()
return queryNode
}
func NewJmesQuerier(manager *mesh.MeshManager) Querier {
func NewJmesQuerier(manager mesh.MeshManager) Querier {
return &JmesQuerier{manager: manager}
}

View File

@ -2,45 +2,49 @@ package robin
import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
"time"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query"
"github.com/tim-beatham/wgmesh/pkg/rpc"
)
type IpcHandler struct {
Server *ctrlserver.MeshCtrlServer
ipAllocator ip.IPAllocator
Server ctrlserver.CtrlServer
}
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
meshId, err := n.Server.MeshManager.CreateMesh(args.IfName, args.WgPort)
meshId, err := n.Server.GetMeshManager().CreateMesh(args.WgPort)
if err != nil {
return err
}
err = n.Server.MeshManager.AddSelf(&mesh.AddSelfParams{
err = n.Server.GetMeshManager().AddSelf(&mesh.AddSelfParams{
MeshId: meshId,
WgPort: args.WgPort,
Endpoint: args.Endpoint,
})
if err != nil {
return err
}
*reply = meshId
return err
}
func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
meshNames := make([]string, len(n.Server.MeshManager.Meshes))
meshNames := make([]string, len(n.Server.GetMeshManager().GetMeshes()))
i := 0
for meshId, _ := range n.Server.MeshManager.Meshes {
for meshId, _ := range n.Server.GetMeshManager().GetMeshes() {
meshNames[i] = meshId
i++
}
@ -50,7 +54,7 @@ func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
}
func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
peerConnection, err := n.Server.ConnectionManager.GetConnection(args.IpAdress)
peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAdress)
if err != nil {
return err
@ -77,9 +81,8 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
return err
}
err = n.Server.MeshManager.AddMesh(&mesh.AddMeshParams{
err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
MeshId: args.MeshId,
DevName: args.IfName,
WgPort: args.Port,
MeshBytes: meshReply.Mesh,
})
@ -88,7 +91,7 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
return err
}
err = n.Server.MeshManager.AddSelf(&mesh.AddSelfParams{
err = n.Server.GetMeshManager().AddSelf(&mesh.AddSelfParams{
MeshId: args.MeshId,
WgPort: args.Port,
Endpoint: args.Endpoint,
@ -104,7 +107,7 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
// LeaveMesh leaves a mesh network
func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error {
err := n.Server.MeshManager.LeaveMesh(meshId)
err := n.Server.GetMeshManager().LeaveMesh(meshId)
if err == nil {
*reply = fmt.Sprintf("Left Mesh %s", meshId)
@ -114,7 +117,12 @@ func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error {
}
func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
mesh := n.Server.MeshManager.GetMesh(meshId)
mesh := n.Server.GetMeshManager().GetMesh(meshId)
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
meshSnapshot, err := mesh.GetMesh()
if err != nil {
@ -124,6 +132,7 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
if mesh == nil {
return errors.New("mesh does not exist")
}
nodes := make([]ctrlserver.MeshNode, len(meshSnapshot.GetNodes()))
i := 0
@ -141,6 +150,9 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
WgHost: node.GetWgHost().String(),
Timestamp: node.GetTimeStamp(),
Routes: node.GetRoutes(),
Description: node.GetDescription(),
Alias: node.GetAlias(),
Services: node.GetServices(),
}
nodes[i] = node
@ -152,7 +164,7 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
}
func (n *IpcHandler) EnableInterface(meshId string, reply *string) error {
err := n.Server.MeshManager.EnableInterface(meshId)
err := n.Server.GetMeshManager().EnableInterface(meshId)
if err != nil {
*reply = err.Error()
@ -164,7 +176,7 @@ func (n *IpcHandler) EnableInterface(meshId string, reply *string) error {
}
func (n *IpcHandler) GetDOT(meshId string, reply *string) error {
g := mesh.NewMeshDotConverter(n.Server.MeshManager)
g := mesh.NewMeshDotConverter(n.Server.GetMeshManager())
result, err := g.Generate(meshId)
@ -177,7 +189,7 @@ func (n *IpcHandler) GetDOT(meshId string, reply *string) error {
}
func (n *IpcHandler) Query(params ipc.QueryMesh, reply *string) error {
queryResponse, err := n.Server.Querier.Query(params.MeshId, params.Query)
queryResponse, err := n.Server.GetQuerier().Query(params.MeshId, params.Query)
if err != nil {
return err
@ -188,7 +200,7 @@ func (n *IpcHandler) Query(params ipc.QueryMesh, reply *string) error {
}
func (n *IpcHandler) PutDescription(description string, reply *string) error {
err := n.Server.MeshManager.SetDescription(description)
err := n.Server.GetMeshManager().SetDescription(description)
if err != nil {
return err
@ -198,8 +210,62 @@ func (n *IpcHandler) PutDescription(description string, reply *string) error {
return nil
}
func (n *IpcHandler) PutAlias(alias string, reply *string) error {
err := n.Server.GetMeshManager().SetAlias(alias)
if err != nil {
return err
}
*reply = fmt.Sprintf("Set alias to %s", alias)
return nil
}
func (n *IpcHandler) PutService(service ipc.PutServiceArgs, reply *string) error {
err := n.Server.GetMeshManager().SetService(service.Service, service.Value)
if err != nil {
return err
}
*reply = "success"
return nil
}
func (n *IpcHandler) DeleteService(service string, reply *string) error {
err := n.Server.GetMeshManager().RemoveService(service)
if err != nil {
return err
}
*reply = "success"
return nil
}
func (n *IpcHandler) GetNode(args ipc.GetNodeArgs, reply *string) error {
node := n.Server.GetMeshManager().GetNode(args.MeshId, args.NodeId)
if node == nil {
*reply = "nil"
return nil
}
queryNode := query.MeshNodeToQueryNode(node)
bytes, err := json.Marshal(queryNode)
if err != nil {
*reply = err.Error()
return nil
}
*reply = string(bytes)
return nil
}
type RobinIpcParams struct {
CtrlServer *ctrlserver.MeshCtrlServer
CtrlServer ctrlserver.CtrlServer
}
func NewRobinIpc(ipcParams RobinIpcParams) IpcHandler {

View File

@ -0,0 +1,73 @@
package robin
import (
"testing"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/mesh"
)
func getRequester() *IpcHandler {
return &IpcHandler{Server: ctrlserver.NewCtrlServerStub()}
}
func TestCreateMeshRepliesMeshId(t *testing.T) {
var reply string
requester := getRequester()
err := requester.CreateMesh(&ipc.NewMeshArgs{
IfName: "wg0",
WgPort: 5000,
Endpoint: "abc.com",
}, &reply)
if err != nil {
t.Error(err)
}
if len(reply) == 0 {
t.Fatalf(`reply should have been returned`)
}
}
func TestListMeshesNoMeshesListsEmpty(t *testing.T) {
var reply ipc.ListMeshReply
requester := getRequester()
err := requester.ListMeshes("", &reply)
if err != nil {
t.Error(err)
}
if len(reply.Meshes) != 0 {
t.Fatalf(`meshes should be empty`)
}
}
func TestListMeshesMeshesNotEmpty(t *testing.T) {
var reply ipc.ListMeshReply
requester := getRequester()
requester.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
MeshId: "tim123",
DevName: "wg0",
WgPort: 5000,
MeshBytes: make([]byte, 0),
})
err := requester.ListMeshes("", &reply)
if err != nil {
t.Error(err)
}
if len(reply.Meshes) != 1 {
t.Fatalf(`only only mesh exists`)
}
if reply.Meshes[0] != "tim123" {
t.Fatalf(`meshId was %s expected %s`, reply.Meshes[0], "tim123")
}
}

View File

@ -13,29 +13,6 @@ type WgRpc struct {
Server *ctrlserver.MeshCtrlServer
}
func nodeToRpcNode(node ctrlserver.MeshNode) *rpc.MeshNode {
return &rpc.MeshNode{
PublicKey: node.PublicKey,
WgEndpoint: node.WgEndpoint,
WgHost: node.WgHost,
Endpoint: node.HostEndpoint,
}
}
func nodesToRpcNodes(nodes map[string]ctrlserver.MeshNode) []*rpc.MeshNode {
n := len(nodes)
meshNodes := make([]*rpc.MeshNode, n)
var i int = 0
for _, v := range nodes {
meshNodes[i] = nodeToRpcNode(v)
i++
}
return meshNodes
}
func (m *WgRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.GetMeshReply, error) {
mesh := m.Server.MeshManager.GetMesh(request.MeshId)

View File

@ -0,0 +1 @@
package robin

View File

@ -1,7 +1,6 @@
package sync
import (
"errors"
"math/rand"
"sync"
"time"
@ -20,7 +19,7 @@ type Syncer interface {
}
type SyncerImpl struct {
manager *mesh.MeshManager
manager mesh.MeshManager
requester SyncRequester
infectionCount int
syncCount int
@ -30,44 +29,30 @@ type SyncerImpl struct {
// Sync: Sync random nodes
func (s *SyncerImpl) Sync(meshId string) error {
logging.Log.WriteInfof("UPDATING WG CONF")
s.manager.ApplyConfig()
if !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
logging.Log.WriteInfof("No changes for %s", meshId)
return nil
}
theMesh := s.manager.GetMesh(meshId)
logging.Log.WriteInfof("UPDATING WG CONF")
if theMesh == nil {
return errors.New("the provided mesh does not exist")
if s.manager.HasChanges(meshId) {
err := s.manager.ApplyConfig()
if err != nil {
logging.Log.WriteInfof("Failed to update config %w", err)
}
}
snapshot, err := theMesh.GetMesh()
nodeNames := s.manager.GetMesh(meshId).GetPeers()
self, err := s.manager.GetSelf(meshId)
if err != nil {
return err
}
nodes := snapshot.GetNodes()
if len(nodes) <= 1 {
return nil
}
excludedNodes := map[string]struct{}{
s.manager.HostParameters.HostEndpoint: {},
}
meshNodes := lib.MapValuesWithExclude(nodes, excludedNodes)
getNames := func(node mesh.MeshNode) string {
return node.GetHostEndpoint()
}
nodeNames := lib.Map(meshNodes, getNames)
neighbours := s.cluster.GetNeighbours(nodeNames, s.manager.HostParameters.HostEndpoint)
neighbours := s.cluster.GetNeighbours(nodeNames, self.GetHostEndpoint())
randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
for _, node := range randomSubset {
@ -76,9 +61,9 @@ func (s *SyncerImpl) Sync(meshId string) error {
before := time.Now()
if len(meshNodes) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
logging.Log.WriteInfof("Sending to random cluster")
interCluster := s.cluster.GetInterCluster(nodeNames, s.manager.HostParameters.HostEndpoint)
interCluster := s.cluster.GetInterCluster(nodeNames, self.GetHostEndpoint())
randomSubset = append(randomSubset, interCluster)
}
@ -97,17 +82,20 @@ func (s *SyncerImpl) Sync(meshId string) error {
waitGroup.Wait()
s.syncCount++
logging.Log.WriteInfof("SYNC TIME: %v", time.Now().Sub(before))
logging.Log.WriteInfof("SYNC TIME: %v", time.Since(before))
logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount)
s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount)
// Check if any changes have occurred and trigger callbacks
// if changes have occurred.
// return s.manager.GetMonitor().Trigger()
return nil
}
// SyncMeshes: Sync all meshes
func (s *SyncerImpl) SyncMeshes() error {
for meshId, _ := range s.manager.Meshes {
for meshId := range s.manager.GetMeshes() {
err := s.Sync(meshId)
if err != nil {
@ -118,7 +106,7 @@ func (s *SyncerImpl) SyncMeshes() error {
return nil
}
func NewSyncer(m *mesh.MeshManager, conf *conf.WgMeshConfiguration, r SyncRequester) Syncer {
func NewSyncer(m mesh.MeshManager, conf *conf.WgMeshConfiguration, r SyncRequester) Syncer {
cluster, _ := conn.NewConnCluster(conf.ClusterSize)
return &SyncerImpl{
manager: m,

View File

@ -14,7 +14,7 @@ type SyncErrorHandler interface {
// SyncErrorHandlerImpl Is an implementation of the SyncErrorHandler
type SyncErrorHandlerImpl struct {
meshManager *mesh.MeshManager
meshManager mesh.MeshManager
}
func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint string) bool {
@ -24,6 +24,13 @@ func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint stri
return false
}
// self, err := s.meshManager.GetSelf(meshId)
// if err != nil {
// return false
// }
// mesh.DecrementHealth(endpoint, self.GetHostEndpoint())
return true
}
@ -40,6 +47,6 @@ func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error)
return false
}
func NewSyncErrorHandler(m *mesh.MeshManager) SyncErrorHandler {
func NewSyncErrorHandler(m mesh.MeshManager) SyncErrorHandler {
return &SyncErrorHandlerImpl{meshManager: m}
}

View File

@ -50,7 +50,6 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endP
err = s.server.MeshManager.AddMesh(&mesh.AddMeshParams{
MeshId: meshId,
DevName: ifName,
WgPort: port,
MeshBytes: reply.Mesh,
})
@ -89,10 +88,12 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
c := rpc.NewSyncServiceClient(client)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
syncTimeOut := s.server.Conf.SyncRate * float64(time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(syncTimeOut))
defer cancel()
err = syncMesh(mesh, ctx, c)
err = s.syncMesh(mesh, ctx, c)
if err != nil {
return s.handleErr(meshId, endpoint, err)
@ -102,7 +103,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
return nil
}
func syncMesh(mesh mesh.MeshProvider, ctx context.Context, client rpc.SyncServiceClient) error {
func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context, client rpc.SyncServiceClient) error {
stream, err := client.SyncMesh(ctx)
syncer := mesh.GetSyncer()

View File

@ -1,55 +1,18 @@
package sync
import (
"time"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
// SyncScheduler: Loops through all nodes in the mesh and runs a schedule to
// sync each event
type SyncScheduler interface {
Run() error
Stop() error
}
// SyncSchedulerImpl scheduler for sync scheduling
type SyncSchedulerImpl struct {
quit chan struct{}
server *ctrlserver.MeshCtrlServer
syncer Syncer
}
// Run implements SyncScheduler.
func (s *SyncSchedulerImpl) Run() error {
ticker := time.NewTicker(time.Duration(s.server.Conf.SyncRate) * time.Second)
quit := make(chan struct{})
s.quit = quit
for {
select {
case <-ticker.C:
err := s.syncer.SyncMeshes()
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
break
case <-quit:
break
}
func syncFunction(syncer Syncer) lib.TimerFunc {
return func() error {
return syncer.SyncMeshes()
}
}
// Stop implements SyncScheduler.
func (s *SyncSchedulerImpl) Stop() error {
close(s.quit)
return nil
}
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester) SyncScheduler {
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester) *lib.Timer {
syncer := NewSyncer(s.MeshManager, s.Conf, syncRequester)
return &SyncSchedulerImpl{server: s, syncer: syncer}
return lib.NewTimer(syncFunction(syncer), int(s.Conf.SyncRate))
}

View File

@ -64,11 +64,11 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error
syncer = mesh.GetSyncer()
} else if meshId != in.MeshId {
return errors.New("Differing MeshIDs")
return errors.New("differing meshids")
}
if syncer == nil {
return errors.New("Syncer should not be nil")
return errors.New("syncer should not be nil")
}
msg, moreMessages := syncer.GenerateMessage()

View File

@ -1,48 +1,14 @@
package timestamp
import (
"time"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
type TimestampScheduler interface {
Run() error
Stop() error
}
type TimeStampSchedulerImpl struct {
meshManager *mesh.MeshManager
updateRate int
quit chan struct{}
}
func (s *TimeStampSchedulerImpl) Run() error {
ticker := time.NewTicker(time.Duration(s.updateRate) * time.Second)
s.quit = make(chan struct{})
for {
select {
case <-ticker.C:
err := s.meshManager.UpdateTimeStamp()
if err != nil {
logging.Log.WriteErrorf("Update Timestamp Error: %s", err.Error())
}
case <-s.quit:
break
}
func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
timerFunc := func() error {
return ctrlServer.MeshManager.UpdateTimeStamp()
}
}
func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) TimestampScheduler {
return &TimeStampSchedulerImpl{meshManager: ctrlServer.MeshManager, updateRate: ctrlServer.Conf.KeepAliveRate}
}
func (s *TimeStampSchedulerImpl) Stop() error {
close(s.quit)
return nil
return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime)
}

15
pkg/wg/stubs.go Normal file
View File

@ -0,0 +1,15 @@
package wg
type WgInterfaceManipulatorStub struct{}
func (i *WgInterfaceManipulatorStub) CreateInterface(port int) (string, error) {
return "", nil
}
func (i *WgInterfaceManipulatorStub) AddAddress(ifName string, addr string) error {
return nil
}
func (i *WgInterfaceManipulatorStub) RemoveInterface(ifName string) error {
return nil
}

View File

@ -8,15 +8,11 @@ func (m *WgError) Error() string {
return m.msg
}
type CreateInterfaceParams struct {
IfName string
Port int
}
type WgInterfaceManipulator interface {
// CreateInterface creates a WireGuard interface
CreateInterface(params *CreateInterfaceParams) error
// Enable interface enables the given interface with
// the IP. It overrides the IP at the interface
EnableInterface(ifName string, ip string) error
CreateInterface(port int) (string, error)
// AddAddress adds an address to the given interface name
AddAddress(ifName string, addr string) error
// RemoveInterface removes the specified interface
RemoveInterface(ifName string) error
}

View File

@ -1,110 +1,99 @@
package wg
import (
"errors"
"crypto"
"crypto/rand"
"encoding/base64"
"fmt"
"net"
"os/exec"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// createInterface uses ip link to create an interface. If the interface exists
// it returns an error
func createInterface(ifName string) error {
_, err := net.InterfaceByName(ifName)
if err == nil {
err = flushInterface(ifName)
return err
}
// Check if the interface exists
cmd := exec.Command("/usr/bin/ip", "link", "add", "dev", ifName, "type", "wireguard")
if err := cmd.Run(); err != nil {
return err
}
return nil
}
type WgInterfaceManipulatorImpl struct {
client *wgctrl.Client
}
func (m *WgInterfaceManipulatorImpl) CreateInterface(params *CreateInterfaceParams) error {
err := createInterface(params.IfName)
const hashLength = 6
// CreateInterface creates a WireGuard interface
func (m *WgInterfaceManipulatorImpl) CreateInterface(port int) (string, error) {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return err
return "", fmt.Errorf("failed to access link: %w", err)
}
defer rtnl.Close()
randomBuf := make([]byte, 32)
_, err = rand.Read(randomBuf)
if err != nil {
return "", err
}
md5 := crypto.MD5.New().Sum(randomBuf)
md5Str := fmt.Sprintf("wg%s", base64.StdEncoding.EncodeToString(md5)[:hashLength])
err = rtnl.CreateLink(md5Str)
if err != nil {
return "", fmt.Errorf("failed to create link: %w", err)
}
privateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return err
return "", fmt.Errorf("failed to create private key: %w", err)
}
var cfg wgtypes.Config = wgtypes.Config{
PrivateKey: &privateKey,
ListenPort: &params.Port,
ListenPort: &port,
}
m.client.ConfigureDevice(params.IfName, cfg)
return nil
err = m.client.ConfigureDevice(md5Str, cfg)
if err != nil {
m.RemoveInterface(md5Str)
return "", fmt.Errorf("failed to configure dev: %w", err)
}
logging.Log.WriteInfof("ip link set up dev %s type wireguard", md5Str)
return md5Str, nil
}
// flushInterface flushes the specified interface
func flushInterface(ifName string) error {
_, err := net.InterfaceByName(ifName)
// Add an address to the given interface
func (m *WgInterfaceManipulatorImpl) AddAddress(ifName string, addr string) error {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return &WgError{msg: fmt.Sprintf("Interface %s does not exist cannot flush", ifName)}
return fmt.Errorf("failed to create config: %w", err)
}
defer rtnl.Close()
err = rtnl.AddAddress(ifName, addr)
if err != nil {
err = fmt.Errorf("failed to add address: %w", err)
}
cmd := exec.Command("/usr/bin/ip", "addr", "flush", "dev", ifName)
if err := cmd.Run(); err != nil {
logging.Log.WriteErrorf(fmt.Sprintf("%s error flushing interface %s", err.Error(), ifName))
return &WgError{msg: fmt.Sprintf("Failed to flush interface %s", ifName)}
}
return nil
return err
}
// EnableInterface flushes the interface and sets the ip address of the
// interface
func (m *WgInterfaceManipulatorImpl) EnableInterface(ifName string, ip string) error {
if len(ifName) == 0 {
return errors.New("ifName not provided")
}
err := flushInterface(ifName)
// RemoveInterface implements WgInterfaceManipulator.
func (*WgInterfaceManipulatorImpl) RemoveInterface(ifName string) error {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return err
return fmt.Errorf("failed to create config: %w", err)
}
defer rtnl.Close()
cmd := exec.Command("/usr/bin/ip", "link", "set", "up", "dev", ifName)
if err := cmd.Run(); err != nil {
return err
}
hostIp, _, err := net.ParseCIDR(ip)
if err != nil {
return err
}
cmd = exec.Command("/usr/bin/ip", "addr", "add", hostIp.String()+"/64", "dev", ifName)
if err := cmd.Run(); err != nil {
return err
}
return nil
return rtnl.DeleteLink(ifName)
}
func NewWgInterfaceManipulator(client *wgctrl.Client) WgInterfaceManipulator {

View File

@ -0,0 +1,92 @@
// Package to convert an IPV6 addres into 8 words
package what8words
import (
"bufio"
"bytes"
"fmt"
"net"
"os"
"strings"
)
type What8Words struct {
words []string
}
// Convert implements What8Words.
func (w *What8Words) Convert(ipStr string) (string, error) {
ip, ipNet, err := net.ParseCIDR(ipStr)
if err != nil {
return "", err
}
ip16 := ip.To16()
if ip16 == nil {
return "", fmt.Errorf("cannot convert ip to 16 representation")
}
representation := make([]string, 7)
for i := 2; i <= net.IPv6len-2; i += 2 {
word1 := w.words[ip16[i]]
word2 := w.words[ip16[i+1]]
representation[i/2-1] = fmt.Sprintf("%s-%s", word1, word2)
}
prefixSize, _ := ipNet.Mask.Size()
return strings.Join(representation[:prefixSize/16-1], "."), nil
}
// Convert implements What8Words.
func (w *What8Words) ConvertIdentifier(ipStr string) (string, error) {
ip, err := w.Convert(ipStr)
if err != nil {
return "", err
}
constituents := strings.Split(ip, ".")
return strings.Join(constituents[3:], "."), nil
}
func NewWhat8Words(pathToWords string) (*What8Words, error) {
words, err := ReadWords(pathToWords)
if err != nil {
return nil, err
}
return &What8Words{words: words}, nil
}
// ReadWords reads the what 8 words txt file
func ReadWords(wordFile string) ([]string, error) {
f, err := os.ReadFile(wordFile)
if err != nil {
return nil, err
}
words := make([]string, 257)
reader := bufio.NewScanner(bytes.NewReader(f))
counter := 0
for reader.Scan() && counter <= len(words) {
text := reader.Text()
words[counter] = text
counter++
if reader.Err() != nil {
return nil, reader.Err()
}
}
return words, nil
}

1
smegmesh-web Submodule

Submodule smegmesh-web added at c1128bcd98