1
0
forked from extern/smegmesh

Compare commits

...

96 Commits

Author SHA1 Message Date
8aab4e99d8 59-error-when-peer-not-selected
In the CLI when the peer is not selected
as the type throwing an error stating
either client or peer must be selected
2023-12-22 19:08:20 +00:00
cf4be1ccab Merge pull request #58 from tim-beatham/bugfix-pull-only
Bugfix pull only
2023-12-22 18:49:09 +00:00
6ed32f3a79 bugfix-push-pull
Organised groups as a tree so that there
isn't a limit to dissemination
2023-12-19 00:50:17 +00:00
b6199892f0 bugfix-pull-only
Bugfix with inter-cluster communication pull not working
2023-12-18 22:17:46 +00:00
ad22f04b0d bugfix-pull-only
After certain period of time if no changes have
occurred then pull
2023-12-18 20:45:56 +00:00
092d9a4af5 checking-latency-for-pull-only 2023-12-17 09:44:32 +00:00
19abf712a6 Fixing bug with nodes being removed 2023-12-12 12:45:41 +00:00
b296e1f45a Merge pull request #57 from tim-beatham/55-cli-option-for-peer-type
55-cli-optionifor-peer-type
2023-12-12 12:00:42 +00:00
2dc89d171b 55-cli-optionifor-peer-type
- Ability to specify WireGuard keepalive in the CLI formatter
- Ability to specify publicly routeable endpoint
- Ability to specify whether to advetise routes into the mesh,
and whether to advertise default routes.
2023-12-12 11:58:47 +00:00
13bea10638 main - bugfix
- Nodes not being removed when deleted because when node gossips again
  it is readded.
- Keep track of highest vector clock we have removed and used this as a
  mark for determining if something is stale.
2023-12-11 11:09:02 +00:00
3222d7e388 main - adding WireGuard stats to JSON objects
- Adding WireGuard stats through to IPC calls so that they can be used
by the API
2023-12-11 09:55:25 +00:00
1789d203f6 main - fix default routing being deleted
Default route keeps fluctuating on configuration
update.
2023-12-10 23:35:00 +00:00
a5074a536e main - BUGFIX
- segfault BUGFIX
2023-12-10 22:31:24 +00:00
acb90a5679 main - go.sum should be tracked into the git
- go.sum should be contained in the git history
2023-12-10 22:11:09 +00:00
27ec23f133 Merge pull request #54 from tim-beatham/53-run-commands-pre-up-and-post-down
53-run-commands-pre-up-and-post-down
2023-12-10 19:22:59 +00:00
fe14f63217 53-run-commands-pre-up-and-post-down
- Ability to run a command pre up and post down
- Ability to be a client in one mesh and a peer in the other
- Added dev card to specify different sync rate, keepalive rate per
  mesh.
2023-12-10 19:21:54 +00:00
4a8a39601f Merge pull request #52 from tim-beatham/51-bufix-not-removing-when-withdrawn
51-bugfix-routes-not-removing-when-withdrawn
2023-12-10 15:13:57 +00:00
1e263cc6a8 51-bugfix-routes-not-removing-when-withdrawn
- Routes are not being removed despite being withdrawn from the
configuration.
- Best path routes are not shared across interfaces
- Bug in consistent hashing wrong parameter passed caused by
refactorings.
2023-12-10 15:10:36 +00:00
dae9cd31a1 Merge pull request #50 from tim-beatham/50-give-client-ability-to-bridge-meshes
50-give-client-ability-to-bridge-meshes
2023-12-08 23:58:32 +00:00
f855f53fbf 50-give-client-ability-to-bridge-meshes
Client can act as a route bridging meshes. Cient send keepalives
to all of it's peers in the different meshes act as a bridge between
the meshes
2023-12-08 23:56:07 +00:00
52feb5767b Merge pull request #48 from tim-beatham/47-default-routing
47 default routing
2023-12-08 20:03:45 +00:00
815c4484ee 47-default-routing
Implemented default routing and improved size of gossip. Using 64 bit
hash funciton to identify vector.
2023-12-08 20:02:57 +00:00
0058c9f4c9 47-default-routing
Implementing default routing so that all traffic goes out of an
exit point.
2023-12-08 11:49:24 +00:00
92c0805275 Merge pull request #46 from tim-beatham/45-use-statistical-testing
45 use statistical testing
2023-12-07 18:20:25 +00:00
661fb0d54c 45-use-statistical-testing
Keepalive is based on per mesh and not per node.
Using total ordering mechanism similar to paxos to elect a leader
if leader doesn't update it's timestamp within 3 * keepAlive then
give the leader a gravestone and elect the next leader.
Leader is bassed on lexicographically ordered public key.
2023-12-07 18:18:13 +00:00
64885f1055 45-use-statistical-testing
Using statistical testing to test whether the node has failed.
2023-12-07 01:44:54 +00:00
2169f7796f Merge pull request #44 from tim-beatham/43-gravestones
43-use-gravestones
2023-12-06 22:46:05 +00:00
a3ceff019d 43-use-gravestones
Change of approach from keepalive to a noiseless protocol
2023-12-06 22:45:04 +00:00
b78d96986c Merge pull request #42 from tim-beatham/41-bugfix-fluctuating-ips
41 bugfix fluctuating ips
2023-12-06 14:37:14 +00:00
1b18d89c9f 41-bugfix-fluctuating-ips
Fluctuating ips creating hub and spoke.
2023-12-05 02:00:16 +00:00
245a2c5f58 41-bugfix-fluctuating-ips
If the node is a peer then add the client in the WG
configuration.
2023-12-04 17:40:24 +00:00
c40f7510b8 41-bugfix-fluctuating-ips
IPs of clients fluctuating because there isn't a strict order on
clients. Client's need to be processed before the peers.
2023-12-04 17:32:50 +00:00
78d748770c BUGIX Hash client by public key 2023-12-04 17:13:51 +00:00
0ff2a8eef9 BUGFIX: Allowed IPs fluctuating 2023-12-04 17:11:37 +00:00
fd7bd80485 BUGFIX
Don't get device each time it is an expensive operation.
2023-12-04 16:40:15 +00:00
3ef1b68ba5 BUGFIX: Hashing datastore to work out changes
Changed hashing implementation to work out if there are changes
in the data store
2023-11-30 15:58:26 +00:00
b9ba836ae3 Merge pull request #40 from tim-beatham/39-implement-two-phase-map
39-implement-two-phase-map
2023-11-30 02:03:36 +00:00
650901aba1 39-implement-two-phase-map
Implemented my own two phase map based on vector clocks
2023-11-30 02:02:38 +00:00
a82eab0686 Bugfix
Added replace peers so that deleted nodes are automatically removed
2023-11-28 14:43:55 +00:00
32e7e4c7df main
Bugfix. Fixed issue where consistent hashing was not working.
2023-11-28 14:42:09 +00:00
1fae0a6c2c Merge pull request #37 from tim-beatham/36-add-route-path-into-route-object
36-add-route-path-into-route-object
2023-11-27 21:03:56 +00:00
d8e156f13f 36-add-route-path-into-route-object
Added the route path into the route object so that we can
see what meshes packets are routed across.
2023-11-27 18:55:41 +00:00
3fca49a1c9 Merge pull request #35 from tim-beatham/34-fix-routing
34 fix routing
2023-11-27 16:05:06 +00:00
a2517a1e72 34-fix-routing
- Added mesh-to-mesh routing of hop count > 1
- If there is a tie-breaker with respect to the hop-count use consistent
hashing to determine the route to take based on the public key.
2023-11-27 15:56:30 +00:00
aef8b59f22 32-fix-routing
Flooding routes into other meshes a bit like BGP.
2023-11-25 03:15:58 +00:00
4030d17b41 Fixed routing issue 2023-11-24 17:49:06 +00:00
73db65660b Merge pull request #33 from tim-beatham/32-incorporate-dns
32-incorporate-dns
2023-11-24 15:05:40 +00:00
d1a74a7b95 32-incorporate-dns
Incorporated a DNS server. A DNS server can be run to resolve host
names.
2023-11-24 15:04:07 +00:00
f28ed8260d Merge pull request #30 from tim-beatham/29-only-ping-clients-who-have-updated-their-config
29-only-ping-clients-who-have-updated-their-config
2023-11-24 12:39:14 +00:00
2c406718df 29-only-ping-clients-who-have-updated-their-config
Only consider clients who have updated their config when synchronising
with peers. Consider a dead time where we don't have a handshake and
a prune time when we remove them from the WireGuard configuration.
2023-11-24 12:37:54 +00:00
11b003b549 Merge pull request #28 from tim-beatham/27-remove-client-grpc-endpoint
27-remove-client-grpc-endpoint
2023-11-24 12:08:42 +00:00
7be11dbaa3 27-remove-client-grpc-endpoint
Removed a client's grpc endpoint value. Client's aren't publicly
available so there is no need for a client's gRPC endpoint.
Also changed a node ID's to their public key. A node id's public
address is an issue for mobility of clients as their endpoint
is subject to change
2023-11-24 12:07:03 +00:00
e7ac8c5542 Only updating WireGuard config if node exists 2023-11-22 13:08:02 +00:00
09c64c4628 Fixed container file 2023-11-22 12:45:01 +00:00
2c4f18f52b Merge pull request #26 from tim-beatham/25-modify-code-to-use-public-api
25-modify-code-to-use-public-api
2023-11-22 10:42:48 +00:00
4c54022f63 25-modify-code-to-use-public-api
Modify the code to use a public IP address by default if none is
specified
2023-11-22 10:41:54 +00:00
bf0724f6e5 Merge pull request #24 from tim-beatham/24-keepalive-holepunch
24 keepalive holepunch
2023-11-21 21:28:16 +00:00
624bd6e921 24-keepalive
Persistent keep alive working
2023-11-21 21:26:31 +00:00
7b939e0468 24-keepalive-holepunch
Added the ability to hole punch NAT
2023-11-21 20:42:43 +00:00
6e201ebaf5 24-keepalive-holepunch
Nodes acting as peers and nodes acting as clients
2023-11-21 16:42:49 +00:00
06542da03c main
Fixed problems with timestamp not updating
2023-11-21 13:31:34 +00:00
0d63cd6624 main
Adding words.txt for what words
2023-11-20 18:12:58 +00:00
f13319cfc1 Merge pull request #22 from tim-beatham/21-phonetic-words-ipv6
21 phonetic words ipv6
2023-11-20 18:08:49 +00:00
95f4495b0b 21-phonetic-words-ipv6
Simple what 8 words implementation
2023-11-20 18:07:52 +00:00
330fa74ef4 IPv6 What 8 Words
what 8 words for ipv6 started
2023-11-20 15:22:32 +00:00
3e5b57e41f Merge pull request #20 from tim-beatham/19-hash-wg-interface
Hashing the WireGuard interface
2023-11-20 13:04:19 +00:00
b179cd3cf4 Hashing the WireGuard interface
Hashing the interface and using ephmeral ports so that the admin doesn't
choose an interface and port combination. An administrator can alteranatively
decide to provide port but this isn't critical.
2023-11-20 13:03:42 +00:00
8f211aa116 Merge pull request #18 from tim-beatham/26-performance-testing
Stubbing out WireGuard components
2023-11-20 11:29:37 +00:00
388153e706 Stubbing out WireGuard components
Stubbing our WireGuard components so that I can use docker/podman
network_mode=host. This is much more efficient than the docker/podman
userspace network.
2023-11-20 11:28:12 +00:00
023565d985 Merge pull request #17 from tim-beatham/25-ability-to-aliases
25 ability to aliases
2023-11-17 22:20:57 +00:00
36c264b38e 25-ability-aliases
Fixed unit tests failing
2023-11-17 22:18:53 +00:00
68db795f47 Ability to specify aliases
Ability to specify aliases that automatically append to /etc/hosts
2023-11-17 22:13:51 +00:00
f6160fe138 Adding aliases that automatically gets added 2023-11-17 19:13:20 +00:00
2c5289afb0 Merge pull request #16 from tim-beatham/15-add-rest-api
Developed a rest API
2023-11-15 12:57:05 +00:00
7199d07a76 Added smegmesh submodule 2023-11-13 10:46:52 +00:00
5f176e731f Developed a rest API 2023-11-13 10:44:14 +00:00
44f119b45c Updating examples 2023-11-08 09:19:24 +00:00
5215d5d54d Merge pull request #14 from tim-beatham/13-netlink-api
Removed interface manipulation via os.Exec into
2023-11-07 19:53:39 +00:00
1a864b7c80 Removed interface manipulation via os.Exec into
rtnetlink calls
2023-11-07 19:48:53 +00:00
4c19ebd81f Merge pull request #12 from tim-beatham/11-health-system
11 health system
2023-11-06 13:40:04 +00:00
acbeb689b5 Prune nodes if they exceed their timeout time 2023-11-06 13:37:28 +00:00
bc6cd4fdd5 Modified syncer 2023-11-06 10:05:23 +00:00
c88012cf71 Added health system to count how many times a node
fails to conenct.
2023-11-06 09:54:06 +00:00
4dc85f3861 Merge pull request #10 from tim-beatham/9-add-ci-support
9 add ci support
2023-11-05 18:07:52 +00:00
ef614f5961 Add cert dependencies 2023-11-05 18:06:24 +00:00
9454d62417 Adding stubs and writing tests 2023-11-05 18:03:58 +00:00
bb07d35dcb Unit testing the automerge library and lib functions 2023-11-05 12:13:40 +00:00
76dda2cf6f Update go.mod 2023-11-05 10:54:38 +00:00
1b286dd3c1 Update go.yml 2023-11-05 10:53:57 +00:00
2d45c2d298 Run go mod tidy in workflow 2023-11-05 10:51:24 +00:00
900c67a121 Update go.mod 2023-11-05 10:49:18 +00:00
b2fa08a642 Reverted go version 2023-11-05 10:48:35 +00:00
a4e9a5cd0f Updated go version in workflow 2023-11-05 10:47:10 +00:00
275eb423fb Create GitHub hosted test runner go.yml 2023-11-05 10:45:39 +00:00
d17dce3b1e Added clustering and clean up 2023-11-03 15:26:09 +00:00
e2c6db3a4f Merge pull request #8 from tim-beatham/7-create-rotating-window-of-connections
Implemented clustering betweeen nodes
2023-11-03 15:25:30 +00:00
93 changed files with 7672 additions and 1143 deletions

31
.github/workflows/go.yml vendored Normal file
View File

@ -0,0 +1,31 @@
# This workflow will build a golang project
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go
name: Go
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: '1.21'
- name: Tidy
run: go mod tidy
- name: Build
run: go build -v ./...
- name: Test
run: go test -v ./...

3
.gitmodules vendored Normal file
View File

@ -0,0 +1,3 @@
[submodule "smegmesh-web"]
path = smegmesh-web
url = git@github.com:tim-beatham/smegmesh-web.git

12
Containerfile Normal file
View File

@ -0,0 +1,12 @@
FROM docker.io/library/golang:bookworm
COPY ./ /wgmesh
RUN apt-get update && apt-get install -y \
wireguard \
wireguard-tools \
iproute2 \
iputils-ping \
tmux \
vim
WORKDIR /wgmesh
RUN go mod tidy
RUN go build -o /usr/local/bin ./...

1
Dockerfile Symbolic link
View File

@ -0,0 +1 @@
Containerfile

19
cmd/api/main.go Normal file
View File

@ -0,0 +1,19 @@
package main
import (
"log"
"github.com/tim-beatham/wgmesh/pkg/api"
)
func main() {
apiServer, err := api.NewSmegServer(api.ApiServerConf{
WordsFile: "./cmd/api/words.txt",
})
if err != nil {
log.Fatal(err.Error())
}
apiServer.Run(":8080")
}

257
cmd/api/words.txt Normal file
View File

@ -0,0 +1,257 @@
be
to
of
it
in
we
do
he
on
go
at
if
or
up
by
hi
the
and
you
not
for
but
say
get
she
one
all
can
out
who
now
see
way
how
lot
yes
use
any
day
try
put
let
why
new
off
big
too
ask
man
bit
end
may
own
run
pay
job
old
kid
bad
few
ago
far
buy
set
guy
car
sit
war
win
yet
top
law
cut
low
die
eat
age
hit
air
add
boy
act
tax
oil
eye
son
key
fun
dad
dog
arm
fly
box
gas
lie
hot
gun
per
art
red
fit
bed
fan
mix
mom
sex
bus
fix
bar
lay
ice
bet
bag
due
aid
tie
leg
ban
odd
cup
dry
cry
rid
pop
sir
cat
map
sad
sea
aim
sun
fat
row
egg
tea
god
wed
tip
ear
hat
net
ill
dig
fee
mad
gap
nor
bid
era
toy
sky
bin
owe
wet
tap
pro
ski
cow
pen
van
web
pot
sum
cap
log
pub
pig
joy
raw
rat
via
lip
two
six
ten
lab
ton
mid
bat
hip
gut
sin
non
rub
sub
par
pre
ray
cue
dye
fin
ion
neo
hey
wow
mum
bye
aye
jet
sue
pet
flu
cop
ooh
rip
spy
pie
bug
gum
wan
rap
nut
beg
pin
pit
jam
tag
fax
vet
fry
pad
lad
mud
bay
con
pan
gee
toe
dip
shy
gym
zoo
fox
bow
tin
hop
wee
kit
opt
vow
sew
cab
bee
rob
rig
yep
ego
rib
nod
hug
lap
ash
hum
dam
bum
yen
jar

18
cmd/dns/main.go Normal file
View File

@ -0,0 +1,18 @@
package main
import (
"log"
smegdns "github.com/tim-beatham/wgmesh/pkg/dns"
)
func main() {
server, err := smegdns.NewDns(53)
if err != nil {
log.Fatal(err.Error())
}
defer server.Close()
server.Listen()
}

View File

@ -4,8 +4,6 @@ import (
"fmt"
ipcRpc "net/rpc"
"os"
"strings"
"time"
"github.com/akamensky/argparse"
"github.com/tim-beatham/wgmesh/pkg/ipc"
@ -16,20 +14,19 @@ const SockAddr = "/tmp/wgmesh_ipc.sock"
type CreateMeshParams struct {
Client *ipcRpc.Client
IfName string
WgPort int
Endpoint string
WgArgs ipc.WireGuardArgs
AdvertiseRoutes bool
AdvertiseDefault bool
}
func createMesh(args *CreateMeshParams) string {
func createMesh(params *CreateMeshParams) string {
var reply string
newMeshParams := ipc.NewMeshArgs{
IfName: args.IfName,
WgPort: args.WgPort,
Endpoint: args.Endpoint,
WgArgs: params.WgArgs,
}
err := args.Client.Call("IpcHandler.CreateMesh", &newMeshParams, &reply)
err := params.Client.Call("IpcHandler.CreateMesh", &newMeshParams, &reply)
if err != nil {
return err.Error()
@ -57,9 +54,10 @@ type JoinMeshParams struct {
Client *ipcRpc.Client
MeshId string
IpAddress string
IfName string
WgPort int
Endpoint string
WgArgs ipc.WireGuardArgs
AdvertiseRoutes bool
AdvertiseDefault bool
}
func joinMesh(params *JoinMeshParams) string {
@ -68,8 +66,7 @@ func joinMesh(params *JoinMeshParams) string {
args := ipc.JoinMeshArgs{
MeshId: params.MeshId,
IpAdress: params.IpAddress,
IfName: params.IfName,
Port: params.WgPort,
WgArgs: params.WgArgs,
}
err := params.Client.Call("IpcHandler.JoinMesh", &args, &reply)
@ -81,30 +78,6 @@ func joinMesh(params *JoinMeshParams) string {
return reply
}
func getMesh(client *ipcRpc.Client, meshId string) {
reply := new(ipc.GetMeshReply)
err := client.Call("IpcHandler.GetMesh", &meshId, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
for _, node := range reply.Nodes {
fmt.Println("Public Key: " + node.PublicKey)
fmt.Println("Control Endpoint: " + node.HostEndpoint)
fmt.Println("WireGuard Endpoint: " + node.WgEndpoint)
fmt.Println("Wg IP: " + node.WgHost)
fmt.Println(fmt.Sprintf("Timestamp: %s", time.Unix(node.Timestamp, 0).String()))
advertiseRoutes := strings.Join(node.Routes, ",")
fmt.Printf("Routes: %s\n", advertiseRoutes)
fmt.Println("---")
}
}
func leaveMesh(client *ipcRpc.Client, meshId string) {
var reply string
@ -118,19 +91,6 @@ func leaveMesh(client *ipcRpc.Client, meshId string) {
fmt.Println(reply)
}
func enableInterface(client *ipcRpc.Client, meshId string) {
var reply string
err := client.Call("IpcHandler.EnableInterface", &meshId, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func getGraph(client *ipcRpc.Client, meshId string) {
var reply string
@ -171,42 +131,175 @@ func putDescription(client *ipcRpc.Client, description string) {
fmt.Println(reply)
}
// putAlias: puts an alias for the node
func putAlias(client *ipcRpc.Client, alias string) {
var reply string
err := client.Call("IpcHandler.PutAlias", &alias, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func setService(client *ipcRpc.Client, service, value string) {
var reply string
serviceArgs := &ipc.PutServiceArgs{
Service: service,
Value: value,
}
err := client.Call("IpcHandler.PutService", serviceArgs, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func deleteService(client *ipcRpc.Client, service string) {
var reply string
err := client.Call("IpcHandler.PutService", &service, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func main() {
parser := argparse.NewParser("wg-mesh",
"wg-mesh Manipulate WireGuard meshes")
"wg-mesh Manipulate WireGuard mesh networks")
newMeshCmd := parser.NewCommand("new-mesh", "Create a new mesh")
listMeshCmd := parser.NewCommand("list-meshes", "List meshes the node is connected to")
joinMeshCmd := parser.NewCommand("join-mesh", "Join a mesh network")
// getMeshCmd := parser.NewCommand("get-mesh", "Get a mesh network")
enableInterfaceCmd := parser.NewCommand("enable-interface", "Enable A Specific Mesh Interface")
getGraphCmd := parser.NewCommand("get-graph", "Convert a mesh into DOT format")
leaveMeshCmd := parser.NewCommand("leave-mesh", "Leave a mesh network")
queryMeshCmd := parser.NewCommand("query-mesh", "Query a mesh network using JMESPath")
putDescriptionCmd := parser.NewCommand("put-description", "Place a description for the node")
putAliasCmd := parser.NewCommand("put-alias", "Place an alias for the node")
setServiceCmd := parser.NewCommand("set-service", "Place a service into your advertisements")
deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements")
var newMeshIfName *string = newMeshCmd.String("f", "ifname", &argparse.Options{Required: true})
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{})
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{
Default: 0,
Help: "WireGuard port to use to the interface. A default of 0 uses an unused ephmeral port.",
})
var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true})
var joinMeshIfName *string = joinMeshCmd.String("f", "ifname", &argparse.Options{Required: true})
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{})
var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{
Help: "Publicly routeable endpoint to advertise within the mesh",
})
// var getMeshId *string = getMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var newMeshRole *string = newMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{
Default: "peer",
Help: "Role in the mesh network. A value of peer means that the node is publicly routeable and thus considered" +
" in the gossip protocol. Client means that the node is not publicly routeable and is not a candidate in the gossip" +
" protocol",
})
var newMeshKeepAliveWg *int = newMeshCmd.Int("k", "KeepAliveWg", &argparse.Options{
Default: 0,
Help: "WireGuard KeepAlive value for NAT traversal and firewall holepunching",
})
var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true})
var newMeshAdvertiseRoutes *bool = newMeshCmd.Flag("a", "advertise", &argparse.Options{
Help: "Advertise routes to other mesh network into the mesh",
})
var getGraphMeshId *string = getGraphCmd.String("m", "mesh", &argparse.Options{Required: true})
var newMeshAdvertiseDefaults *bool = newMeshCmd.Flag("d", "defaults", &argparse.Options{
Help: "Advertise ::/0 into the mesh network",
})
var leaveMeshMeshId *string = leaveMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var joinMeshId *string = joinMeshCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
var queryMeshMeshId *string = queryMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var queryMeshQuery *string = queryMeshCmd.String("q", "query", &argparse.Options{Required: true})
var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{
Required: true,
Help: "IP address of the bootstrapping node to join through",
})
var description *string = putDescriptionCmd.String("d", "description", &argparse.Options{Required: true})
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{
Help: "Publicly routeable endpoint to advertise within the mesh",
})
var joinMeshRole *string = joinMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{
Default: "peer",
Help: "Role in the mesh network. A value of peer means that the node is publicly routeable and thus considered" +
" in the gossip protocol. Client means that the node is not publicly routeable and is not a candidate in the gossip" +
" protocol",
})
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{
Default: 0,
Help: "WireGuard port to use to the interface. A default of 0 uses an unused ephmeral port.",
})
var joinMeshKeepAliveWg *int = joinMeshCmd.Int("k", "KeepAliveWg", &argparse.Options{
Default: 0,
Help: "WireGuard KeepAlive value for NAT traversal and firewall ho;lepunching",
})
var joinMeshAdvertiseRoutes *bool = joinMeshCmd.Flag("a", "advertise", &argparse.Options{
Help: "Advertise routes to other mesh network into the mesh",
})
var joinMeshAdvertiseDefaults *bool = joinMeshCmd.Flag("d", "defaults", &argparse.Options{
Help: "Advertise ::/0 into the mesh network",
})
var getGraphMeshId *string = getGraphCmd.String("m", "mesh", &argparse.Options{
Required: true,
Help: "MeshID of the graph to get",
})
var leaveMeshMeshId *string = leaveMeshCmd.String("m", "mesh", &argparse.Options{
Required: true,
Help: "MeshID of the mesh to leave",
})
var queryMeshMeshId *string = queryMeshCmd.String("m", "mesh", &argparse.Options{
Required: true,
Help: "MeshID of the mesh to query",
})
var queryMeshQuery *string = queryMeshCmd.String("q", "query", &argparse.Options{
Required: true,
Help: "JMESPath Query Of The Mesh Network To Query",
})
var description *string = putDescriptionCmd.String("d", "description", &argparse.Options{
Required: true,
Help: "Description of the node in the mesh",
})
var alias *string = putAliasCmd.String("a", "alias", &argparse.Options{
Required: true,
Help: "Alias of the node to set can be used in DNS to lookup an IP address",
})
var serviceKey *string = setServiceCmd.String("s", "service", &argparse.Options{
Required: true,
Help: "Key of the service to advertise in the mesh network",
})
var serviceValue *string = setServiceCmd.String("v", "value", &argparse.Options{
Required: true,
Help: "Value of the service to advertise in the mesh network",
})
var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{
Required: true,
Help: "Key of the service to remove",
})
err := parser.Parse(os.Args)
@ -224,9 +317,15 @@ func main() {
if newMeshCmd.Happened() {
fmt.Println(createMesh(&CreateMeshParams{
Client: client,
IfName: *newMeshIfName,
WgPort: *newMeshPort,
Endpoint: *newMeshEndpoint,
WgArgs: ipc.WireGuardArgs{
Endpoint: *newMeshEndpoint,
Role: *newMeshRole,
WgPort: *newMeshPort,
KeepAliveWg: *newMeshKeepAliveWg,
AdvertiseDefaultRoute: *newMeshAdvertiseDefaults,
AdvertiseRoutes: *newMeshAdvertiseRoutes,
},
}))
}
@ -237,26 +336,24 @@ func main() {
if joinMeshCmd.Happened() {
fmt.Println(joinMesh(&JoinMeshParams{
Client: client,
IfName: *joinMeshIfName,
WgPort: *joinMeshPort,
IpAddress: *joinMeshIpAddress,
MeshId: *joinMeshId,
Endpoint: *joinMeshEndpoint,
WgArgs: ipc.WireGuardArgs{
Endpoint: *joinMeshEndpoint,
Role: *joinMeshRole,
WgPort: *joinMeshPort,
KeepAliveWg: *joinMeshKeepAliveWg,
AdvertiseDefaultRoute: *joinMeshAdvertiseDefaults,
AdvertiseRoutes: *joinMeshAdvertiseRoutes,
},
}))
}
// if getMeshCmd.Happened() {
// getMesh(client, *getMeshId)
// }
if getGraphCmd.Happened() {
getGraph(client, *getGraphMeshId)
}
if enableInterfaceCmd.Happened() {
enableInterface(client, *enableInterfaceMeshId)
}
if leaveMeshCmd.Happened() {
leaveMesh(client, *leaveMeshMeshId)
}
@ -268,4 +365,16 @@ func main() {
if putDescriptionCmd.Happened() {
putDescription(client, *description)
}
if putAliasCmd.Happened() {
putAlias(client, *alias)
}
if setServiceCmd.Happened() {
setService(client, *serviceKey, *serviceValue)
}
if deleteServiceCmd.Happened() {
deleteService(client, *deleteServiceKey)
}
}

View File

@ -2,6 +2,7 @@ certificatePath: "/wgmesh/cert/cert.pem"
privateKeyPath: "/wgmesh/cert/priv.pem"
caCertificatePath: "/wgmesh/cert/cacert.pem"
skipCertVerification: true
timeout: 5
gRPCPort: "21906"
advertiseRoutes: true
clusterSize: 32
@ -9,4 +10,5 @@ syncRate: 1
interClusterChance: 0.15
branchRate: 3
infectionCount: 3
keepAliveRate: 60
keepAliveTime: 10
pruneTime: 20

View File

@ -1,7 +1,8 @@
package main
import (
"log"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
@ -9,22 +10,22 @@ import (
ctrlserver "github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/middleware"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/robin"
"github.com/tim-beatham/wgmesh/pkg/sync"
"github.com/tim-beatham/wgmesh/pkg/timestamp"
timer "github.com/tim-beatham/wgmesh/pkg/timers"
"golang.zx2c4.com/wireguard/wgctrl"
)
func main() {
if len(os.Args) != 2 {
logging.Log.WriteErrorf("Need to provide configuration.yaml")
logging.Log.WriteErrorf("Did not provide configuration")
return
}
conf, err := conf.ParseConfiguration(os.Args[1])
conf, err := conf.ParseDaemonConfiguration(os.Args[1])
if err != nil {
logging.Log.WriteInfof("Could not parse configuration")
logging.Log.WriteErrorf("Could not parse configuration: %s", err.Error())
return
}
@ -35,24 +36,34 @@ func main() {
return
}
if conf.Profile {
go func() {
http.ListenAndServe("localhost:6060", nil)
}()
}
var robinRpc robin.WgRpc
var robinIpc robin.IpcHandler
var authProvider middleware.AuthRpcProvider
var syncProvider sync.SyncServiceImpl
var syncRequester sync.SyncRequester
var syncer sync.Syncer
ctrlServerParams := ctrlserver.NewCtrlServerParams{
Conf: conf,
AuthProvider: &authProvider,
CtrlProvider: &robinRpc,
SyncProvider: &syncProvider,
Client: client,
OnDelete: func(mp mesh.MeshProvider) {
syncer.SyncMeshes()
},
}
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
syncProvider.Server = ctrlServer
syncRequester := sync.NewSyncRequester(ctrlServer)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester)
timestampScheduler := timestamp.NewTimestampScheduler(ctrlServer)
syncRequester = sync.NewSyncRequester(ctrlServer)
syncer = sync.NewSyncer(ctrlServer.MeshManager, conf, syncRequester)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, syncer)
keepAlive := timer.NewTimestampScheduler(ctrlServer)
robinIpcParams := robin.RobinIpcParams{
CtrlServer: ctrlServer,
@ -66,16 +77,16 @@ func main() {
return
}
log.Println("Running IPC Handler")
logging.Log.WriteInfof("Running IPC Handler")
go ipc.RunIpcHandler(&robinIpc)
go syncScheduler.Run()
go timestampScheduler.Run()
go keepAlive.Run()
closeResources := func() {
logging.Log.WriteInfof("Closing resources")
syncScheduler.Stop()
timestampScheduler.Stop()
keepAlive.Stop()
ctrlServer.Close()
client.Close()
}

View File

@ -0,0 +1,95 @@
version: '3'
networks:
net-1:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.0.0/17
net-2:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.155.0/17
services:
wg-1:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-2:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-3:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-4:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
sysctls:
- net.ipv6.conf.all.forwarding=1
networks:
- net-1
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-5:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-6:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-7:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"

View File

@ -0,0 +1,14 @@
certificatePath: "/wgmesh/cert/cert.pem"
privateKeyPath: "/wgmesh/cert/priv.pem"
caCertificatePath: "/wgmesh/cert/cacert.pem"
skipCertVerification: true
timeout: 5
gRPCPort: "21906"
advertiseRoutes: true
clusterSize: 32
syncRate: 1
interClusterChance: 0.15
branchRate: 3
infectionCount: 3
keepAliveTime: 10
pruneTime: 20

View File

@ -0,0 +1,42 @@
version: '3'
networks:
net-1:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.0.0/17
services:
wg-1:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-2:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-3:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"

View File

@ -0,0 +1,14 @@
certificatePath: "/wgmesh/cert/cert.pem"
privateKeyPath: "/wgmesh/cert/priv.pem"
caCertificatePath: "/wgmesh/cert/cacert.pem"
skipCertVerification: true
timeout: 5
gRPCPort: "21906"
advertiseRoutes: true
clusterSize: 32
syncRate: 1
interClusterChance: 0.15
branchRate: 3
infectionCount: 3
keepAliveTime: 10
pruneTime: 20

36
go.mod
View File

@ -1,31 +1,57 @@
module github.com/tim-beatham/wgmesh
go 1.21.1
go 1.21.3
require (
github.com/akamensky/argparse v1.4.0
github.com/anandvarma/namegen v0.0.0-20230727084436-5197c6ea3255
github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9
github.com/gin-gonic/gin v1.9.1
github.com/go-playground/validator/v10 v10.16.0
github.com/google/uuid v1.3.0
github.com/jmespath/go-jmespath v0.4.0
github.com/jsimonetti/rtnetlink v1.3.5
github.com/miekg/dns v1.1.57
github.com/sirupsen/logrus v1.9.3
golang.org/x/sys v0.14.0
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
gonum.org/v1/gonum v0.14.0
google.golang.org/grpc v1.58.1
google.golang.org/protobuf v1.31.0
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/bytedance/sonic v1.9.1 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.5.0 // indirect
golang.org/x/crypto v0.13.0 // indirect
golang.org/x/net v0.15.0 // indirect
golang.org/x/sync v0.3.0 // indirect
golang.org/x/sys v0.12.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.14.0 // indirect
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 // indirect
golang.org/x/mod v0.12.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sync v0.4.0 // indirect
golang.org/x/text v0.13.0 // indirect
golang.org/x/tools v0.13.0 // indirect
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect
)

143
go.sum Normal file
View File

@ -0,0 +1,143 @@
github.com/akamensky/argparse v1.4.0 h1:YGzvsTqCvbEZhL8zZu2AiA5nq805NZh75JNj4ajn1xc=
github.com/akamensky/argparse v1.4.0/go.mod h1:S5kwC7IuDcEr5VeXtGPRVZ5o/FdhcMlQz4IZQuw64xA=
github.com/anandvarma/namegen v0.0.0-20230727084436-5197c6ea3255 h1:aIAyyj4XPrke9Tc/umbBCzP5SKX/CHf3dKrL/PhH2lo=
github.com/anandvarma/namegen v0.0.0-20230727084436-5197c6ea3255/go.mod h1:MFyILur9tG8PxaCXGZVr/2BOnHtRIgxYejYFZdWLxr0=
github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9 h1:+6JSfuxZgmURoIlGdnYnY/FLRGWGagLyiBjt/VLtwi4=
github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9/go.mod h1:6UxoDE+thWsISXK93pxaOuOfkcAfCvDbg0eAnFmxL5E=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/cilium/ebpf v0.11.0 h1:V8gS/bTCCjX9uUnkUFUpPsksM8n1lXBAvHcpiFk1X2Y=
github.com/cilium/ebpf v0.11.0/go.mod h1:WE7CZAnqOL2RouJ4f1uyNhqr2P4CCvXFIqdRDUgWsVs=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE=
github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/jsimonetti/rtnetlink v1.3.5 h1:hVlNQNRlLDGZz31gBPicsG7Q53rnlsz1l1Ix/9XlpVA=
github.com/jsimonetti/rtnetlink v1.3.5/go.mod h1:0LFedyiTkebnd43tE4YAkWGIq9jQphow4CcwxaT2Y00=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI=
github.com/mdlayher/socket v0.5.0/go.mod h1:WkcBFfvyG8QENs5+hfQPl1X6Jpd2yeLIYgrGFmJiJxI=
github.com/miekg/dns v1.1.57 h1:Jzi7ApEIzwEPLHWRcafCN9LZSBbqQpxjt/wpgvg7wcM=
github.com/miekg/dns v1.1.57/go.mod h1:uqRjCRUuEAA6qsOiJvDd+CFo/vW+y5WR6SNmHE55hZk=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug=
golang.org/x/exp v0.0.0-20230321023759-10a507213a29/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ=
golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 h1:EY138uSo1JYlDq+97u1FtcOUwPpIU6WL1Lkt7WpYjPA=
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0=
gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 h1:bVf09lpb+OJbByTj913DRJioFFAjf/ZGxEz7MajTp2U=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98/go.mod h1:TUfxEVdsvPg18p6AslUXFoLdpED4oBnGwyqk3dV1XzM=
google.golang.org/grpc v1.58.1 h1:OL+Vz23DTtrrldqHK49FUOPHyY75rvFqJfXC84NYW58=
google.golang.org/grpc v1.58.1/go.mod h1:tgX3ZQDlNJGU96V6yHh1T/JeoBQ2TXdr43YbYSsCJk0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

247
pkg/api/apiserver.go Normal file
View File

@ -0,0 +1,247 @@
package api
import (
"fmt"
"net/http"
ipcRpc "net/rpc"
"github.com/gin-gonic/gin"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/what8words"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
type ApiServer interface {
GetMeshes(c *gin.Context)
Run(addr string) error
}
type SmegServer struct {
router *gin.Engine
client *ipcRpc.Client
words *what8words.What8Words
}
func (s *SmegServer) routeToApiRoute(meshNode ctrlserver.MeshNode) []Route {
routes := make([]Route, len(meshNode.Routes))
for index, route := range meshNode.Routes {
if route.Path == nil {
route.Path = make([]string, 0)
}
routes[index] = Route{
Prefix: route.Destination,
Path: route.Path,
}
}
return routes
}
func (s *SmegServer) meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode {
if meshNode.Routes == nil {
meshNode.Routes = make([]ctrlserver.MeshRoute, 0)
}
alias := meshNode.Alias
if alias == "" {
alias, _ = s.words.ConvertIdentifier(meshNode.WgHost)
}
return &SmegNode{
WgHost: meshNode.WgHost,
WgEndpoint: meshNode.WgEndpoint,
Endpoint: meshNode.HostEndpoint,
Timestamp: int(meshNode.Timestamp),
Description: meshNode.Description,
Routes: s.routeToApiRoute(meshNode),
PublicKey: meshNode.PublicKey,
Alias: alias,
Services: meshNode.Services,
Stats: SmegStats{
TotalTransmit: meshNode.Stats.TransmitBytes,
TotalReceived: meshNode.Stats.ReceivedBytes,
KeepAliveInterval: meshNode.Stats.PersistentKeepAliveInterval,
AllowedIps: meshNode.Stats.AllowedIPs,
},
}
}
func (s *SmegServer) meshToAPIMesh(meshId string, nodes []ctrlserver.MeshNode) SmegMesh {
var smegMesh SmegMesh
smegMesh.MeshId = meshId
smegMesh.Nodes = make(map[string]SmegNode)
for _, node := range nodes {
smegMesh.Nodes[node.WgHost] = *s.meshNodeToAPIMeshNode(node)
}
return smegMesh
}
// CreateMesh: creates a mesh network
func (s *SmegServer) CreateMesh(c *gin.Context) {
var createMesh CreateMeshRequest
if err := c.ShouldBindJSON(&createMesh); err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
ipcRequest := ipc.NewMeshArgs{
WgArgs: ipc.WireGuardArgs{
WgPort: createMesh.WgPort,
},
}
var reply string
err := s.client.Call("IpcHandler.CreateMesh", &ipcRequest, &reply)
if err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
c.JSON(http.StatusOK, &gin.H{
"meshid": reply,
})
}
// JoinMesh: joins a mesh network
func (s *SmegServer) JoinMesh(c *gin.Context) {
var joinMesh JoinMeshRequest
if err := c.ShouldBindJSON(&joinMesh); err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
ipcRequest := ipc.JoinMeshArgs{
MeshId: joinMesh.MeshId,
IpAdress: joinMesh.Bootstrap,
WgArgs: ipc.WireGuardArgs{
WgPort: joinMesh.WgPort,
},
}
var reply string
err := s.client.Call("IpcHandler.JoinMesh", &ipcRequest, &reply)
if err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
c.JSON(http.StatusOK, &gin.H{
"status": "success",
})
}
// GetMesh: given a meshId returns the corresponding mesh
// network.
func (s *SmegServer) GetMesh(c *gin.Context) {
meshidParam := c.Param("meshid")
var meshid string = meshidParam
getMeshReply := new(ipc.GetMeshReply)
err := s.client.Call("IpcHandler.GetMesh", &meshid, &getMeshReply)
if err != nil {
c.JSON(http.StatusNotFound,
&gin.H{
"error": fmt.Sprintf("could not find mesh %s", meshidParam),
})
return
}
mesh := s.meshToAPIMesh(meshidParam, getMeshReply.Nodes)
c.JSON(http.StatusOK, mesh)
}
func (s *SmegServer) GetMeshes(c *gin.Context) {
listMeshesReply := new(ipc.ListMeshReply)
err := s.client.Call("IpcHandler.ListMeshes", "", &listMeshesReply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
c.JSON(http.StatusBadRequest, nil)
return
}
meshes := make([]SmegMesh, 0)
for _, mesh := range listMeshesReply.Meshes {
getMeshReply := new(ipc.GetMeshReply)
err := s.client.Call("IpcHandler.GetMesh", &mesh, &getMeshReply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
c.JSON(http.StatusBadRequest, nil)
return
}
meshes = append(meshes, s.meshToAPIMesh(mesh, getMeshReply.Nodes))
}
c.JSON(http.StatusOK, meshes)
}
func (s *SmegServer) Run(addr string) error {
logging.Log.WriteInfof("Running API server")
return s.router.Run(addr)
}
func NewSmegServer(conf ApiServerConf) (ApiServer, error) {
client, err := ipcRpc.DialHTTP("unix", SockAddr)
if err != nil {
return nil, err
}
words, err := what8words.NewWhat8Words(conf.WordsFile)
if err != nil {
return nil, err
}
router := gin.Default()
router.Use(gin.LoggerWithConfig(gin.LoggerConfig{
Output: logging.Log.Writer(),
}))
smegServer := &SmegServer{
router: router,
client: client,
words: words,
}
router.GET("/meshes", smegServer.GetMeshes)
router.GET("/mesh/:meshid", smegServer.GetMesh)
router.POST("/mesh/create", smegServer.CreateMesh)
router.POST("/mesh/join", smegServer.JoinMesh)
return smegServer, nil
}

47
pkg/api/types.go Normal file
View File

@ -0,0 +1,47 @@
package api
import "time"
type Route struct {
Prefix string `json:"prefix"`
Path []string `json:"path"`
}
type SmegStats struct {
TotalTransmit int64 `json:"totalTransmit"`
TotalReceived int64 `json:"totalReceived"`
KeepAliveInterval time.Duration `json:"keepaliveInterval"`
AllowedIps []string `json:"allowedIps"`
}
type SmegNode struct {
Alias string `json:"alias"`
WgHost string `json:"wgHost"`
WgEndpoint string `json:"wgEndpoint"`
Endpoint string `json:"endpoint"`
Timestamp int `json:"timestamp"`
Description string `json:"description"`
PublicKey string `json:"publicKey"`
Routes []Route `json:"routes"`
Services map[string]string `json:"services"`
Stats SmegStats `json:"stats"`
}
type SmegMesh struct {
MeshId string `json:"meshid"`
Nodes map[string]SmegNode `json:"nodes"`
}
type CreateMeshRequest struct {
WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"`
}
type JoinMeshRequest struct {
WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"`
Bootstrap string `json:"bootstrap" binding:"required"`
MeshId string `json:"meshid" binding:"required"`
}
type ApiServerConf struct {
WordsFile string
}

View File

@ -1,9 +1,10 @@
package crdt
package automerge
import (
"errors"
"fmt"
"net"
"slices"
"strings"
"time"
@ -20,11 +21,12 @@ import (
type CrdtMeshManager struct {
MeshId string
IfName string
NodeId string
Client *wgctrl.Client
doc *automerge.Doc
LastHash automerge.ChangeHash
conf *conf.WgMeshConfiguration
conf *conf.WgConfiguration
cache *MeshCrdt
lastCacheHash automerge.ChangeHash
}
func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
@ -34,15 +36,78 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
panic("node must be of type *MeshNodeCrdt")
}
crdt.Routes = make(map[string]Route)
crdt.Services = make(map[string]string)
crdt.Timestamp = time.Now().Unix()
c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt)
nodeVal, _ := c.doc.Path("nodes").Map().Get(crdt.HostEndpoint)
nodeVal.Map().Set("routes", automerge.NewMap())
c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt)
}
func (c *CrdtMeshManager) isPeer(nodeId string) bool {
node, err := c.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return false
}
nodeType, err := node.Map().Get("type")
if err != nil || nodeType.Kind() != automerge.KindStr {
return false
}
return nodeType.Str() == string(conf.PEER_ROLE)
}
// isAlive: checks that the node's configuration has been updated
// since the rquired keep alive time
func (c *CrdtMeshManager) isAlive(nodeId string) bool {
node, err := c.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return false
}
timestamp, err := node.Map().Get("timestamp")
if err != nil || timestamp.Kind() != automerge.KindInt64 {
return false
}
return true
// return (time.Now().Unix() - keepAliveTime) < int64(c.conf.DeadTime)
}
func (c *CrdtMeshManager) GetPeers() []string {
keys, _ := c.doc.Path("nodes").Map().Keys()
keys = lib.Filter(keys, func(publicKey string) bool {
return c.isPeer(publicKey) && c.isAlive(publicKey)
})
return keys
}
// GetMesh(): Converts the document into a struct
func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
return automerge.As[*MeshCrdt](c.doc.Root())
changes, err := c.doc.Changes(c.lastCacheHash)
if err != nil {
return nil, err
}
if c.cache == nil || len(changes) > 0 {
c.lastCacheHash = c.LastHash
cache, err := automerge.As[*MeshCrdt](c.doc.Root())
if err != nil {
return nil, err
}
c.cache = cache
}
return c.cache, nil
}
// GetMeshId returns the meshid of the mesh
@ -70,7 +135,7 @@ type NewCrdtNodeMangerParams struct {
MeshId string
DevName string
Port int
Conf conf.WgMeshConfiguration
Conf *conf.WgConfiguration
Client *wgctrl.Client
}
@ -81,24 +146,24 @@ func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, erro
manager.doc = automerge.New()
manager.IfName = params.DevName
manager.Client = params.Client
manager.conf = &params.Conf
manager.conf = params.Conf
manager.cache = nil
return &manager, nil
}
func (c *CrdtMeshManager) removeNode(endpoint string) error {
err := c.doc.Path("nodes").Map().Delete(endpoint)
if err != nil {
return err
// NodeExists: returns true if the node exists. Returns false
func (m *CrdtMeshManager) NodeExists(key string) bool {
node, err := m.doc.Path("nodes").Map().Get(key)
return node.Kind() == automerge.KindMap && err == nil
}
return nil
}
// GetNode: returns a mesh node crdt.Close releases resources used by a Client.
func (m *CrdtMeshManager) GetNode(endpoint string) (*MeshNodeCrdt, error) {
func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
node, err := m.doc.Path("nodes").Map().Get(endpoint)
if node.Kind() != automerge.KindMap {
return nil, fmt.Errorf("GetNode: something went wrong %s is not a map type")
}
if err != nil {
return nil, err
}
@ -176,7 +241,7 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro
}
if node.Kind() != automerge.KindMap {
return errors.New(fmt.Sprintf("%s does not exist", nodeId))
return fmt.Errorf("%s does not exist", nodeId)
}
err = node.Map().Set("description", description)
@ -188,8 +253,75 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro
return err
}
func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
return err
}
if node.Kind() != automerge.KindMap {
return fmt.Errorf("%s does not exist", nodeId)
}
err = node.Map().Set("alias", alias)
if err == nil {
logging.Log.WriteInfof("Updated Alias for %s to %s", nodeId, alias)
}
return err
}
func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return fmt.Errorf("AddService: node %s does not exist", nodeId)
}
service, err := node.Map().Get("services")
if err != nil {
return err
}
if service.Kind() != automerge.KindMap {
return fmt.Errorf("AddService: services property does not exist in node")
}
err = service.Map().Set(key, value)
return err
}
func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return fmt.Errorf("RemoveService: node %s does not exist", nodeId)
}
service, err := node.Map().Get("services")
if err != nil {
return err
}
if service.Kind() != automerge.KindMap {
return fmt.Errorf("services property does not exist")
}
err = service.Map().Delete(key)
if err != nil {
return fmt.Errorf("service %s does not exist", key)
}
return nil
}
// AddRoutes: adds routes to the specific nodeId
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
logging.Log.WriteInfof("Adding route to %s", nodeId)
@ -197,6 +329,10 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
return err
}
if nodeVal.Kind() != automerge.KindMap {
return fmt.Errorf("node does not exist")
}
routeMap, err := nodeVal.Map().Get("routes")
if err != nil {
@ -204,20 +340,156 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
}
for _, route := range routes {
err = routeMap.Map().Set(route, struct{}{})
prevRoute, err := routeMap.Map().Get(route.GetDestination().String())
if prevRoute.Kind() == automerge.KindVoid && err != nil {
path, err := prevRoute.Map().Get("path")
if err != nil {
return err
}
if path.Kind() != automerge.KindList {
return fmt.Errorf("path is not a list")
}
pathStr, err := automerge.As[[]string](path)
if err != nil {
return err
}
slices.Equal(route.GetPath(), pathStr)
}
err = routeMap.Map().Set(route.GetDestination().String(), Route{
Destination: route.GetDestination().String(),
Path: route.GetPath(),
})
if err != nil {
return err
}
}
return nil
}
func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
return nil, err
}
if nodeVal.Kind() != automerge.KindMap {
return nil, fmt.Errorf("node does not exist")
}
routeMap, err := nodeVal.Map().Get("routes")
if err != nil {
return nil, err
}
if routeMap.Kind() != automerge.KindMap {
return nil, fmt.Errorf("node %s is not a map", nodeId)
}
routes, err := automerge.As[map[string]Route](routeMap)
return lib.MapValues(routes), err
}
func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) {
node, err := m.GetNode(targetNode)
if err != nil {
return nil, err
}
routes := make(map[string]mesh.Route)
// Add routes that the node directly has
for _, route := range node.GetRoutes() {
routes[route.GetDestination().String()] = route
}
// Work out the other routes in the mesh
for _, node := range m.GetPeers() {
nodeRoutes, err := m.getRoutes(node)
if err != nil {
return nil, err
}
for _, route := range nodeRoutes {
otherRoute, ok := routes[route.GetDestination().String()]
hopCount := route.GetHopCount()
if node != targetNode {
hopCount += 1
}
if !ok || route.GetHopCount()+1 < otherRoute.GetHopCount() {
routes[route.GetDestination().String()] = &Route{
Destination: route.GetDestination().String(),
Path: append(route.Path, m.GetMeshId()),
}
}
}
}
return routes, nil
}
func (m *CrdtMeshManager) RemoveNode(nodeId string) error {
err := m.doc.Path("nodes").Map().Delete(nodeId)
return err
}
// DeleteRoutes deletes the specified routes
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
return err
}
if nodeVal.Kind() != automerge.KindMap {
return fmt.Errorf("node is not a map")
}
routeMap, err := nodeVal.Map().Get("routes")
if err != nil {
return err
}
for _, route := range routes {
err = routeMap.Map().Delete(route.GetDestination().String())
}
return err
}
// GetConfiguration: gets the configuration for this mesh network
func (m *CrdtMeshManager) GetConfiguration() *conf.WgConfiguration {
return m.conf
}
// Mark: mark the node as locally dead
func (m *CrdtMeshManager) Mark(nodeId string) {
}
func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
return NewAutomergeSync(m)
}
func (m *CrdtMeshManager) Prune() error {
return nil
}
func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int {
return strings.Compare(m1.PublicKey, m2.PublicKey)
}
@ -238,7 +510,6 @@ func (m *MeshNodeCrdt) GetWgHost() *net.IPNet {
_, ipnet, err := net.ParseCIDR(m.WgHost)
if err != nil {
logging.Log.WriteErrorf("Cannot parse WgHost %s", err.Error())
return nil
}
@ -249,8 +520,13 @@ func (m *MeshNodeCrdt) GetTimeStamp() int64 {
return m.Timestamp
}
func (m *MeshNodeCrdt) GetRoutes() []string {
return lib.MapKeys(m.Routes)
func (m *MeshNodeCrdt) GetRoutes() []mesh.Route {
return lib.Map(lib.MapValues(m.Routes), func(r Route) mesh.Route {
return &Route{
Destination: r.Destination,
Path: r.Path,
}
})
}
func (m *MeshNodeCrdt) GetDescription() string {
@ -261,11 +537,30 @@ func (m *MeshNodeCrdt) GetIdentifier() string {
ipv6 := m.WgHost[:len(m.WgHost)-4]
constituents := strings.Split(ipv6, ":")
logging.Log.WriteInfof(ipv6)
constituents = constituents[4:]
return strings.Join(constituents, ":")
}
func (m *MeshNodeCrdt) GetAlias() string {
return m.Alias
}
func (m *MeshNodeCrdt) GetServices() map[string]string {
services := make(map[string]string)
for key, service := range m.Services {
services[key] = service
}
return services
}
// GetType refers to the type of the node. Peer means that the node is globally accessible
// Client means the node is only accessible through another peer
func (n *MeshNodeCrdt) GetType() conf.NodeType {
return conf.NodeType(n.Type)
}
func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
nodes := make(map[string]mesh.MeshNode)
@ -278,8 +573,24 @@ func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
Timestamp: node.Timestamp,
Routes: node.Routes,
Description: node.Description,
Alias: node.Alias,
Services: node.GetServices(),
Type: node.Type,
}
}
return nodes
}
func (r *Route) GetDestination() *net.IPNet {
_, ipnet, _ := net.ParseCIDR(r.Destination)
return ipnet
}
func (r *Route) GetHopCount() int {
return len(r.Path)
}
func (r *Route) GetPath() []string {
return r.Path
}

View File

@ -1,4 +1,4 @@
package crdt
package automerge
import (
"github.com/automerge/automerge-go"

View File

@ -0,0 +1,366 @@
package automerge
import (
"slices"
"strings"
"testing"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type TestParams struct {
manager *CrdtMeshManager
}
func setUpTests() *TestParams {
manager, _ := NewCrdtNodeManager(&NewCrdtNodeMangerParams{
MeshId: "timsmesh123",
DevName: "wg0",
Port: 5000,
Client: nil,
Conf: conf.DaemonConfiguration{},
})
return &TestParams{
manager: manager,
}
}
func getTestNode() mesh.MeshNode {
return &MeshNodeCrdt{
HostEndpoint: "public-endpoint:8080",
WgEndpoint: "public-endpoint:21906",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d218/128",
PublicKey: "AAAAAAAAAAAA",
Timestamp: time.Now().Unix(),
Description: "A node that we are adding",
}
}
func getTestNode2() mesh.MeshNode {
return &MeshNodeCrdt{
HostEndpoint: "public-endpoint:8081",
WgEndpoint: "public-endpoint:21907",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d219/128",
PublicKey: "BBBBBBBBB",
Timestamp: time.Now().Unix(),
Description: "A node that we are adding",
}
}
func TestAddNodeNodeExists(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getTestNode())
node, err := testParams.manager.GetNode("public-endpoint:8080")
if err != nil {
t.Error(err)
}
if node == nil {
t.Fatalf(`node not added to the mesh when it should be`)
}
}
func TestAddNodeAddRoute(t *testing.T) {
testParams := setUpTests()
testNode := getTestNode()
testParams.manager.AddNode(testNode)
testParams.manager.AddRoutes(testNode.GetHostEndpoint(), "fd:1c64:1d00::/48")
updatedNode, err := testParams.manager.GetNode(testNode.GetHostEndpoint())
if err != nil {
t.Error(err)
}
if updatedNode == nil {
t.Fatalf(`Node does not exist in the mesh`)
}
routes := updatedNode.GetRoutes()
if !slices.Contains(routes, "fd:1c64:1d00::/48") {
t.Fatal("Route node not added")
}
if len(routes) != 1 {
t.Fatal(`Route length mismatch`)
}
}
func TestGetMeshIdReturnsTheMeshId(t *testing.T) {
testParams := setUpTests()
if len(testParams.manager.GetMeshId()) == 0 {
t.Fatal(`Meshid is less than 0`)
}
}
// Add 2 nodes to the mesh and then get the mesh.s
// It should return the 2 nodes that have been added to the mesh
func TestAdd2NodesGetMesh(t *testing.T) {
testParams := setUpTests()
node1 := getTestNode()
node2 := getTestNode2()
testParams.manager.AddNode(node1)
testParams.manager.AddNode(node2)
mesh, err := testParams.manager.GetMesh()
if err != nil {
t.Error(err)
}
nodes := mesh.GetNodes()
if len(nodes) != 2 {
t.Fatalf(`Mismatch in node slice`)
}
for _, node := range nodes {
if node.GetHostEndpoint() != node1.GetHostEndpoint() && node.GetHostEndpoint() != node2.GetHostEndpoint() {
t.Fatalf(`Node should not exist`)
}
}
}
func TestSaveMeshReturnsMeshBytes(t *testing.T) {
testParams := setUpTests()
node1 := getTestNode()
testParams.manager.AddNode(node1)
bytes := testParams.manager.Save()
if len(bytes) <= 0 {
t.Fatalf(`bytes in the mesh is less than 0`)
}
}
func TestSaveMeshThenLoad(t *testing.T) {
testParams := setUpTests()
testParams2 := setUpTests()
node1 := getTestNode()
testParams.manager.AddNode(node1)
bytes := testParams.manager.Save()
err := testParams2.manager.Load(bytes)
if err != nil {
t.Error(err)
}
if len(bytes) <= 0 {
t.Fatalf(`bytes in the mesh is less than 0`)
}
mesh2, err := testParams2.manager.GetMesh()
if err != nil {
t.Error(err)
}
nodes := mesh2.GetNodes()
if lib.MapValues(nodes)[0].GetHostEndpoint() != node1.GetHostEndpoint() {
t.Fatalf(`Node should be in the list of nodes`)
}
}
func TestLengthNoNodes(t *testing.T) {
testParams := setUpTests()
if testParams.manager.Length() != 0 {
t.Fatalf(`Number of nodes should be 0`)
}
}
func TestLength1Node(t *testing.T) {
testParams := setUpTests()
node := getTestNode()
testParams.manager.AddNode(node)
if testParams.manager.Length() != 1 {
t.Fatalf(`Number of nodes should be 1`)
}
}
func TestLengthMultipleNodes(t *testing.T) {
testParams := setUpTests()
node := getTestNode()
node1 := getTestNode2()
testParams.manager.AddNode(node)
testParams.manager.AddNode(node1)
if testParams.manager.Length() != 2 {
t.Fatalf(`Number of nodes should be 2`)
}
}
func TestHasChangesNoChanges(t *testing.T) {
testParams := setUpTests()
if testParams.manager.HasChanges() {
t.Fatalf(`Should not have changes just created document`)
}
}
func TestHasChangesChanges(t *testing.T) {
testParams := setUpTests()
node := getTestNode()
testParams.manager.AddNode(node)
if !testParams.manager.HasChanges() {
t.Fatalf(`Should have changes just added node`)
}
}
func TestHasChangesSavedChanges(t *testing.T) {
testParams := setUpTests()
node := getTestNode()
testParams.manager.AddNode(node)
testParams.manager.SaveChanges()
if testParams.manager.HasChanges() {
t.Fatalf(`Should not have changes just saved document`)
}
}
func TestUpdateTimeStampNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.UpdateTimeStamp("AAAAAA")
if err == nil {
t.Fatalf(`Error should have returned`)
}
}
func TestUpdateTimeStampNodeExists(t *testing.T) {
testParams := setUpTests()
node := getTestNode()
testParams.manager.AddNode(node)
err := testParams.manager.UpdateTimeStamp(node.GetHostEndpoint())
if err != nil {
t.Error(err)
}
}
func TestSetDescriptionNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.SetDescription("AAAAA", "Bob 123")
if err == nil {
t.Fatalf(`Error should have returned`)
}
}
func TestSetDescriptionNodeExists(t *testing.T) {
testParams := setUpTests()
node := getTestNode()
err := testParams.manager.SetDescription(node.GetHostEndpoint(), "Bob 123")
if err == nil {
t.Fatalf(`Error should have returned`)
}
}
func TestAddRoutesNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.AddRoutes("AAAAA", "fd:1c64:1d00::/48")
if err == nil {
t.Error(err)
}
}
func TestCompareComparesByPublicKey(t *testing.T) {
node := getTestNode().(*MeshNodeCrdt)
node2 := getTestNode2().(*MeshNodeCrdt)
if node.Compare(node2) != -1 {
t.Fatalf(`node is alphabetically before node2`)
}
if node2.Compare(node) != 1 {
t.Fatalf(`node is alphabetical;y before node2`)
}
if node.Compare(node) != 0 {
t.Fatalf(`node is equal to node`)
}
}
func TestGetHostEndpoint(t *testing.T) {
node := getTestNode()
if (node.(*MeshNodeCrdt)).HostEndpoint != node.GetHostEndpoint() {
t.Fatalf(`get hostendpoint should get the host endpoint`)
}
}
func TestGetPublicKey(t *testing.T) {
key1, _ := wgtypes.GenerateKey()
node := getTestNode()
node.(*MeshNodeCrdt).PublicKey = key1.String()
pubKey, err := node.GetPublicKey()
if err != nil {
t.Error(err)
}
if pubKey.String() != key1.String() {
t.Fatalf(`Expected %s got %s`, key1.String(), pubKey.String())
}
}
func TestGetWgEndpoint(t *testing.T) {
node := getTestNode()
if node.(*MeshNodeCrdt).WgEndpoint != node.GetWgEndpoint() {
t.Fatal(`Did not return the correct wgEndpoint`)
}
}
func TestGetWgHost(t *testing.T) {
node := getTestNode()
ip := node.GetWgHost()
if node.(*MeshNodeCrdt).WgHost != ip.String() {
t.Fatal(`Did not parse WgHost correctly`)
}
}
func TestGetTimeStamp(t *testing.T) {
node := getTestNode()
if node.(*MeshNodeCrdt).Timestamp != node.GetTimeStamp() {
t.Fatal(`Did not return return the correct timestamp`)
}
}
func TestGetIdentifierDoesNotContainPrefix(t *testing.T) {
node := getTestNode()
if strings.Contains(node.GetIdentifier(), "/128") {
t.Fatal(`Identifier should not contain prefix`)
}
}

View File

@ -1,4 +1,4 @@
package crdt
package automerge
import (
"fmt"
@ -14,13 +14,13 @@ func (f *CrdtProviderFactory) CreateMesh(params *mesh.MeshProviderFactoryParams)
return NewCrdtNodeManager(&NewCrdtNodeMangerParams{
MeshId: params.MeshId,
DevName: params.DevName,
Conf: *params.Conf,
Conf: params.Conf,
Client: params.Client,
})
}
type MeshNodeFactory struct {
Config conf.WgMeshConfiguration
Config conf.DaemonConfiguration
}
// Build builds the mesh node that represents the host machine to add
@ -28,14 +28,23 @@ type MeshNodeFactory struct {
func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode {
hostName := f.getAddress(params)
grpcEndpoint := fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort)
if *params.MeshConfig.Role == conf.CLIENT_ROLE {
grpcEndpoint = "-"
}
return &MeshNodeCrdt{
HostEndpoint: fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort),
HostEndpoint: grpcEndpoint,
PublicKey: params.PublicKey.String(),
WgEndpoint: fmt.Sprintf("%s:%d", hostName, params.WgPort),
WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()),
// Always set the routes as empty.
// Routes handled by external component
Routes: map[string]interface{}{},
Routes: make(map[string]Route),
Description: "",
Alias: "",
Type: string(*params.MeshConfig.Role),
}
}
@ -45,10 +54,22 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string
if params.Endpoint != "" {
hostName = params.Endpoint
} else if len(f.Config.Endpoint) != 0 {
hostName = f.Config.Endpoint
} else if len(*params.MeshConfig.Endpoint) != 0 {
hostName = *params.MeshConfig.Endpoint
} else {
hostName = lib.GetOutboundIP().String()
ipFunc := lib.GetPublicIP
if *params.MeshConfig.IPDiscovery == conf.DNS_IP_DISCOVERY {
ipFunc = lib.GetOutboundIP
}
ip, err := ipFunc()
if err != nil {
return ""
}
hostName = ip.String()
}
return hostName

View File

@ -1,4 +1,10 @@
package crdt
package automerge
// Route: Represents a CRDT of the given route
type Route struct {
Destination string `automerge:"destination"`
Path []string `automerge:"path"`
}
// MeshNodeCrdt: Represents a CRDT for a mesh nodes
type MeshNodeCrdt struct {
@ -7,8 +13,11 @@ type MeshNodeCrdt struct {
PublicKey string `automerge:"publicKey"`
WgHost string `automerge:"wgHost"`
Timestamp int64 `automerge:"timestamp"`
Routes map[string]interface{} `automerge:"routes"`
Routes map[string]Route `automerge:"routes"`
Alias string `automerge:"alias"`
Description string `automerge:"description"`
Services map[string]string `automerge:"services"`
Type string `automerge:"type"`
}
// MeshCrdt: Represents the mesh network as a whole

33
pkg/cmd/cmd.go Normal file
View File

@ -0,0 +1,33 @@
// cmd is a package for running commands in the different operating systems implementations
package cmd
import (
"os/exec"
"strings"
)
type CmdRunner interface {
RunCommands(commands ...string) error
}
type UnixCmdRunner struct{}
// RunCommand: runs the unix command. It splits the command into fields
// and then runs the command accordingly
func RunCommand(cmd string) error {
args := strings.Fields(cmd)
c := exec.Command(args[0], args[1:]...)
return c.Run()
}
func (l *UnixCmdRunner) RunCommands(commands ...string) error {
for _, cmd := range commands {
err := RunCommand(cmd)
if err != nil {
return err
}
}
return nil
}

View File

@ -4,52 +4,213 @@ package conf
import (
"os"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/go-playground/validator/v10"
"gopkg.in/yaml.v3"
)
type WgMeshConfiguration struct {
type WgMeshConfigurationError struct {
msg string
}
func (m *WgMeshConfigurationError) Error() string {
return m.msg
}
type NodeType string
const (
PEER_ROLE NodeType = "peer"
CLIENT_ROLE NodeType = "client"
)
type IPDiscovery string
const (
PUBLIC_IP_DISCOVERY = "public"
DNS_IP_DISCOVERY = "dns"
)
// WgConfiguration contains per-mesh WireGuard configuration. Contains poitner types only so we can
// tell if the attribute is set
type WgConfiguration struct {
// IPDIscovery: how to discover your IP if not specified. Use your outgoing IP or use a public
// service for IPDiscoverability
IPDiscovery *IPDiscovery `yaml:"ipDiscovery" validate:"required,eq=public|eq=dns"`
// AdvertiseRoutes: specifies whether the node can act as a router routing packets between meshes
AdvertiseRoutes *bool `yaml:"advertiseRoute" validate:"required"`
// AdvertiseDefaultRoute: specifies whether or not this route should advertise a default route
// for all nodes to route their packets to
AdvertiseDefaultRoute *bool `yaml:"advertiseDefaults" validate:"required"`
// Endpoint contains what value should be set as the public endpoint of this node
Endpoint *string `yaml:"publicEndpoint"`
// Role specifies whether or not the user is globally accessible.
// If the user is globaly accessible they specify themselves as a client.
Role *NodeType `yaml:"role" validate:"required,eq=client|eq=peer"`
// KeepAliveWg configures the implementation so that we send keep alive packets to peers.
// KeepAlive can only be set if role is type client
KeepAliveWg *int `yaml:"keepAliveWg" validate:"omitempty,gte=0"`
// PreUp are WireGuard commands to run before adding the WG interface
PreUp []string `yaml:"preUp"`
// PostUp are WireGuard commands to run after adding the WG interface
PostUp []string `yaml:"postUp"`
// PreDown are WireGuard commands to run prior to removing the WG interface
PreDown []string `yaml:"preDown"`
// PostDown are WireGuard command to run after removing the WG interface
PostDown []string `yaml:"postDown"`
}
type DaemonConfiguration struct {
// CertificatePath is the path to the certificate to use in mTLS
CertificatePath string `yaml:"certificatePath"`
CertificatePath string `yaml:"certificatePath" validate:"required,file"`
// PrivateKeypath is the path to the clients private key in mTLS
PrivateKeyPath string `yaml:"privateKeyPath"`
PrivateKeyPath string `yaml:"privateKeyPath" validate:"required,file"`
// CaCeritifcatePath path to the certificate of the trust certificate authority
CaCertificatePath string `yaml:"caCertificatePath"`
CaCertificatePath string `yaml:"caCertificatePath" validate:"required,file"`
// SkipCertVerification specify to skip certificate verification. Should only be used
// in test environments
SkipCertVerification bool `yaml:"skipCertVerification"`
// Port to run the GrpcServer on
GrpcPort string `yaml:"gRPCPort"`
// AdvertiseRoutes advertises other meshes if the node is in multiple meshes
AdvertiseRoutes bool `yaml:"advertiseRoutes"`
// Endpoint is the IP in which this computer is publicly reachable.
// usecase is when the node has multiple IP addresses
Endpoint string `yaml:"publicEndpoint"`
ClusterSize int `yaml:"clusterSize"`
SyncRate float64 `yaml:"syncRate"`
InterClusterChance float64 `yaml:"interClusterChance"`
BranchRate int `yaml:"branchRate"`
InfectionCount int `yaml:"infectionCount"`
KeepAliveRate int `yaml:"keepAliveRate"`
GrpcPort int `yaml:"gRPCPort" validate:"required"`
// Timeout number of seconds without response that a node is considered unreachable by gRPC
Timeout int `yaml:"timeout" validate:"required,gte=1"`
// Profile whether or not to include a http server that profiles the code
Profile bool `yaml:"profile"`
// StubWg whether or not to stub the WireGuard types
StubWg bool `yaml:"stubWg"`
// SyncRate specifies how long the minimum time should be between synchronisation
SyncRate int `yaml:"syncRate" validate:"required,gte=1"`
// KeepAliveTime: number of seconds before the leader of the mesh sends an update to
// send to every member in the mesh
KeepAliveTime int `yaml:"keepAliveTime" validate:"required,gte=1"`
// ClusterSize specifies how many neighbours you should synchronise with per round
ClusterSize int `yaml:"clusterSize" valdiate:"required,gt=0"`
// InterClusterChance specifies the probabilityof inter-cluster communication in a sync round
InterClusterChance float64 `yaml:"interClusterChance" valdiate:"required,gt=0"`
// BranchRate specifies the number of nodes to synchronise with when a node has
// new changes to send to the mesh
BranchRate int `yaml:"branchRate" validate:"required,gte=1"`
// InfectionCount: number of time to sync before an update can no longer be 'caught'
InfectionCount int `yaml:"infectionCount" validate:"required,gte=1"`
// BaseConfiguration base WireGuard configuration to use, this is used when none is provided
BaseConfiguration WgConfiguration `yaml:"baseConfiguration" validate:"required"`
}
// ParseConfiguration parses the mesh configuration
func ParseConfiguration(filePath string) (*WgMeshConfiguration, error) {
var conf WgMeshConfiguration
// ValdiateMeshConfiguration: validates the mesh configuration
func ValidateMeshConfiguration(conf *WgConfiguration) error {
validate := validator.New(validator.WithRequiredStructEnabled())
err := validate.Struct(conf)
if conf.PostDown == nil {
conf.PostDown = make([]string, 0)
}
if conf.PostUp == nil {
conf.PostUp = make([]string, 0)
}
if conf.PreDown == nil {
conf.PreDown = make([]string, 0)
}
if conf.PreUp == nil {
conf.PreUp = make([]string, 0)
}
return err
}
// ValidateDaemonConfiguration: validates the dameon configuration that is used.
func ValidateDaemonConfiguration(c *DaemonConfiguration) error {
validate := validator.New(validator.WithRequiredStructEnabled())
err := validate.Struct(c)
return err
}
// ParseMeshConfiguration: parses the mesh network configuration. Parses parameters such as
// keepalive time, role and so forth.
func ParseMeshConfiguration(filePath string) (*WgConfiguration, error) {
var conf WgConfiguration
yamlBytes, err := os.ReadFile(filePath)
if err != nil {
logging.Log.WriteErrorf("Read file error: %s\n", err.Error())
return nil, err
}
err = yaml.Unmarshal(yamlBytes, &conf)
if err != nil {
logging.Log.WriteErrorf("Unmarshal error: %s\n", err.Error())
return nil, err
}
return &conf, nil
return &conf, ValidateMeshConfiguration(&conf)
}
// ParseDaemonConfiguration parses the mesh configuration and validates the configuration
func ParseDaemonConfiguration(filePath string) (*DaemonConfiguration, error) {
var conf DaemonConfiguration
yamlBytes, err := os.ReadFile(filePath)
if err != nil {
return nil, err
}
err = yaml.Unmarshal(yamlBytes, &conf)
if err != nil {
return nil, err
}
return &conf, ValidateDaemonConfiguration(&conf)
}
// MergemeshConfiguration: merges the configuration in precedence where the last
// element in the list takes the most and the first takes the least
func MergeMeshConfiguration(cfgs ...WgConfiguration) (WgConfiguration, error) {
var result WgConfiguration
for _, cfg := range cfgs {
if cfg.AdvertiseDefaultRoute != nil {
result.AdvertiseDefaultRoute = cfg.AdvertiseDefaultRoute
}
if cfg.AdvertiseRoutes != nil {
result.AdvertiseRoutes = cfg.AdvertiseRoutes
}
if cfg.Endpoint != nil {
result.Endpoint = cfg.Endpoint
}
if cfg.IPDiscovery != nil {
result.IPDiscovery = cfg.IPDiscovery
}
if cfg.KeepAliveWg != nil {
result.KeepAliveWg = cfg.KeepAliveWg
}
if cfg.PostDown != nil {
result.PostDown = cfg.PostDown
}
if cfg.PostUp != nil {
result.PostUp = cfg.PostUp
}
if cfg.PreDown != nil {
result.PreDown = cfg.PreDown
}
if cfg.PreUp != nil {
result.PreUp = cfg.PreUp
}
if cfg.Role != nil {
result.Role = cfg.Role
}
}
return result, ValidateMeshConfiguration(&result)
}

66
pkg/conf/conf_test.go Normal file
View File

@ -0,0 +1,66 @@
package conf
import "testing"
func getExampleConfiguration() *DaemonConfiguration {
return &DaemonConfiguration{
CertificatePath: "./cert/cert.pem",
PrivateKeyPath: "./cert/key.pem",
CaCertificatePath: "./cert/ca.pems",
SkipCertVerification: true,
}
}
func TestConfigurationCertificatePathEmpty(t *testing.T) {
conf := getExampleConfiguration()
conf.CertificatePath = ""
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestConfigurationPrivateKeyPathEmpty(t *testing.T) {
conf := getExampleConfiguration()
conf.PrivateKeyPath = ""
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestConfigurationCaCertificatePathEmpty(t *testing.T) {
conf := getExampleConfiguration()
conf.CaCertificatePath = ""
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestConfigurationGrpcPortEmpty(t *testing.T) {
conf := getExampleConfiguration()
conf.GrpcPort = 0
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestValidConfiguration(t *testing.T) {
conf := getExampleConfiguration()
err := ValidateDaemonConfiguration(conf)
if err != nil {
t.Error(err)
}
}

77
pkg/conn/cluster.go Normal file
View File

@ -0,0 +1,77 @@
package conn
import (
"errors"
"math"
"math/rand"
"slices"
)
// ConnCluster splits nodes into clusters where nodes in a cluster communicate
// frequently and nodes outside of a cluster communicate infrequently
type ConnCluster interface {
GetNeighbours(global []string, selfId string) []string
GetInterCluster(global []string, selfId string) string
}
type ConnClusterImpl struct {
clusterSize int
}
func binarySearch(global []string, selfId string, groupSize int) (int, int) {
slices.Sort(global)
lower := 0
higher := len(global) - 1
mid := (lower + higher) / 2
for (higher+1)-lower > groupSize {
if global[mid] < selfId {
lower = mid + 1
} else if global[mid] > selfId {
higher = mid - 1
} else {
break
}
mid = (lower + higher) / 2
}
return lower, int(math.Min(float64(lower+groupSize), float64(len(global))))
}
// GetNeighbours return the neighbours 'nearest' to you. In this implementation the
// neighbours aren't actually the ones nearest to you but just the ones nearest
// to you alphabetically. Perform binary search to get the total group
func (i *ConnClusterImpl) GetNeighbours(global []string, selfId string) []string {
slices.Sort(global)
lower, higher := binarySearch(global, selfId, i.clusterSize)
// slice the list to get the neighbours
return global[lower:higher]
}
// GetInterCluster get nodes not in your cluster. Every round there is a given chance
// you will communicate with a random node that is not in your cluster.
func (i *ConnClusterImpl) GetInterCluster(global []string, selfId string) string {
// Doesn't matter if not in it. Get index of where the node 'should' be
slices.Sort(global)
index, _ := binarySearch(global, selfId, 1)
randomCluster := rand.Intn(2) + 1
// cluster is considered a heap
neighbourIndex := (2*index + (randomCluster * i.clusterSize)) % len(global)
return global[neighbourIndex]
}
func NewConnCluster(clusterSize int) (ConnCluster, error) {
log2Cluster := math.Log2(float64(clusterSize))
if float64((log2Cluster))-log2Cluster != 0 {
return nil, errors.New("cluster must be a power of 2")
}
return &ConnClusterImpl{clusterSize: clusterSize}, nil
}

116
pkg/conn/cluster_test.go Normal file
View File

@ -0,0 +1,116 @@
package conn
import (
"math/rand"
"slices"
"testing"
)
func TestGetNeighboursClusterSizeTwo(t *testing.T) {
cluster := &ConnClusterImpl{
clusterSize: 2,
}
neighbours := []string{
"a",
"b",
"c",
"d",
}
result := cluster.GetNeighbours(neighbours, "b")
if len(result) != 2 {
t.Fatalf(`neighbour length should be 2`)
}
if result[0] != "a" && result[1] != "b" {
t.Fatalf(`Expected value b`)
}
}
func TestGetNeighboursGlobalListLessThanClusterSize(t *testing.T) {
cluster := &ConnClusterImpl{
clusterSize: 4,
}
neighbours := []string{
"a",
"b",
"c",
}
result := cluster.GetNeighbours(neighbours, "a")
if len(result) != 3 {
t.Fatalf(`neighbour length should be 3`)
}
slices.Sort(result)
if !slices.Equal(result, neighbours) {
t.Fatalf(`Cluster and neighbours should be equal`)
}
}
func TestGetNeighboursClusterSize4(t *testing.T) {
cluster := &ConnClusterImpl{
clusterSize: 4,
}
neighbours := []string{
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
"k", "l", "m", "n", "o",
}
result := cluster.GetNeighbours(neighbours, "k")
if len(result) != 4 {
t.Fatalf(`cluster size must be 4`)
}
slices.Sort(result)
if !slices.Equal(neighbours[8:12], result) {
t.Fatalf(`Cluster should be i, j, k, l`)
}
}
func TestGetNeighboursClusterSize4OneReturned(t *testing.T) {
cluster := &ConnClusterImpl{
clusterSize: 4,
}
neighbours := []string{
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
"k", "l", "m", "n", "o",
}
result := cluster.GetNeighbours(neighbours, "o")
if len(result) != 3 {
t.Fatalf(`Cluster should be of length 3`)
}
if !slices.Equal(neighbours[12:15], result) {
t.Fatalf(`Cluster should be m, n, o`)
}
}
func TestInterClusterNotInCluster(t *testing.T) {
rand.Seed(1)
cluster := &ConnClusterImpl{
clusterSize: 4,
}
global := []string{
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
"k", "l", "m", "n", "o",
}
neighbours := cluster.GetNeighbours(global, "c")
interCluster := cluster.GetInterCluster(global, "c")
if slices.Contains(neighbours, interCluster) {
t.Fatalf(`intercluster cannot be in your cluster`)
}
}

View File

@ -18,6 +18,8 @@ type PeerConnection interface {
GetClient() (*grpc.ClientConn, error)
}
type PeerConnectionFactory = func(clientConfig *tls.Config, server string) (PeerConnection, error)
// WgCtrlConnection implements PeerConnection.
type WgCtrlConnection struct {
clientConfig *tls.Config
@ -26,12 +28,12 @@ type WgCtrlConnection struct {
}
// NewWgCtrlConnection creates a new instance of a WireGuard control connection
func NewWgCtrlConnection(clientConfig *tls.Config, server string) (*WgCtrlConnection, error) {
func NewWgCtrlConnection(clientConfig *tls.Config, server string) (PeerConnection, error) {
var conn WgCtrlConnection
conn.clientConfig = clientConfig
conn.endpoint = server
if err := conn.createGrpcConn(); err != nil {
if err := conn.CreateGrpcConnection(); err != nil {
return nil, err
}
@ -39,7 +41,7 @@ func NewWgCtrlConnection(clientConfig *tls.Config, server string) (*WgCtrlConnec
}
// ConnectWithToken: Connects to a new gRPC peer given the address of the other server.
func (c *WgCtrlConnection) createGrpcConn() error {
func (c *WgCtrlConnection) CreateGrpcConnection() error {
conn, err := grpc.Dial(c.endpoint,
grpc.WithTransportCredentials(credentials.NewTLS(c.clientConfig)))
@ -62,7 +64,7 @@ func (c *WgCtrlConnection) GetClient() (*grpc.ClientConn, error) {
var err error = nil
if c.conn == nil {
err = errors.New("The client's config does not exist")
err = errors.New("the client's config does not exist")
}
return c.conn, err

View File

@ -34,10 +34,11 @@ type ConnectionManagerImpl struct {
clientConnections map[string]PeerConnection
serverConfig *tls.Config
clientConfig *tls.Config
connFactory PeerConnectionFactory
}
// Create a new instance of a connection manager.
type NewConnectionManageParams struct {
type NewConnectionManagerParams struct {
// The path to the certificate
CertificatePath string
// The private key of the node
@ -45,11 +46,12 @@ type NewConnectionManageParams struct {
// Whether or not to skip certificate verification
SkipCertVerification bool
CaCert string
ConnFactory PeerConnectionFactory
}
// NewConnectionManager: Creates a new instance of a ConnectionManager or an error
// if something went wrong.
func NewConnectionManager(params *NewConnectionManageParams) (ConnectionManager, error) {
func NewConnectionManager(params *NewConnectionManagerParams) (ConnectionManager, error) {
cert, err := tls.LoadX509KeyPair(params.CertificatePath, params.PrivateKey)
if err != nil {
@ -94,10 +96,12 @@ func NewConnectionManager(params *NewConnectionManageParams) (ConnectionManager,
}
connections := make(map[string]PeerConnection)
connMgr := ConnectionManagerImpl{sync.RWMutex{},
connMgr := ConnectionManagerImpl{
sync.RWMutex{},
connections,
serverConfig,
clientConfig,
params.ConnFactory,
}
return &connMgr, nil
@ -127,7 +131,7 @@ func (m *ConnectionManagerImpl) AddConnection(endPoint string) (PeerConnection,
return conn, nil
}
connections, err := NewWgCtrlConnection(m.clientConfig, endPoint)
connections, err := m.connFactory(m.clientConfig, endPoint)
if err != nil {
return nil, err

View File

@ -0,0 +1,145 @@
package conn
import (
"crypto/tls"
"errors"
"log"
"testing"
)
func getConnectionManagerParams() *NewConnectionManagerParams {
return &NewConnectionManagerParams{
CertificatePath: "./test/cert.pem",
PrivateKey: "./test/priv.pem",
CaCert: "./test/cacert.pem",
SkipCertVerification: false,
ConnFactory: MockFactory,
}
}
func TestNewConnectionManagerCertificatePathDoesNotExist(t *testing.T) {
params := getConnectionManagerParams()
params.CertificatePath = "./cert/sdfjdskjdsjkd.pem"
_, err := NewConnectionManager(params)
if err == nil {
t.Fatalf(`Expected error as certificate does not exist`)
}
}
func TestNewConnectionManagerPrivateKeyDoesNotExist(t *testing.T) {
params := getConnectionManagerParams()
params.PrivateKey = "./cert/sdjdjdks.pem"
_, err := NewConnectionManager(params)
if err == nil {
t.Fatalf(`Expected error as private key does not exist`)
}
}
func TestNewConnectionManagerCACertDoesNotExistAndVerify(t *testing.T) {
params := getConnectionManagerParams()
params.CaCert = "./cert/sdjdsjdksjdks.pem"
params.SkipCertVerification = false
_, err := NewConnectionManager(params)
if err == nil {
t.Fatal(`Expected error as ca cert does not exist and skip is false`)
}
}
func TestNewConnectionManagerCACertDoesNotExistAndNotVerify(t *testing.T) {
params := getConnectionManagerParams()
params.CaCert = ""
params.SkipCertVerification = true
_, err := NewConnectionManager(params)
if err != nil {
t.Fatal(`an error should not be thrown`)
}
}
func TestGetConnectionConnectionDoesNotExistAddsConnection(t *testing.T) {
params := getConnectionManagerParams()
m, _ := NewConnectionManager(params)
conn, err := m.GetConnection("abc-123.com")
if err != nil {
t.Error(err)
}
if conn == nil {
t.Fatal(`the connection should not be nil`)
}
conn2, _ := m.GetConnection("abc-123.com")
if conn != conn2 {
log.Fatalf(`should return the same connection instance`)
}
}
func TestAddConnectionThrowsAnErrorIfFactoryThrowsError(t *testing.T) {
params := getConnectionManagerParams()
params.ConnFactory = func(clientConfig *tls.Config, server string) (PeerConnection, error) {
return nil, errors.New("this is an error")
}
m, _ := NewConnectionManager(params)
_, err := m.AddConnection("abc-123.com")
if err == nil || err.Error() != "this is an error" {
t.Error(err)
}
}
func TestAddConnectionConnectionDoesNotExist(t *testing.T) {
params := getConnectionManagerParams()
m, _ := NewConnectionManager(params)
conn, err := m.AddConnection("abc-123.com")
if err != nil {
t.Error(err)
}
if conn == nil {
t.Fatal(`connection should not be nil`)
}
conn1, _ := m.GetConnection("abc-123.com")
if conn != conn1 {
t.Fatal(`underlying connections should be the same`)
}
}
func TestHasConnectionConnectionDoesNotExist(t *testing.T) {
params := getConnectionManagerParams()
m, _ := NewConnectionManager(params)
if m.HasConnection("abc-123.com") {
t.Fatal(`should return that the connection does not exist`)
}
}
func TestHasConnectionConnectionExists(t *testing.T) {
params := getConnectionManagerParams()
m, _ := NewConnectionManager(params)
m.AddConnection("abc-123.com")
if !m.HasConnection("abc-123.com") {
t.Fatal(`should return that the connection exists`)
}
}

View File

@ -2,6 +2,7 @@ package conn
import (
"crypto/tls"
"fmt"
"net"
"github.com/tim-beatham/wgmesh/pkg/conf"
@ -16,21 +17,18 @@ type ConnectionServer struct {
// tlsConfiguration of the server
serverConfig *tls.Config
// server an instance of the grpc server
server *grpc.Server
// the authentication service to authenticate nodes
authProvider rpc.AuthenticationServer
server *grpc.Server // the authentication service to authenticate nodes
// the ctrl service to manage node
ctrlProvider rpc.MeshCtrlServerServer
// the sync service to synchronise nodes
syncProvider rpc.SyncServiceServer
Conf *conf.WgMeshConfiguration
Conf *conf.DaemonConfiguration
listener net.Listener
}
// NewConnectionServerParams contains params for creating a new connection server
type NewConnectionServerParams struct {
Conf *conf.WgMeshConfiguration
AuthProvider rpc.AuthenticationServer
Conf *conf.DaemonConfiguration
CtrlProvider rpc.MeshCtrlServerServer
SyncProvider rpc.SyncServiceServer
}
@ -59,14 +57,12 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
grpc.Creds(credentials.NewTLS(serverConfig)),
)
authProvider := params.AuthProvider
ctrlProvider := params.CtrlProvider
syncProvider := params.SyncProvider
connServer := ConnectionServer{
serverConfig: serverConfig,
server: server,
authProvider: authProvider,
ctrlProvider: ctrlProvider,
syncProvider: syncProvider,
Conf: params.Conf,
@ -78,14 +74,13 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
// Listen for incoming requests. Returns an error if something went wrong.
func (s *ConnectionServer) Listen() error {
rpc.RegisterMeshCtrlServerServer(s.server, s.ctrlProvider)
rpc.RegisterAuthenticationServer(s.server, s.authProvider)
rpc.RegisterSyncServiceServer(s.server, s.syncProvider)
lis, err := net.Listen("tcp", ":"+s.Conf.GrpcPort)
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", s.Conf.GrpcPort))
s.listener = lis
logging.Log.WriteInfof("GRPC listening on %s\n", s.Conf.GrpcPort)
logging.Log.WriteInfof("GRPC listening on %d\n", s.Conf.GrpcPort)
if err != nil {
logging.Log.WriteErrorf(err.Error())

51
pkg/conn/stub.go Normal file
View File

@ -0,0 +1,51 @@
package conn
import (
"crypto/tls"
"google.golang.org/grpc"
)
type ConnectionManagerStub struct {
Endpoints map[string]PeerConnection
}
func (s *ConnectionManagerStub) AddConnection(endPoint string) (PeerConnection, error) {
mock := &PeerConnectionMock{}
s.Endpoints[endPoint] = mock
return mock, nil
}
func (s *ConnectionManagerStub) GetConnection(endPoint string) (PeerConnection, error) {
endpoint, ok := s.Endpoints[endPoint]
if !ok {
return s.AddConnection(endPoint)
}
return endpoint, nil
}
func (s *ConnectionManagerStub) HasConnection(endPoint string) bool {
_, ok := s.Endpoints[endPoint]
return ok
}
func (s *ConnectionManagerStub) Close() error {
return nil
}
type PeerConnectionMock struct {
}
func (c *PeerConnectionMock) Close() error {
return nil
}
func (c *PeerConnectionMock) GetClient() (*grpc.ClientConn, error) {
return &grpc.ClientConn{}, nil
}
var MockFactory PeerConnectionFactory = func(clientConfig *tls.Config, server string) (PeerConnection, error) {
return &PeerConnectionMock{}, nil
}

21
pkg/conn/test/cacert.pem Normal file
View File

@ -0,0 +1,21 @@
-----BEGIN CERTIFICATE-----
MIIDazCCAlOgAwIBAgIUDRIRI8UnHU2a4znsun0gxFwlrFQwDQYJKoZIhvcNAQEL
BQAwRTELMAkGA1UEBhMCR0IxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMzEwMjcxNTIzMDZaFw0yNDEw
MjYxNTIzMDZaMEUxCzAJBgNVBAYTAkdCMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw
HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB
AQUAA4IBDwAwggEKAoIBAQDJ5hOmzilimA/zM5hYP7CQf4iRmICtSbVLgt6/rTDP
p3JsGGQWZ4pZNofzGnGa7aEMoXS2Ztl7GzZbr1p4+rd6MBbVt8XZ/hP+X4zasCXi
/YubG0TYyBuAt+JrcYb0cbsTBkMXXnFcNIXDfeYFsNq+pfyJwq2ElMUUZ6SQmVhH
ovn1Wk9Fv4t2GJMhmUcObrSIoYdgo4Vf9CfQnn0PCaRf+RjspY/Kz33oyqDI6xJx
I0rfJR7f9B6ZKosfAkt4oTTfT9P8w/d1I95oBENhDkalgkdJCuNJ/AwKGxZrYf/P
aefcc91HheauObjBYPFrSn6bUj3LMJEfj4IeBK+fOZCfAgMBAAGjUzBRMB0GA1Ud
DgQWBBSpcF7jtpd9n73VM3xhPmI1GMEkFjAfBgNVHSMEGDAWgBSpcF7jtpd9n73V
M3xhPmI1GMEkFjAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQCK
GplAveP9nVo9zmg+/mkDpyVoo5rp64oJh4DFtm6X+EI31FmH6Cb71Kn2ZzXhQvSq
qrP7+VoGeBDxk4guJtAs/fhnuDupJG2SjsctjiFnDbSrJjWJjGhC0kuL0wcjLU5G
qUpCEJu13GkDlYHKKw0z+oLUOw+OHmvE5/sD23sKl2KxBWKItx0hwSCkGtm0RQld
8mfjOsHqJ2V/FOcHK6X2DSV1728PAhu4l/PRSB0drBA+7kdeCuWIRZw5RA/OyxvU
CuC5dfUh75MrK7KL6sZsXklsoXo8BZp4rRRUt/v1D3r/SMBJPULSGXh6QDjXQX1D
km71c3DEDyKznHTpGxPt
-----END CERTIFICATE-----

View File

@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDJ5hOmzilimA/z
M5hYP7CQf4iRmICtSbVLgt6/rTDPp3JsGGQWZ4pZNofzGnGa7aEMoXS2Ztl7GzZb
r1p4+rd6MBbVt8XZ/hP+X4zasCXi/YubG0TYyBuAt+JrcYb0cbsTBkMXXnFcNIXD
feYFsNq+pfyJwq2ElMUUZ6SQmVhHovn1Wk9Fv4t2GJMhmUcObrSIoYdgo4Vf9CfQ
nn0PCaRf+RjspY/Kz33oyqDI6xJxI0rfJR7f9B6ZKosfAkt4oTTfT9P8w/d1I95o
BENhDkalgkdJCuNJ/AwKGxZrYf/Paefcc91HheauObjBYPFrSn6bUj3LMJEfj4Ie
BK+fOZCfAgMBAAECggEADqAjoUxC9Dj2wtPkf9QRSs5qSr3E6Iiz4OX4k+MMa6aC
I/F6YqMagw7vtz0dqK75ISybA1GdBI16mRaxU5056FiOdunqo7mDokQytG7ZN8HN
OK23hYqtb1wiw0zEjXWlqyGjf5BgXuERJZG7tYLTvcbRbftTzYxnYGyHn8/z9LBp
GsTJ5X8XMLM5+bTvg1Ovv5s0q31FCeqAuw+auHH4pBNP+ylV6dF5XOWq4HO3TJ2b
grHxWB94JZChZnDC/K+HxQ6aHJfbZ5XCoXfIaIVkoXfnyPzgjvgK+/IpHEF8f/3I
uT/NBiArTpRl29pX5flEO4R121VaW93eM1tuzL32VQKBgQD6Trctx9SYuhzgfiO7
kdefvR43Kl9SFyEw3hN3HW1cxSNGCCFotjmdem+QdtMBtUd27UJ9tuiKJC0lcCER
t3WRz4kVd/cb0eC1DPzpGHA81o1rUUR3nMr1o7aBfvQ06VAxFUrFAOPpF8nD7tI4
0CiOh7/sL1ElThA3bOPUpXkYHQKBgQDOfYbP8dppIkC8pRTnHWe0qUY0G4YXxg7r
UtTo4GYOLJeKH/MKoK8MjBDS5VN5n5TAHJ8yUVzhpWXZIPIGzNEhIRDMa56sRPgI
9mLJNs5z/ZIxd/7ZQbDHrD4T3PKeTjzVUtjXrhLowokPlPB/RMQL6ZT+qMao+3bS
fDITSfLG6wKBgBpbcZSDh1JxvpqxDagxqkfqzSS39IObZeZUbC5NzfdH1vgH4SS6
k4SOoPLQYFW8tgLC5w5/1Sq+tnZLwV+xNtMczG2TTVUDm6rU7EjLRv5RBWE4lIIX
45NMIuqt6J8ttkEE4fOurVEdLSTRoBdVa//eMYp4TQ4lkzWS5Ma+ierNAoGAYO3z
1rFFQYzerq8ffM4E3H2JgvRYodhLMJQVdavAvG6aRDBzOk3rXgxx6U3VPYZ3oSbO
ZCRlYVbu1FnuwtpqYQ7Qf+UU+vD1Ld/ax3F+wFwLwET/0KRRg6mLCm/xQ/ad/9WA
DN6d6b1H8ZSMwHFbRexEELbRaomAYZYDO6K+4DkCgYEAv5De85hPnWtAvKhPzwQi
9mtyWo/cfQgtwL8IKNu6hBHl5RXDpPgX/+pNbXLJfBPwVR3H62x1CMYJDkWVuE6/
ZjtF7FSucZMz/mR6r1GhSOXy3YLwQ6JLPjjKzvnEjahGlKwALJNL0O2ZucjsZxHE
PM4rmhRZT9opiapiltEhRm0=
-----END PRIVATE KEY-----

19
pkg/conn/test/cert.pem Normal file
View File

@ -0,0 +1,19 @@
-----BEGIN CERTIFICATE-----
MIIDCjCCAfICFB/Vd2eOXWdNdrakThJhFIRtZmhUMA0GCSqGSIb3DQEBCwUAMEUx
CzAJBgNVBAYTAkdCMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRl
cm5ldCBXaWRnaXRzIFB0eSBMdGQwHhcNMjMxMDI3MTUzNDM1WhcNMjMxMTI2MTUz
NDM1WjA+MQswCQYDVQQGEwJHQjENMAsGA1UECAwERmlmZTENMAsGA1UEBwwEY2l0
eTERMA8GA1UECgwITWVzaCBMdGQwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEK
AoIBAQDVgcLtNU5AYfPML/mE5PyC7YYKvZn2mt6vEiJ7M/6EzYeTXFeYexD5ZqHg
ewGEd1fwiQWQsATsWd+EM4OnCAXAaNiOH6gGY7FR8CThfT+k8yIGPrl1BovzHHYS
Orekna17UFeIyFMHDPIjl4d2WiJPvmNn5PhLEppPHPBWPhl3J3sMrSbqyRuYbtta
oFIzN8mFcikixLg0SnBPtwlLC72ah9G+MF5CwEcU/E0bYbLQZXv+WhG5aw5JEzes
K2GLxVNgM0xXB7hSyLoX1wBc8DdQyLCMkOp55Hl04UKTxtVE82MiuAOVqMUuKFjR
u2a1C+/Gbk/PS5SHgenGjdZ8sZGpAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAHMc
jIFG5Rn9KaVmo7E+/UAq+3ld/3y2yMHg5wq7oG8b7/z0mlSGErHdFMzo75AFLN4r
kOuiF5ItF6dRLNrG8IUFSNMGVH3b3ukw1EI8E89L8ak3CM+wpLT6GVP3BfV8ah+X
4RRix40Tmx4C81l+Lf5W10rHIdlXBCanJy/Fa0ae+S+oXFc9jeXHlK9qlgszrECT
Pa3VCR95LAIc6o9pDL2Z8tpEkSbyzvIWhp53fnC80PyXpSsFMfIw657shagBc/Ov
e7/aPpPf3V3CafJlEIraQp24MDI5ZM59lT5vhRq2AC50gelL6UPV16mVVUlGVhWE
vYyejod5i5ZbuLFOy2g=
-----END CERTIFICATE-----

28
pkg/conn/test/priv.pem Normal file
View File

@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDVgcLtNU5AYfPM
L/mE5PyC7YYKvZn2mt6vEiJ7M/6EzYeTXFeYexD5ZqHgewGEd1fwiQWQsATsWd+E
M4OnCAXAaNiOH6gGY7FR8CThfT+k8yIGPrl1BovzHHYSOrekna17UFeIyFMHDPIj
l4d2WiJPvmNn5PhLEppPHPBWPhl3J3sMrSbqyRuYbttaoFIzN8mFcikixLg0SnBP
twlLC72ah9G+MF5CwEcU/E0bYbLQZXv+WhG5aw5JEzesK2GLxVNgM0xXB7hSyLoX
1wBc8DdQyLCMkOp55Hl04UKTxtVE82MiuAOVqMUuKFjRu2a1C+/Gbk/PS5SHgenG
jdZ8sZGpAgMBAAECggEARJNAggLYhtpPPVp9WJ9ZsU3L+0AppujYND/tXkf1bD89
V+nVYq7IZWp+/MRVWPAiCSphZLb8ZdN59JK9KtVrT4D9aSymwaKcjfZFSj15xyem
Wn4j///hzGxsSe+dE1znnw9PhindbQrN7Pua8TsDATzj3bdPvoETmexwDysz765i
u4zXvxP+xAessz1OYa5IUaDXdlWOf0e1zNXWwanjRggzCeWR3lTofG49GX087oVC
Sb9ASy+AScnOlwpTdQ8sKy1r9gXmE5ey4AULVb0nJ8LDvrCoBBhKBtVE5mJHepE6
bdC9l6poL6roGvHfMAo3SmiUUT5XceqUxBtHcyHX3wKBgQD1uh+Dv0PrH3CTW9cF
bwHL1rmQNJrbDzDAaounGBe9mcot1RrBhyQAoGw1no4c+QWDAwYRuBP2+Rp6JLU/
XnEXSyN85rJN6LajlrLEr+BNmKw6ghNsnAFUZBLaJ7epRi6OjACUwmtvH6hRIef8
aMg4WiOyDT+Z4Xe81pdXb91HXwKBgQDebs3idgVEau3LCKGYnqvmUhzv8iiQiJmD
R29o2G5Xrahf3r1O5gJdGLO1DaCBtdrI7J4xUOlM935KaEYFe5B7RVGXg23tNWgb
2M+YQqu5qz61bDxhg7dGkegHrdvKNcSkV6GUSm5w9rdxJlY8+l45p/7QpSkatcbd
IRiVzMNr9wKBgQC/+Z5fbpFgYxqvdaPicdxkZShqOj71f8OlwFfEvrTlgv4KmqAh
rDP7bVm89leu2PpuZXFbbIXkgK8n1//mNyGBgkmCbjXFWlc+LSETOxixZuK/fxov
0x3S0bBM0ZTSYatD4KsfjVkj4wa8BBJbB33NUNbsZx9WWGkUlk58mD+3XwKBgQDV
mgR+n6WJQUIfwqckH+Ol517AkYSg33zEE9qKDaVQ74QMpKKY3MqSSkFw8agcR93V
K1zysOeJsPYHUEFFzJY/up6S6HSs4aebbkZUylmMkEVFBa6qWkmrLDxs+2lgsuem
hjy1YhDSzCn3L8CLCEdqCMjr5l8ltkBFZB3u5NcZmwKBgHE9ODedQm783JfvDNBb
lB/IoUjMhMR0J2vHC3zxgTU4nIK+MR0vXvA7fmZebpaQNwYrHY9gvrL0/QevOrmG
PtXlkQ9GITMxTlqfHWV5jXZuRBIGTqh1QW3tKbVAhUhNlM0XDNBmBvjKIFjxUIo3
zMRw/o4R4cIaazyVxguZbsa2
-----END PRIVATE KEY-----

84
pkg/conn/window.go Normal file
View File

@ -0,0 +1,84 @@
package conn
import (
"errors"
"slices"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
// ConnectionWindow maintains a sliding window of connections between users
type ConnectionWindow interface {
// GetWindow is a list of connections to choose from
GetWindow() []string
// SlideConnection removes a node from the window and adds a random node
// not already in the window. connList represents the list of possible
// connections to choose from
SlideConnection(connList []string) error
// PushConneciton is used when connection list less than window size.
PutConnection(conn []string) error
// IsFull returns true if the window is full. In which case we must slide the window
IsFull() bool
}
type ConnectionWindowImpl struct {
window []string
windowSize int
}
// GetWindow gets the current list of active connections in
// the window
func (c *ConnectionWindowImpl) GetWindow() []string {
return c.window
}
// SlideConnection slides the connection window by one shuffling items
// in the windows
func (c *ConnectionWindowImpl) SlideConnection(connList []string) error {
// If the number of peer connections is less than the length of the window
// then exit early. Can't slide the window it should contain all nodes!
if len(c.window) < c.windowSize {
return nil
}
filter := func(node string) bool {
return !slices.Contains(c.window, node)
}
pool := lib.Filter(connList, filter)
newNode := lib.RandomSubsetOfLength(pool, 1)
if len(newNode) == 0 {
return errors.New("could not slide window")
}
for i := len(c.window) - 1; i >= 1; i-- {
c.window[i] = c.window[i-1]
}
c.window[0] = newNode[0]
return nil
}
// PutConnection put random connections in the connection
func (c *ConnectionWindowImpl) PutConnection(connList []string) error {
if len(c.window) >= c.windowSize {
return errors.New("cannot place connection. Window full need to slide")
}
c.window = lib.RandomSubsetOfLength(connList, c.windowSize)
return nil
}
func (c *ConnectionWindowImpl) IsFull() bool {
return len(c.window) >= c.windowSize
}
func NewConnectionWindow(windowLength int) ConnectionWindow {
window := &ConnectionWindowImpl{
window: make([]string, 0),
windowSize: windowLength,
}
return window
}

514
pkg/crdt/datastore.go Normal file
View File

@ -0,0 +1,514 @@
package crdt
import (
"bytes"
"encoding/gob"
"fmt"
"net"
"slices"
"strings"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type Route struct {
Destination string
Path []string
}
// GetDestination implements mesh.Route.
func (r *Route) GetDestination() *net.IPNet {
_, ipnet, _ := net.ParseCIDR(r.Destination)
return ipnet
}
// GetHopCount implements mesh.Route.
func (r *Route) GetHopCount() int {
return len(r.Path)
}
// GetPath implements mesh.Route.
func (r *Route) GetPath() []string {
return r.Path
}
type MeshNode struct {
HostEndpoint string
WgEndpoint string
PublicKey string
WgHost string
Timestamp int64
Routes map[string]Route
Alias string
Description string
Services map[string]string
Type string
Tombstone bool
}
// Mark: marks the node is unreachable. This is not broadcast on
// syncrhonisation
func (m *TwoPhaseStoreMeshManager) Mark(nodeId string) {
m.store.Mark(nodeId)
}
// GetHostEndpoint: gets the gRPC endpoint of the node
func (n *MeshNode) GetHostEndpoint() string {
return n.HostEndpoint
}
// GetPublicKey: gets the public key of the node
func (n *MeshNode) GetPublicKey() (wgtypes.Key, error) {
return wgtypes.ParseKey(n.PublicKey)
}
// GetWgEndpoint(): get IP and port of the wireguard endpoint
func (n *MeshNode) GetWgEndpoint() string {
return n.WgEndpoint
}
// GetWgHost: get the IP address of the WireGuard node
func (n *MeshNode) GetWgHost() *net.IPNet {
_, ipnet, _ := net.ParseCIDR(n.WgHost)
return ipnet
}
// GetTimestamp: get the UNIX time stamp of the ndoe
func (n *MeshNode) GetTimeStamp() int64 {
return n.Timestamp
}
// GetRoutes: returns the routes that the nodes provides
func (n *MeshNode) GetRoutes() []mesh.Route {
routes := make([]mesh.Route, len(n.Routes))
for index, route := range lib.MapValues(n.Routes) {
routes[index] = &Route{
Destination: route.Destination,
Path: route.Path,
}
}
return routes
}
// GetIdentifier: returns the identifier of the node
func (m *MeshNode) GetIdentifier() string {
ipv6 := m.WgHost[:len(m.WgHost)-4]
constituents := strings.Split(ipv6, ":")
constituents = constituents[4:]
return strings.Join(constituents, ":")
}
// GetDescription: returns the description for this node
func (n *MeshNode) GetDescription() string {
return n.Description
}
// GetAlias: associates the node with an alias. Potentially used
// for DNS and so forth.
func (n *MeshNode) GetAlias() string {
return n.Alias
}
// GetServices: returns a list of services offered by the node
func (n *MeshNode) GetServices() map[string]string {
return n.Services
}
func (n *MeshNode) GetType() conf.NodeType {
return conf.NodeType(n.Type)
}
type MeshSnapshot struct {
Nodes map[string]MeshNode
}
// GetNodes() returns the nodes in the mesh
func (m *MeshSnapshot) GetNodes() map[string]mesh.MeshNode {
newMap := make(map[string]mesh.MeshNode)
for key, value := range m.Nodes {
newMap[key] = &MeshNode{
HostEndpoint: value.HostEndpoint,
PublicKey: value.PublicKey,
WgHost: value.WgHost,
WgEndpoint: value.WgEndpoint,
Timestamp: value.Timestamp,
Routes: value.Routes,
Alias: value.Alias,
Description: value.Description,
Services: value.Services,
Type: value.Type,
}
}
return newMap
}
type TwoPhaseStoreMeshManager struct {
MeshId string
IfName string
Client *wgctrl.Client
LastClock uint64
conf *conf.WgConfiguration
daemonConf *conf.DaemonConfiguration
store *TwoPhaseMap[string, MeshNode]
}
// AddNode() adds a node to the mesh
func (m *TwoPhaseStoreMeshManager) AddNode(node mesh.MeshNode) {
crdt, ok := node.(*MeshNode)
if !ok {
panic("node must be of type mesh node")
}
crdt.Routes = make(map[string]Route)
crdt.Services = make(map[string]string)
crdt.Timestamp = time.Now().Unix()
m.store.Put(crdt.PublicKey, *crdt)
}
// GetMesh() returns a snapshot of the mesh provided by the mesh provider.
func (m *TwoPhaseStoreMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
nodes := m.store.AsList()
snapshot := make(map[string]MeshNode)
for _, node := range nodes {
snapshot[node.PublicKey] = node
}
return &MeshSnapshot{
Nodes: snapshot,
}, nil
}
// GetMeshId() returns the ID of the mesh network
func (m *TwoPhaseStoreMeshManager) GetMeshId() string {
return m.MeshId
}
// Save() saves the mesh network
func (m *TwoPhaseStoreMeshManager) Save() []byte {
snapshot := m.store.Snapshot()
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(*snapshot)
if err != nil {
logging.Log.WriteInfof(err.Error())
}
return buf.Bytes()
}
// Load() loads a mesh network
func (m *TwoPhaseStoreMeshManager) Load(bs []byte) error {
buf := bytes.NewBuffer(bs)
dec := gob.NewDecoder(buf)
var snapshot TwoPhaseMapSnapshot[string, MeshNode]
err := dec.Decode(&snapshot)
m.store.Merge(snapshot)
return err
}
// GetDevice() get the device corresponding with the mesh
func (m *TwoPhaseStoreMeshManager) GetDevice() (*wgtypes.Device, error) {
dev, err := m.Client.Device(m.IfName)
if err != nil {
return nil, err
}
return dev, nil
}
// HasChanges returns true if we have changes since last time we synced
func (m *TwoPhaseStoreMeshManager) HasChanges() bool {
clockValue := m.store.GetHash()
return clockValue != m.LastClock
}
// Record that we have changes and save the corresponding changes
func (m *TwoPhaseStoreMeshManager) SaveChanges() {
clockValue := m.store.GetHash()
m.LastClock = clockValue
}
// UpdateTimeStamp: update the timestamp of the given node
func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
// Sort nodes by their public key
peers := m.GetPeers()
slices.Sort(peers)
if len(peers) == 0 {
return nil
}
peerToUpdate := peers[0]
if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.daemonConf.KeepAliveTime) {
m.store.Mark(peerToUpdate)
if len(peers) < 2 {
return nil
}
peerToUpdate = peers[1]
}
if peerToUpdate != nodeId {
return nil
}
// Refresh causing node to update it's time stamp
node := m.store.Get(nodeId)
node.Timestamp = time.Now().Unix()
m.store.Put(nodeId, node)
return nil
}
// AddRoutes: adds routes to the given node
func (m *TwoPhaseStoreMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
if len(routes) == 0 {
return nil
}
node := m.store.Get(nodeId)
changes := false
for _, route := range routes {
prevRoute, ok := node.Routes[route.GetDestination().String()]
if !ok || route.GetHopCount() < prevRoute.GetHopCount() {
changes = true
node.Routes[route.GetDestination().String()] = Route{
Destination: route.GetDestination().String(),
Path: route.GetPath(),
}
}
}
if changes {
m.store.Put(nodeId, node)
}
return nil
}
// DeleteRoutes: deletes the routes from the node
func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
if len(routes) == 0 {
return nil
}
node := m.store.Get(nodeId)
changes := false
for _, route := range routes {
changes = true
delete(node.Routes, route.GetDestination().String())
}
if changes {
m.store.Put(nodeId, node)
}
return nil
}
// GetSyncer: returns the automerge syncer for sync
func (m *TwoPhaseStoreMeshManager) GetSyncer() mesh.MeshSyncer {
return NewTwoPhaseSyncer(m)
}
// GetNode get a particular not within the mesh
func (m *TwoPhaseStoreMeshManager) GetNode(nodeId string) (mesh.MeshNode, error) {
if !m.store.Contains(nodeId) {
return nil, fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
node := m.store.Get(nodeId)
return &node, nil
}
// NodeExists: returns true if a particular node exists false otherwise
func (m *TwoPhaseStoreMeshManager) NodeExists(nodeId string) bool {
return m.store.Contains(nodeId)
}
// SetDescription: sets the description of this automerge data type
func (m *TwoPhaseStoreMeshManager) SetDescription(nodeId string, description string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
node := m.store.Get(nodeId)
node.Description = description
m.store.Put(nodeId, node)
return nil
}
// SetAlias: set the alias of the nodeId
func (m *TwoPhaseStoreMeshManager) SetAlias(nodeId string, alias string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
node := m.store.Get(nodeId)
node.Alias = alias
m.store.Put(nodeId, node)
return nil
}
// AddService: adds the service to the given node
func (m *TwoPhaseStoreMeshManager) AddService(nodeId string, key string, value string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
node := m.store.Get(nodeId)
node.Services[key] = value
m.store.Put(nodeId, node)
return nil
}
// RemoveService: removes the service form the node. throws an error if the service does not exist
func (m *TwoPhaseStoreMeshManager) RemoveService(nodeId string, key string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
node := m.store.Get(nodeId)
delete(node.Services, key)
m.store.Put(nodeId, node)
return nil
}
// Prune: prunes all nodes that have not updated their timestamp in
func (m *TwoPhaseStoreMeshManager) Prune() error {
m.store.Prune()
return nil
}
// GetPeers: get a list of contactable peers
func (m *TwoPhaseStoreMeshManager) GetPeers() []string {
nodes := m.store.AsList()
nodes = lib.Filter(nodes, func(mn MeshNode) bool {
if mn.Type != string(conf.PEER_ROLE) {
return false
}
// If the node is marked as unreachable don't consider it a peer.
// this help to optimize convergence time for unreachable nodes.
// However advertising it to other nodes could result in flapping.
if m.store.IsMarked(mn.PublicKey) {
return false
}
return true
})
return lib.Map(nodes, func(mn MeshNode) string {
return mn.PublicKey
})
}
func (m *TwoPhaseStoreMeshManager) getRoutes(targetNode string) (map[string]Route, error) {
if !m.store.Contains(targetNode) {
return nil, fmt.Errorf("getRoute: cannot get route %s does not exist", targetNode)
}
node := m.store.Get(targetNode)
return node.Routes, nil
}
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen
func (m *TwoPhaseStoreMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) {
node, err := m.GetNode(targetNode)
if err != nil {
return nil, err
}
routes := make(map[string]mesh.Route)
// Add routes that the node directly has
for _, route := range node.GetRoutes() {
routes[route.GetDestination().String()] = route
}
// Work out the other routes in the mesh
for _, node := range m.GetPeers() {
nodeRoutes, err := m.getRoutes(node)
if err != nil {
return nil, err
}
for _, route := range nodeRoutes {
otherRoute, ok := routes[route.GetDestination().String()]
hopCount := route.GetHopCount()
if node != targetNode {
hopCount += 1
}
if !ok || route.GetHopCount()+1 < otherRoute.GetHopCount() {
routes[route.GetDestination().String()] = &Route{
Destination: route.GetDestination().String(),
Path: append(route.GetPath(), m.GetMeshId()),
}
}
}
}
return routes, nil
}
// RemoveNode(): remove the node from the mesh
func (m *TwoPhaseStoreMeshManager) RemoveNode(nodeId string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
m.store.Remove(nodeId)
return nil
}
// GetConfiguration implements mesh.MeshProvider.
func (m *TwoPhaseStoreMeshManager) GetConfiguration() *conf.WgConfiguration {
return m.conf
}

83
pkg/crdt/factory.go Normal file
View File

@ -0,0 +1,83 @@
package crdt
import (
"fmt"
"hash/fnv"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
)
type TwoPhaseMapFactory struct {
Config *conf.DaemonConfiguration
}
func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) {
return &TwoPhaseStoreMeshManager{
MeshId: params.MeshId,
IfName: params.DevName,
Client: params.Client,
conf: params.Conf,
daemonConf: params.DaemonConf,
store: NewTwoPhaseMap[string, MeshNode](params.NodeID, func(s string) uint64 {
h := fnv.New64a()
h.Write([]byte(s))
return h.Sum64()
}, uint64(3*f.Config.KeepAliveTime)),
}, nil
}
type MeshNodeFactory struct {
Config conf.DaemonConfiguration
}
func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode {
hostName := f.getAddress(params)
grpcEndpoint := fmt.Sprintf("%s:%d", hostName, f.Config.GrpcPort)
wgEndpoint := fmt.Sprintf("%s:%d", hostName, params.WgPort)
if *params.MeshConfig.Role == conf.CLIENT_ROLE {
grpcEndpoint = "-"
wgEndpoint = "-"
}
return &MeshNode{
HostEndpoint: grpcEndpoint,
PublicKey: params.PublicKey.String(),
WgEndpoint: wgEndpoint,
WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()),
Routes: make(map[string]Route),
Description: "",
Alias: "",
Type: string(*params.MeshConfig.Role),
}
}
// getAddress returns the routable address of the machine.
func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string {
var hostName string = ""
if params.Endpoint != "" {
hostName = params.Endpoint
} else if params.MeshConfig.Endpoint != nil && len(*params.MeshConfig.Endpoint) != 0 {
hostName = *params.MeshConfig.Endpoint
} else {
ipFunc := lib.GetPublicIP
if *params.MeshConfig.IPDiscovery == conf.DNS_IP_DISCOVERY {
ipFunc = lib.GetOutboundIP
}
ip, err := ipFunc()
if err != nil {
return ""
}
hostName = ip.String()
}
return hostName
}

176
pkg/crdt/g_map.go Normal file
View File

@ -0,0 +1,176 @@
// crdt is a golang implementation of a crdt
package crdt
import (
"cmp"
"sync"
)
type Bucket[D any] struct {
Vector uint64
Contents D
Gravestone bool
}
// GMap is a set that can only grow in size
type GMap[K cmp.Ordered, D any] struct {
lock sync.RWMutex
contents map[uint64]Bucket[D]
clock *VectorClock[K]
}
func (g *GMap[K, D]) Put(key K, value D) {
g.lock.Lock()
clock := g.clock.IncrementClock()
g.contents[g.clock.hashFunc(key)] = Bucket[D]{
Vector: clock,
Contents: value,
}
g.lock.Unlock()
}
func (g *GMap[K, D]) Contains(key K) bool {
return g.contains(g.clock.hashFunc(key))
}
func (g *GMap[K, D]) contains(key uint64) bool {
g.lock.RLock()
_, ok := g.contents[key]
g.lock.RUnlock()
return ok
}
func (g *GMap[K, D]) put(key uint64, b Bucket[D]) {
g.lock.Lock()
if g.contents[key].Vector < b.Vector {
g.contents[key] = b
}
g.lock.Unlock()
}
func (g *GMap[K, D]) get(key uint64) Bucket[D] {
g.lock.RLock()
bucket := g.contents[key]
g.lock.RUnlock()
return bucket
}
func (g *GMap[K, D]) Get(key K) D {
return g.get(g.clock.hashFunc(key)).Contents
}
func (g *GMap[K, D]) Mark(key K) {
g.lock.Lock()
bucket := g.contents[g.clock.hashFunc(key)]
bucket.Gravestone = true
g.contents[g.clock.hashFunc(key)] = bucket
g.lock.Unlock()
}
// IsMarked: returns true if the node is marked
func (g *GMap[K, D]) IsMarked(key K) bool {
marked := false
g.lock.RLock()
bucket, ok := g.contents[g.clock.hashFunc(key)]
if ok {
marked = bucket.Gravestone
}
g.lock.RUnlock()
return marked
}
func (g *GMap[K, D]) Keys() []uint64 {
g.lock.RLock()
contents := make([]uint64, len(g.contents))
index := 0
for key := range g.contents {
contents[index] = key
index++
}
g.lock.RUnlock()
return contents
}
func (g *GMap[K, D]) Save() map[uint64]Bucket[D] {
buckets := make(map[uint64]Bucket[D])
g.lock.RLock()
for key, value := range g.contents {
buckets[key] = value
}
g.lock.RUnlock()
return buckets
}
func (g *GMap[K, D]) SaveWithKeys(keys []uint64) map[uint64]Bucket[D] {
buckets := make(map[uint64]Bucket[D])
g.lock.RLock()
for _, key := range keys {
buckets[key] = g.contents[key]
}
g.lock.RUnlock()
return buckets
}
func (g *GMap[K, D]) GetClock() map[uint64]uint64 {
clock := make(map[uint64]uint64)
g.lock.RLock()
for key, bucket := range g.contents {
clock[key] = bucket.Vector
}
g.lock.RUnlock()
return clock
}
func (g *GMap[K, D]) GetHash() uint64 {
hash := uint64(0)
g.lock.RLock()
for _, value := range g.contents {
hash += value.Vector
}
g.lock.RUnlock()
return hash
}
func (g *GMap[K, D]) Prune() {
stale := g.clock.getStale()
g.lock.Lock()
for _, outlier := range stale {
delete(g.contents, outlier)
}
g.lock.Unlock()
}
func NewGMap[K cmp.Ordered, D any](clock *VectorClock[K]) *GMap[K, D] {
return &GMap[K, D]{
contents: make(map[uint64]Bucket[D]),
clock: clock,
}
}

211
pkg/crdt/two_phase_map.go Normal file
View File

@ -0,0 +1,211 @@
package crdt
import (
"cmp"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
type TwoPhaseMap[K cmp.Ordered, D any] struct {
addMap *GMap[K, D]
removeMap *GMap[K, bool]
Clock *VectorClock[K]
processId K
}
type TwoPhaseMapSnapshot[K cmp.Ordered, D any] struct {
Add map[uint64]Bucket[D]
Remove map[uint64]Bucket[bool]
}
// Contains checks whether the value exists in the map
func (m *TwoPhaseMap[K, D]) Contains(key K) bool {
return m.contains(m.Clock.hashFunc(key))
}
// Contains checks whether the value exists in the map
func (m *TwoPhaseMap[K, D]) contains(key uint64) bool {
if !m.addMap.contains(key) {
return false
}
addValue := m.addMap.get(key)
if !m.removeMap.contains(key) {
return true
}
removeValue := m.removeMap.get(key)
return addValue.Vector >= removeValue.Vector
}
func (m *TwoPhaseMap[K, D]) Get(key K) D {
var result D
if !m.Contains(key) {
return result
}
return m.addMap.Get(key)
}
func (m *TwoPhaseMap[K, D]) get(key uint64) D {
var result D
if !m.contains(key) {
return result
}
return m.addMap.get(key).Contents
}
// Put places the key K in the map
func (m *TwoPhaseMap[K, D]) Put(key K, data D) {
msgSequence := m.Clock.IncrementClock()
m.Clock.Put(key, msgSequence)
m.addMap.Put(key, data)
}
func (m *TwoPhaseMap[K, D]) Mark(key K) {
m.addMap.Mark(key)
}
// Remove removes the value from the map
func (m *TwoPhaseMap[K, D]) Remove(key K) {
m.removeMap.Put(key, true)
}
func (m *TwoPhaseMap[K, D]) keys() []uint64 {
keys := make([]uint64, 0)
addKeys := m.addMap.Keys()
for _, key := range addKeys {
if !m.contains(key) {
continue
}
keys = append(keys, key)
}
return keys
}
func (m *TwoPhaseMap[K, D]) AsList() []D {
theList := make([]D, 0)
keys := m.keys()
for _, key := range keys {
theList = append(theList, m.get(key))
}
return theList
}
func (m *TwoPhaseMap[K, D]) Snapshot() *TwoPhaseMapSnapshot[K, D] {
return &TwoPhaseMapSnapshot[K, D]{
Add: m.addMap.Save(),
Remove: m.removeMap.Save(),
}
}
func (m *TwoPhaseMap[K, D]) SnapShotFromState(state *TwoPhaseMapState[K]) *TwoPhaseMapSnapshot[K, D] {
addKeys := lib.MapKeys(state.AddContents)
removeKeys := lib.MapKeys(state.RemoveContents)
return &TwoPhaseMapSnapshot[K, D]{
Add: m.addMap.SaveWithKeys(addKeys),
Remove: m.removeMap.SaveWithKeys(removeKeys),
}
}
type TwoPhaseMapState[K cmp.Ordered] struct {
Vectors map[uint64]uint64
AddContents map[uint64]uint64
RemoveContents map[uint64]uint64
}
func (m *TwoPhaseMap[K, D]) IsMarked(key K) bool {
return m.addMap.IsMarked(key)
}
// GetHash: Get the hash of the current state of the map
// Sums the current values of the vectors. Provides good approximation
// of increasing numbers
func (m *TwoPhaseMap[K, D]) GetHash() uint64 {
return (m.addMap.GetHash() + 1) * (m.removeMap.GetHash() + 1)
}
// GetState: get the current vector clock of the add and remove
// map
func (m *TwoPhaseMap[K, D]) GenerateMessage() *TwoPhaseMapState[K] {
addContents := m.addMap.GetClock()
removeContents := m.removeMap.GetClock()
return &TwoPhaseMapState[K]{
Vectors: m.Clock.GetClock(),
AddContents: addContents,
RemoveContents: removeContents,
}
}
func (m *TwoPhaseMapState[K]) Difference(highestStale uint64, state *TwoPhaseMapState[K]) *TwoPhaseMapState[K] {
mapState := &TwoPhaseMapState[K]{
AddContents: make(map[uint64]uint64),
RemoveContents: make(map[uint64]uint64),
}
for key, value := range state.AddContents {
otherValue, ok := m.AddContents[key]
if value > highestStale && (!ok || otherValue < value) {
mapState.AddContents[key] = value
}
}
for key, value := range state.RemoveContents {
otherValue, ok := m.RemoveContents[key]
if value > highestStale && (!ok || otherValue < value) {
mapState.RemoveContents[key] = value
}
}
return mapState
}
func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) {
for key, value := range snapshot.Add {
// Gravestone is local only to that node.
// Discover ourselves if the node is alive
m.addMap.put(key, value)
m.Clock.put(key, value.Vector)
}
for key, value := range snapshot.Remove {
m.removeMap.put(key, value)
m.Clock.put(key, value.Vector)
}
}
func (m *TwoPhaseMap[K, D]) Prune() {
m.addMap.Prune()
m.removeMap.Prune()
m.Clock.Prune()
}
// NewTwoPhaseMap: create a new two phase map. Consists of two maps
// a grow map and a remove map. If both timestamps equal then favour keeping
// it in the map
func NewTwoPhaseMap[K cmp.Ordered, D any](processId K, hashKey func(K) uint64, staleTime uint64) *TwoPhaseMap[K, D] {
m := TwoPhaseMap[K, D]{
processId: processId,
Clock: NewVectorClock(processId, hashKey, staleTime),
}
m.addMap = NewGMap[K, D](m.Clock)
m.removeMap = NewGMap[K, bool](m.Clock)
return &m
}

View File

@ -0,0 +1,190 @@
package crdt
import (
"bytes"
"encoding/gob"
logging "github.com/tim-beatham/wgmesh/pkg/log"
)
type SyncState int
const (
HASH SyncState = iota
PREPARE
PRESENT
EXCHANGE
MERGE
FINISHED
)
// TwoPhaseSyncer is a type to sync a TwoPhase data store
type TwoPhaseSyncer struct {
manager *TwoPhaseStoreMeshManager
generateMessageFSM SyncFSM
state SyncState
mapState *TwoPhaseMapState[string]
peerMsg []byte
}
type TwoPhaseHash struct {
Hash uint64
}
type SyncFSM map[SyncState]func(*TwoPhaseSyncer) ([]byte, bool)
func hash(syncer *TwoPhaseSyncer) ([]byte, bool) {
hash := TwoPhaseHash{
Hash: syncer.manager.store.Clock.GetHash(),
}
var buffer bytes.Buffer
enc := gob.NewEncoder(&buffer)
err := enc.Encode(hash)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
syncer.IncrementState()
return buffer.Bytes(), true
}
func prepare(syncer *TwoPhaseSyncer) ([]byte, bool) {
var recvBuffer = bytes.NewBuffer(syncer.peerMsg)
dec := gob.NewDecoder(recvBuffer)
var hash TwoPhaseHash
err := dec.Decode(&hash)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
// If vector clocks are equal then no need to merge state
// Helps to reduce bandwidth by detecting early
if hash.Hash == syncer.manager.store.Clock.GetHash() {
return nil, false
}
// Increment the clock here so the clock gets
// distributed to everyone else in the mesh
syncer.manager.store.Clock.IncrementClock()
var buffer bytes.Buffer
enc := gob.NewEncoder(&buffer)
mapState := syncer.manager.store.GenerateMessage()
syncer.mapState = mapState
err = enc.Encode(*syncer.mapState)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
syncer.IncrementState()
return buffer.Bytes(), true
}
func present(syncer *TwoPhaseSyncer) ([]byte, bool) {
if syncer.peerMsg == nil {
panic("peer msg is nil")
}
var recvBuffer = bytes.NewBuffer(syncer.peerMsg)
dec := gob.NewDecoder(recvBuffer)
var mapState TwoPhaseMapState[string]
err := dec.Decode(&mapState)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
difference := syncer.mapState.Difference(syncer.manager.store.Clock.GetStaleCount(), &mapState)
syncer.manager.store.Clock.Merge(mapState.Vectors)
var sendBuffer bytes.Buffer
enc := gob.NewEncoder(&sendBuffer)
enc.Encode(*difference)
syncer.IncrementState()
return sendBuffer.Bytes(), true
}
func exchange(syncer *TwoPhaseSyncer) ([]byte, bool) {
if syncer.peerMsg == nil {
panic("peer msg is nil")
}
var recvBuffer = bytes.NewBuffer(syncer.peerMsg)
dec := gob.NewDecoder(recvBuffer)
var mapState TwoPhaseMapState[string]
dec.Decode(&mapState)
snapshot := syncer.manager.store.SnapShotFromState(&mapState)
var sendBuffer bytes.Buffer
enc := gob.NewEncoder(&sendBuffer)
enc.Encode(*snapshot)
syncer.IncrementState()
return sendBuffer.Bytes(), true
}
func merge(syncer *TwoPhaseSyncer) ([]byte, bool) {
if syncer.peerMsg == nil {
panic("peer msg is nil")
}
var recvBuffer = bytes.NewBuffer(syncer.peerMsg)
dec := gob.NewDecoder(recvBuffer)
var snapshot TwoPhaseMapSnapshot[string, MeshNode]
dec.Decode(&snapshot)
syncer.manager.store.Merge(snapshot)
return nil, false
}
func (t *TwoPhaseSyncer) IncrementState() {
t.state = min(t.state+1, FINISHED)
}
func (t *TwoPhaseSyncer) GenerateMessage() ([]byte, bool) {
fsmFunc, ok := t.generateMessageFSM[t.state]
if !ok {
panic("state not handled")
}
return fsmFunc(t)
}
func (t *TwoPhaseSyncer) RecvMessage(msg []byte) error {
t.peerMsg = msg
return nil
}
func (t *TwoPhaseSyncer) Complete() {
logging.Log.WriteInfof("SYNC COMPLETED")
}
func NewTwoPhaseSyncer(manager *TwoPhaseStoreMeshManager) *TwoPhaseSyncer {
var generateMessageFsm SyncFSM = SyncFSM{
HASH: hash,
PREPARE: prepare,
PRESENT: present,
EXCHANGE: exchange,
MERGE: merge,
}
return &TwoPhaseSyncer{
manager: manager,
state: HASH,
generateMessageFSM: generateMessageFsm,
}
}

168
pkg/crdt/vector_clock.go Normal file
View File

@ -0,0 +1,168 @@
package crdt
import (
"cmp"
"sync"
"time"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
type VectorBucket struct {
// clock current value of the node's clock
clock uint64
// lastUpdate we've seen
lastUpdate uint64
}
// Vector clock defines an abstract data type
// for a vector clock implementation
type VectorClock[K cmp.Ordered] struct {
vectors map[uint64]*VectorBucket
lock sync.RWMutex
processID K
staleTime uint64
hashFunc func(K) uint64
// highest update that's been garbage collected
highestStale uint64
}
// IncrementClock: increments the node's value in the vector clock
func (m *VectorClock[K]) IncrementClock() uint64 {
maxClock := uint64(0)
m.lock.Lock()
for _, value := range m.vectors {
maxClock = max(maxClock, value.clock)
}
newBucket := VectorBucket{
clock: maxClock + 1,
lastUpdate: uint64(time.Now().Unix()),
}
m.vectors[m.hashFunc(m.processID)] = &newBucket
m.lock.Unlock()
return maxClock
}
// GetHash: gets the hash of the vector clock used to determine if there
// are any changes
func (m *VectorClock[K]) GetHash() uint64 {
m.lock.RLock()
hash := uint64(0)
for key, bucket := range m.vectors {
hash += key * (bucket.clock + 1)
}
m.lock.RUnlock()
return hash
}
func (m *VectorClock[K]) Merge(vectors map[uint64]uint64) {
for key, value := range vectors {
m.put(key, value)
}
}
// getStale: get all entries that are stale within the mesh
func (m *VectorClock[K]) getStale() []uint64 {
m.lock.RLock()
maxTimeStamp := lib.Reduce(0, lib.MapValues(m.vectors), func(i uint64, vb *VectorBucket) uint64 {
return max(i, vb.lastUpdate)
})
toRemove := make([]uint64, 0)
for key, bucket := range m.vectors {
if maxTimeStamp-bucket.lastUpdate > m.staleTime {
toRemove = append(toRemove, key)
m.highestStale = max(bucket.clock, m.highestStale)
}
}
m.lock.RUnlock()
return toRemove
}
// GetStaleCount: returns a vector clock which is considered to be stale.
// all updates must be greater than this
func (m *VectorClock[K]) GetStaleCount() uint64 {
m.lock.RLock()
staleCount := m.highestStale
m.lock.RUnlock()
return staleCount
}
func (m *VectorClock[K]) Prune() {
stale := m.getStale()
m.lock.Lock()
for _, key := range stale {
delete(m.vectors, key)
}
m.lock.Unlock()
}
func (m *VectorClock[K]) GetTimestamp(processId K) uint64 {
m.lock.RLock()
lastUpdate := m.vectors[m.hashFunc(m.processID)].lastUpdate
m.lock.RUnlock()
return lastUpdate
}
func (m *VectorClock[K]) Put(key K, value uint64) {
m.put(m.hashFunc(key), value)
}
func (m *VectorClock[K]) put(key uint64, value uint64) {
clockValue := uint64(0)
m.lock.Lock()
bucket, ok := m.vectors[key]
if ok {
clockValue = bucket.clock
}
// Make sure that entries that were garbage collected don't get
// addded back
if value > clockValue && value > m.highestStale {
newBucket := VectorBucket{
clock: value,
lastUpdate: uint64(time.Now().Unix()),
}
m.vectors[key] = &newBucket
}
m.lock.Unlock()
}
func (m *VectorClock[K]) GetClock() map[uint64]uint64 {
clock := make(map[uint64]uint64)
m.lock.RLock()
for key, value := range m.vectors {
clock[key] = value.clock
}
m.lock.RUnlock()
return clock
}
func NewVectorClock[K cmp.Ordered](processID K, hashFunc func(K) uint64, staleTime uint64) *VectorClock[K] {
return &VectorClock[K]{
vectors: make(map[uint64]*VectorBucket),
processID: processID,
staleTime: staleTime,
hashFunc: hashFunc,
}
}

View File

@ -1,9 +1,9 @@
package ctrlserver
import (
crdt "github.com/tim-beatham/wgmesh/pkg/automerge"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/crdt"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
@ -16,44 +16,52 @@ import (
// NewCtrlServerParams are the params requried to create a new ctrl server
type NewCtrlServerParams struct {
Conf *conf.WgMeshConfiguration
Conf *conf.DaemonConfiguration
Client *wgctrl.Client
AuthProvider rpc.AuthenticationServer
CtrlProvider rpc.MeshCtrlServerServer
SyncProvider rpc.SyncServiceServer
Querier query.Querier
OnDelete func(mesh.MeshProvider)
}
// Create a new instance of the MeshCtrlServer or error if the
// operation failed
func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
ctrlServer := new(MeshCtrlServer)
meshFactory := crdt.CrdtProviderFactory{}
nodeFactory := crdt.MeshNodeFactory{
meshFactory := &crdt.TwoPhaseMapFactory{
Config: params.Conf,
}
nodeFactory := &crdt.MeshNodeFactory{
Config: *params.Conf,
}
idGenerator := &lib.UUIDGenerator{}
idGenerator := &lib.IDNameGenerator{}
ipAllocator := &ip.ULABuilder{}
interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client)
configApplyer := mesh.NewWgMeshConfigApplyer()
meshManagerParams := &mesh.NewMeshManagerParams{
Conf: *params.Conf,
Client: params.Client,
MeshProvider: &meshFactory,
NodeFactory: &nodeFactory,
MeshProvider: meshFactory,
NodeFactory: nodeFactory,
IdGenerator: idGenerator,
IPAllocator: ipAllocator,
InterfaceManipulator: interfaceManipulator,
ConfigApplyer: configApplyer,
OnDelete: params.OnDelete,
}
ctrlServer.MeshManager = mesh.NewMeshManager(meshManagerParams)
configApplyer.SetMeshManager(ctrlServer.MeshManager)
ctrlServer.Conf = params.Conf
connManagerParams := conn.NewConnectionManageParams{
connManagerParams := conn.NewConnectionManagerParams{
CertificatePath: params.Conf.CertificatePath,
PrivateKey: params.Conf.PrivateKeyPath,
SkipCertVerification: params.Conf.SkipCertVerification,
CaCert: params.Conf.CaCertificatePath,
ConnFactory: conn.NewWgCtrlConnection,
}
connMgr, err := conn.NewConnectionManager(&connManagerParams)
@ -65,7 +73,6 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
ctrlServer.ConnectionManager = connMgr
connServerParams := conn.NewConnectionServerParams{
Conf: params.Conf,
AuthProvider: params.AuthProvider,
CtrlProvider: params.CtrlProvider,
SyncProvider: params.SyncProvider,
}
@ -82,12 +89,37 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
return ctrlServer, nil
}
func (s *MeshCtrlServer) GetConfiguration() *conf.DaemonConfiguration {
return s.Conf
}
func (s *MeshCtrlServer) GetClient() *wgctrl.Client {
return s.Client
}
func (s *MeshCtrlServer) GetQuerier() query.Querier {
return s.Querier
}
func (s *MeshCtrlServer) GetMeshManager() mesh.MeshManager {
return s.MeshManager
}
func (s *MeshCtrlServer) GetConnectionManager() conn.ConnectionManager {
return s.ConnectionManager
}
// Close closes the ctrl server tearing down any connections that exist
func (s *MeshCtrlServer) Close() error {
if err := s.ConnectionManager.Close(); err != nil {
logging.Log.WriteErrorf(err.Error())
}
if err := s.MeshManager.Close(); err != nil {
logging.Log.WriteErrorf(err.Error())
}
if err := s.ConnectionServer.Close(); err != nil {
logging.Log.WriteErrorf(err.Error())
}

View File

@ -1,14 +1,31 @@
package ctrlserver
import (
"net"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type MeshRoute struct {
Destination string
Path []string
}
// Represents the WireGuard configuration attached to the node
type WireGuardStats struct {
AllowedIPs []string
TransmitBytes int64
ReceivedBytes int64
PersistentKeepAliveInterval time.Duration
}
// Represents a WireGuard MeshNode
type MeshNode struct {
HostEndpoint string
@ -16,7 +33,11 @@ type MeshNode struct {
PublicKey string
WgHost string
Timestamp int64
Routes []string
Routes []MeshRoute
Description string
Alias string
Services map[string]string
Stats WireGuardStats
}
// Represents a WireGuard Mesh
@ -25,12 +46,71 @@ type Mesh struct {
Nodes map[string]MeshNode
}
type CtrlServer interface {
GetConfiguration() *conf.DaemonConfiguration
GetClient() *wgctrl.Client
GetQuerier() query.Querier
GetMeshManager() mesh.MeshManager
Close() error
GetConnectionManager() conn.ConnectionManager
}
// Represents a ctrlserver to be used in WireGuard
type MeshCtrlServer struct {
Client *wgctrl.Client
MeshManager *mesh.MeshManager
MeshManager mesh.MeshManager
ConnectionManager conn.ConnectionManager
ConnectionServer *conn.ConnectionServer
Conf *conf.WgMeshConfiguration
Conf *conf.DaemonConfiguration
Querier query.Querier
}
// NewCtrlNode create an instance of a ctrl node to send over an
// IPC call
func NewCtrlNode(provider mesh.MeshProvider, node mesh.MeshNode) *MeshNode {
pubKey, _ := node.GetPublicKey()
ctrlNode := MeshNode{
HostEndpoint: node.GetHostEndpoint(),
WgEndpoint: node.GetWgEndpoint(),
PublicKey: pubKey.String(),
WgHost: node.GetWgHost().String(),
Timestamp: node.GetTimeStamp(),
Routes: lib.Map(node.GetRoutes(), func(r mesh.Route) MeshRoute {
return MeshRoute{
Destination: r.GetDestination().String(),
Path: r.GetPath(),
}
}),
Description: node.GetDescription(),
Alias: node.GetAlias(),
Services: node.GetServices(),
}
device, err := provider.GetDevice()
if err != nil {
return &ctrlNode
}
peers := lib.Filter(device.Peers, func(p wgtypes.Peer) bool {
return p.PublicKey.String() == pubKey.String()
})
if len(peers) > 0 {
peer := peers[0]
stats := WireGuardStats{
AllowedIPs: lib.Map(peer.AllowedIPs, func(i net.IPNet) string {
return i.String()
}),
TransmitBytes: peer.TransmitBytes,
ReceivedBytes: peer.ReceiveBytes,
PersistentKeepAliveInterval: peer.PersistentKeepaliveInterval,
}
ctrlNode.Stats = stats
}
return &ctrlNode
}

51
pkg/ctrlserver/stub.go Normal file
View File

@ -0,0 +1,51 @@
package ctrlserver
import (
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query"
"golang.zx2c4.com/wireguard/wgctrl"
)
type CtrlServerStub struct {
manager mesh.MeshManager
querier query.Querier
connectionManager conn.ConnectionManager
}
func NewCtrlServerStub() *CtrlServerStub {
var manager mesh.MeshManager = mesh.NewMeshManagerStub()
return &CtrlServerStub{
manager: manager,
querier: query.NewJmesQuerier(manager),
connectionManager: &conn.ConnectionManagerStub{},
}
}
func (c *CtrlServerStub) GetConfiguration() *conf.DaemonConfiguration {
return &conf.DaemonConfiguration{
GrpcPort: 8080,
BaseConfiguration: conf.WgConfiguration{},
}
}
func (c *CtrlServerStub) GetClient() *wgctrl.Client {
return &wgctrl.Client{}
}
func (c *CtrlServerStub) GetQuerier() query.Querier {
return c.querier
}
func (c *CtrlServerStub) GetMeshManager() mesh.MeshManager {
return c.manager
}
func (c *CtrlServerStub) Close() error {
return nil
}
func (c *CtrlServerStub) GetConnectionManager() conn.ConnectionManager {
return c.connectionManager
}

114
pkg/dns/dns.go Normal file
View File

@ -0,0 +1,114 @@
package smegdns
import (
"encoding/json"
"fmt"
"net"
"net/rpc"
"github.com/miekg/dns"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/query"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
const MeshRegularExpression = `(?P<meshId>.+)\.(?P<alias>.+)\.smeg\.`
type DNSHandler struct {
client *rpc.Client
server *dns.Server
}
// queryMesh: queries the mesh network for the given meshId and node
// with alias
func (d *DNSHandler) queryMesh(meshId, alias string) net.IP {
var reply string
err := d.client.Call("IpcHandler.Query", &ipc.QueryMesh{
MeshId: meshId,
Query: fmt.Sprintf("[?alias == '%s'] | [0]", alias),
}, &reply)
if err != nil {
return nil
}
var node *query.QueryNode
err = json.Unmarshal([]byte(reply), &node)
if err != nil || node == nil {
return nil
}
ip, _, _ := net.ParseCIDR(node.WgHost)
return ip
}
func (d *DNSHandler) handleQuery(m *dns.Msg) {
for _, q := range m.Question {
switch q.Qtype {
case dns.TypeAAAA:
logging.Log.WriteInfof("Query for %s", q.Name)
groups := lib.MatchCaptureGroup(MeshRegularExpression, q.Name)
if len(groups) == 0 {
continue
}
ip := d.queryMesh(groups["meshId"], groups["alias"])
rr, err := dns.NewRR(fmt.Sprintf("%s AAAA %s", q.Name, ip))
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
}
}
func (h *DNSHandler) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
msg := new(dns.Msg)
msg.SetReply(r)
msg.Authoritative = true
switch r.Opcode {
case dns.OpcodeQuery:
h.handleQuery(msg)
}
w.WriteMsg(msg)
}
func (h *DNSHandler) Listen() error {
return h.server.ListenAndServe()
}
func (h *DNSHandler) Close() error {
return h.server.Shutdown()
}
func NewDns(udpPort int) (*DNSHandler, error) {
client, err := rpc.DialHTTP("unix", SockAddr)
if err != nil {
return nil, err
}
dnsHander := DNSHandler{
client: client,
}
dns.HandleFunc("smeg.", dnsHander.handleDnsRequest)
dnsHander.server = &dns.Server{Addr: fmt.Sprintf(":%d", udpPort), Net: "udp"}
return &dnsHander, nil
}

View File

@ -4,13 +4,13 @@ package rpctypes;
option go_package = "pkg/rpc";
service MeshCtrlServer {
rpc JoinMesh(JoinMeshRequest) returns (JoinMeshReply) {}
rpc GetMesh(GetMeshRequest) returns (GetMeshReply) {}
}
message JoinMeshRequest {
string meshId = 2;
message GetMeshRequest {
string meshId = 1;
}
message JoinMeshReply {
bool success = 1;
message GetMeshReply {
bytes mesh = 1;
}

132
pkg/hosts/hosts.go Normal file
View File

@ -0,0 +1,132 @@
// hosts: utility for modifying the /etc/hosts file
package hosts
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"os"
"strings"
)
// HOSTS_FILE is the hosts file location
const HOSTS_FILE = "/etc/hosts"
const DOMAIN_HEADER = "#WG AUTO GENERATED HOSTS"
const DOMAIN_TRAILER = "#WG AUTO GENERATED HOSTS END"
type HostsEntry struct {
Alias string
Ip net.IP
}
// Generic interface to manipulate /etc/hosts file
type HostsManipulator interface {
// AddrAddr associates an aliasd with a given IP address
AddAddr(hosts ...HostsEntry)
// Remove deletes the entry from /etc/hosts
Remove(hosts ...HostsEntry)
// Writes the changes to /etc/hosts file
Write() error
}
type HostsManipulatorImpl struct {
hosts map[string]HostsEntry
}
// AddAddr implements HostsManipulator.
func (m *HostsManipulatorImpl) AddAddr(hosts ...HostsEntry) {
changed := false
for _, host := range hosts {
prev, ok := m.hosts[host.Ip.String()]
if !ok || prev.Alias != host.Alias {
changed = true
}
m.hosts[host.Ip.String()] = host
}
if changed {
m.Write()
}
}
// Remove implements HostsManipulator.
func (m *HostsManipulatorImpl) Remove(hosts ...HostsEntry) {
lenBefore := len(m.hosts)
for _, host := range hosts {
delete(m.hosts, host.Alias)
}
if lenBefore != len(m.hosts) {
m.Write()
}
}
func (m *HostsManipulatorImpl) removeHosts() string {
hostsFile, err := os.ReadFile(HOSTS_FILE)
if err != nil {
return ""
}
var contents strings.Builder
scanner := bufio.NewScanner(bytes.NewReader(hostsFile))
hostsSection := false
for scanner.Scan() {
line := scanner.Text()
if err == io.EOF {
break
} else if err != nil {
return ""
}
if !hostsSection && strings.Contains(line, DOMAIN_HEADER) {
hostsSection = true
}
if !hostsSection {
contents.WriteString(line + "\n")
}
if hostsSection && strings.Contains(line, DOMAIN_TRAILER) {
hostsSection = false
}
}
if scanner.Err() != nil && scanner.Err() != io.EOF {
return ""
}
return contents.String()
}
// Write implements HostsManipulator
func (m *HostsManipulatorImpl) Write() error {
contents := m.removeHosts()
var nextHosts strings.Builder
nextHosts.WriteString(contents)
nextHosts.WriteString(DOMAIN_HEADER + "\n")
for _, host := range m.hosts {
nextHosts.WriteString(fmt.Sprintf("%s\t%s\n", host.Ip.String(), host.Alias))
}
nextHosts.WriteString(DOMAIN_TRAILER + "\n")
return os.WriteFile(HOSTS_FILE, []byte(nextHosts.String()), 0644)
}
func NewHostsManipulator() HostsManipulator {
return &HostsManipulatorImpl{hosts: make(map[string]HostsEntry)}
}

View File

@ -10,14 +10,29 @@ import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
)
type NewMeshArgs struct {
// IfName is the interface that the mesh instance will run on
IfName string
// WireGuardArgs are provided args specific to WireGuard
type WireGuardArgs struct {
// WgPort is the WireGuard port to expose
WgPort int
// KeepAliveWg is the number of seconds to keep alive
// for WireGuard NAT/firewall traversal
KeepAliveWg int
// AdvertiseRoutes whether or not to advertise routes to and from the
// mesh network
AdvertiseRoutes bool
// AdvertiseDefaultRoute whether or not to advertise the default route
// into the mesh network
AdvertiseDefaultRoute bool
// Endpoint is the routable alias of the machine. Can be an IP
// or DNS entry
Endpoint string
// Role is the role of the individual in the mesh
Role string
}
type NewMeshArgs struct {
// WgArgs are specific WireGuard args to use
WgArgs WireGuardArgs
}
type JoinMeshArgs struct {
@ -25,13 +40,13 @@ type JoinMeshArgs struct {
MeshId string
// IpAddress is a routable IP in another mesh
IpAdress string
// IfName is the interface name of the mesh
IfName string
// Port is the WireGuard port to expose
Port int
// Endpoint is the routable address of this machine. If not provided
// defaults to the default address
Endpoint string
// WgArgs is the WireGuard parameters to use.
WgArgs WireGuardArgs
}
type PutServiceArgs struct {
Service string
Value string
}
type GetMeshReply struct {
@ -53,17 +68,19 @@ type MeshIpc interface {
JoinMesh(args JoinMeshArgs, reply *string) error
LeaveMesh(meshId string, reply *string) error
GetMesh(meshId string, reply *GetMeshReply) error
EnableInterface(meshId string, reply *string) error
GetDOT(meshId string, reply *string) error
Query(query QueryMesh, reply *string) error
PutDescription(description string, reply *string) error
PutAlias(alias string, reply *string) error
PutService(args PutServiceArgs, reply *string) error
DeleteService(service string, reply *string) error
}
const SockAddr = "/tmp/wgmesh_ipc.sock"
func RunIpcHandler(server MeshIpc) error {
if err := os.RemoveAll(SockAddr); err != nil {
return errors.New("Could not find to address")
return errors.New("could not find to address")
}
rpc.Register(server)

View File

@ -1,11 +1,34 @@
package lib
import "cmp"
// MapToSlice converts a map to a slice in go
func MapValues[K comparable, V any](m map[K]V) []V {
func MapValues[K cmp.Ordered, V any](m map[K]V) []V {
return MapValuesWithExclude(m, map[K]struct{}{})
}
func MapValuesWithExclude[K comparable, V any](m map[K]V, exclude map[K]struct{}) []V {
type MapItemsEntry[K cmp.Ordered, V any] struct {
Key K
Value V
}
func MapItems[K cmp.Ordered, V any](m map[K]V) []MapItemsEntry[K, V] {
keys := MapKeys(m)
values := MapValues(m)
vs := make([]MapItemsEntry[K, V], len(keys))
for index, _ := range keys {
vs[index] = MapItemsEntry[K, V]{
Key: keys[index],
Value: values[index],
}
}
return vs
}
func MapValuesWithExclude[K cmp.Ordered, V any](m map[K]V, exclude map[K]struct{}) []V {
values := make([]V, len(m)-len(exclude))
i := 0
@ -26,11 +49,11 @@ func MapValuesWithExclude[K comparable, V any](m map[K]V, exclude map[K]struct{}
return values
}
func MapKeys[K comparable, V any](m map[K]V) []K {
func MapKeys[K cmp.Ordered, V any](m map[K]V) []K {
values := make([]K, len(m))
i := 0
for k, _ := range m {
for k := range m {
values[i] = k
i++
}
@ -58,7 +81,7 @@ type filterFunc[V any] func(V) bool
func Filter[V any](list []V, f filterFunc[V]) []V {
newList := make([]V, 0)
for _, elem := range newList {
for _, elem := range list {
if f(elem) {
newList = append(newList, elem)
}
@ -66,3 +89,23 @@ func Filter[V any](list []V, f filterFunc[V]) []V {
return newList
}
func Contains[V any](list []V, proposition func(V) bool) bool {
for _, elem := range list {
if proposition(elem) {
return true
}
}
return false
}
func Reduce[A any, V any](start A, values []V, reduce func(A, V) A) A {
accum := start
for _, elem := range values {
accum = reduce(accum, elem)
}
return accum
}

144
pkg/lib/conv_test.go Normal file
View File

@ -0,0 +1,144 @@
package lib
import (
"slices"
"testing"
)
func stringToInt(input string) int {
return len(input)
}
func intDiv(input int) int {
return input / 2
}
func TestMapValuesMapsValues(t *testing.T) {
values := []int{1, 4, 11, 92}
var theMap map[string]int = map[string]int{
"mynameisjeff": values[0],
"tim": values[1],
"bob": values[2],
"derek": values[3],
}
mapValues := MapValues(theMap)
for _, elem := range mapValues {
if !slices.Contains(values, elem) {
t.Fatalf(`%d is not an expected value`, elem)
}
}
if len(mapValues) != len(theMap) {
t.Fatalf(`Expected length %d got %d`, len(theMap), len(mapValues))
}
}
func TestMapValuesWithExcludeExcludesValues(t *testing.T) {
values := []int{1, 9, 22}
var theMap map[string]int = map[string]int{
"mynameisbob": values[0],
"tim": values[1],
"bob": values[2],
}
exclude := map[string]struct{}{
"tim": {},
}
mapValues := MapValuesWithExclude(theMap, exclude)
if slices.Contains(mapValues, values[1]) {
t.Fatalf(`Failed to exclude expected value`)
}
if len(mapValues) != 2 {
t.Fatalf(`Incorrect expected length`)
}
for _, value := range theMap {
if !slices.Contains(values, value) {
t.Fatalf(`Element does not exist in the list of
expected values`)
}
}
}
func TestMapKeys(t *testing.T) {
keys := []string{"1", "2", "3"}
theMap := map[string]int{
keys[0]: 1,
keys[1]: 2,
keys[2]: 3,
}
mapKeys := MapKeys(theMap)
for _, elem := range mapKeys {
if !slices.Contains(keys, elem) {
t.Fatalf(`%s elem is not an expected key`, elem)
}
}
if len(mapKeys) != len(theMap) {
t.Fatalf(`Missing expected values`)
}
}
func TestMapValues(t *testing.T) {
array := []string{"mynameisjeff", "tim", "bob", "derek"}
intArray := Map(array, stringToInt)
for index, elem := range intArray {
if len(array[index]) != elem {
t.Fatalf(`Have %d want %d`, elem, len(array[index]))
}
}
}
func TestFilterFilterAll(t *testing.T) {
values := []int{1, 2, 3, 4, 5}
filterFunc := func(n int) bool {
return false
}
newValues := Filter(values, filterFunc)
if len(newValues) != 0 {
t.Fatalf(`Expected value was 0`)
}
}
func TestFilterFilterNone(t *testing.T) {
values := []int{1, 2, 3, 4, 5}
filterFunc := func(n int) bool {
return true
}
newValues := Filter(values, filterFunc)
if !slices.Equal(values, newValues) {
t.Fatalf(`Expected lists to be the same`)
}
}
func TestFilterFilterSome(t *testing.T) {
values := []int{1, 2, 3, 4, 5}
filterFunc := func(n int) bool {
return n < 3
}
expected := []int{1, 2}
actual := Filter(values, filterFunc)
if !slices.Equal(expected, actual) {
t.Fatalf(`Expected expected and actual to be the same`)
}
}

46
pkg/lib/hashing.go Normal file
View File

@ -0,0 +1,46 @@
package lib
import (
"hash/fnv"
"sort"
)
type consistentHashRecord[V any] struct {
record V
value int
}
func HashString(value string) int {
f := fnv.New32a()
f.Write([]byte(value))
return int(f.Sum32())
}
// ConsistentHash implementation. Traverse the values until we find a key
// less than ours.
func ConsistentHash[V any, K any](values []V, client K, bucketFunc func(V) int, keyFunc func(K) int) V {
if len(values) == 0 {
panic("values is empty")
}
vs := Map(values, func(v V) consistentHashRecord[V] {
return consistentHashRecord[V]{
v,
bucketFunc(v),
}
})
sort.SliceStable(vs, func(i, j int) bool {
return vs[i].value < vs[j].value
})
ourKey := keyFunc(client)
for _, record := range vs {
if ourKey < record.value {
return record.record
}
}
return vs[0].record
}

View File

@ -1,6 +1,9 @@
package lib
import "github.com/google/uuid"
import (
"github.com/anandvarma/namegen"
"github.com/google/uuid"
)
// IdGenerator generates unique ids
type IdGenerator interface {
@ -15,3 +18,11 @@ func (g *UUIDGenerator) GetId() (string, error) {
id := uuid.New()
return id.String(), nil
}
type IDNameGenerator struct {
}
func (i *IDNameGenerator) GetId() (string, error) {
name_schema := namegen.New()
return name_schema.Get(), nil
}

View File

@ -1,17 +1,61 @@
package lib
import (
"encoding/json"
"io"
"log"
"net"
"net/http"
)
// GetOutboundIP: gets the oubound IP of this packet
func GetOutboundIP() net.IP {
func GetOutboundIP() (net.IP, error) {
conn, err := net.Dial("udp", "8.8.8.8:80")
if err != nil {
log.Fatal(err)
}
defer conn.Close()
localAddr := conn.LocalAddr().(*net.UDPAddr)
return localAddr.IP
return localAddr.IP, nil
}
const IP_SERVICE = "https://api.ipify.org?format=json"
type IpResponse struct {
Ip string `json:"ip"`
}
func (i *IpResponse) GetIP() net.IP {
return net.ParseIP(i.Ip)
}
// GetPublicIP: get the nodes public IP address. For when a node is behind NAT
func GetPublicIP() (net.IP, error) {
req, err := http.NewRequest(http.MethodGet, IP_SERVICE, nil)
if err != nil {
return nil, err
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
resBody, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
var jsonResponse IpResponse
err = json.Unmarshal([]byte(resBody), &jsonResponse)
if err != nil {
return nil, err
}
return jsonResponse.GetIP(), nil
}

46
pkg/lib/random_test.go Normal file
View File

@ -0,0 +1,46 @@
package lib
import (
"slices"
"testing"
)
// Test that a random subset of length 0 produces a zero length
// list
func TestRandomSubsetOfLength0(t *testing.T) {
values := []int{1, 2, 3, 4, 5, 6, 7, 8}
randomValues := RandomSubsetOfLength(values, 0)
if len(randomValues) != 0 {
t.Fatalf(`Expected length to be 0`)
}
}
func TestRandomSubsetOfLength1(t *testing.T) {
values := []int{1, 2, 3, 4, 5}
randomValues := RandomSubsetOfLength(values, 1)
if len(randomValues) != 1 {
t.Fatalf(`Expected length to be 1`)
}
if !slices.Contains(values, randomValues[0]) {
t.Fatalf(`Expected length to be 1`)
}
}
func TestRandomSubsetEntireList(t *testing.T) {
values := []int{1, 2, 3, 4, 5}
randomValues := RandomSubsetOfLength(values, len(values))
if len(randomValues) != len(values) {
t.Fatalf(`Expected length to be %d was %d`, len(values), len(randomValues))
}
slices.Sort(randomValues)
if !slices.Equal(values, randomValues) {
t.Fatalf(`Expected slices to be equal`)
}
}

19
pkg/lib/regex.go Normal file
View File

@ -0,0 +1,19 @@
package lib
import "regexp"
func MatchCaptureGroup(pattern, payload string) map[string]string {
patterns := make(map[string]string)
expr := regexp.MustCompile(pattern)
match := expr.FindStringSubmatch(payload)
for i, name := range expr.SubexpNames() {
if i != 0 && name != "" {
patterns[name] = match[i]
}
}
return patterns
}

309
pkg/lib/rtnetlink.go Normal file
View File

@ -0,0 +1,309 @@
package lib
import (
"encoding/binary"
"fmt"
"net"
"github.com/jsimonetti/rtnetlink"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"golang.org/x/sys/unix"
)
type RtNetlinkConfig struct {
conn *rtnetlink.Conn
}
func NewRtNetlinkConfig() (*RtNetlinkConfig, error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return nil, err
}
return &RtNetlinkConfig{conn: conn}, nil
}
const WIREGUARD_MTU = 1420
// Create a netlink interface if it does not exist. ifName is the name of the netlink interface
func (c *RtNetlinkConfig) CreateLink(ifName string) error {
_, err := net.InterfaceByName(ifName)
if err == nil {
return fmt.Errorf("interface %s already exists", ifName)
}
err = c.conn.Link.New(&rtnetlink.LinkMessage{
Family: unix.AF_UNSPEC,
Flags: unix.IFF_UP,
Attributes: &rtnetlink.LinkAttributes{
Name: ifName,
Info: &rtnetlink.LinkInfo{Kind: "wireguard"},
MTU: uint32(WIREGUARD_MTU),
},
})
if err != nil {
return fmt.Errorf("failed to create wireguard interface: %w", err)
}
return nil
}
// Delete link delete the specified interface
func (c *RtNetlinkConfig) DeleteLink(ifName string) error {
iface, err := net.InterfaceByName(ifName)
if err != nil {
return fmt.Errorf("failed to get interface %s %w", ifName, err)
}
err = c.conn.Link.Delete(uint32(iface.Index))
if err != nil {
return fmt.Errorf("failed to delete wg interface %w", err)
}
return nil
}
// AddAddress adds an address to the given interface.
func (c *RtNetlinkConfig) AddAddress(ifName string, address string) error {
iface, err := net.InterfaceByName(ifName)
if err != nil {
return fmt.Errorf("failed to get interface %s error: %w", ifName, err)
}
addr, cidr, err := net.ParseCIDR(address)
if err != nil {
return fmt.Errorf("failed to parse CIDR %s error: %w", addr, err)
}
family := unix.AF_INET6
ipv4 := cidr.IP.To4()
if ipv4 != nil {
family = unix.AF_INET
}
// Calculate the prefix length
ones, _ := cidr.Mask.Size()
// Calculate the broadcast IP
// Only used when family is AF_INET
var brd net.IP
if ipv4 != nil {
brd = make(net.IP, len(ipv4))
binary.BigEndian.PutUint32(brd, binary.BigEndian.Uint32(ipv4)|^binary.BigEndian.Uint32(net.IP(cidr.Mask).To4()))
}
err = c.conn.Address.New(&rtnetlink.AddressMessage{
Family: uint8(family),
PrefixLength: uint8(ones),
Scope: unix.RT_SCOPE_UNIVERSE,
Index: uint32(iface.Index),
Attributes: &rtnetlink.AddressAttributes{
Address: addr,
Local: addr,
Broadcast: brd,
},
})
if err != nil {
err = fmt.Errorf("failed to add address to link %w", err)
}
return err
}
// AddRoute: adds a route to the routing table.
// ifName is the intrface to add the route to
// gateway is the IP of the gateway device to hop to
// dst is the network prefix of the advertised destination
func (c *RtNetlinkConfig) AddRoute(ifName string, route Route) error {
iface, err := net.InterfaceByName(ifName)
if err != nil {
return fmt.Errorf("failed accessing interface %s error %w", ifName, err)
}
gw := route.Gateway
dst := route.Destination
var family uint8 = unix.AF_INET6
if dst.IP.To4() != nil {
family = unix.AF_INET
}
routes, err := c.listRoutes(ifName, family)
if err != nil {
return err
}
// If it already exists no need to add the route
if !Contains(routes, func(prevRoute rtnetlink.RouteMessage) bool {
return prevRoute.Attributes.Dst.Equal(route.Destination.IP) &&
prevRoute.Attributes.Gateway.Equal(route.Gateway)
}) {
attr := rtnetlink.RouteAttributes{
Dst: dst.IP,
OutIface: uint32(iface.Index),
Gateway: gw,
}
ones, _ := dst.Mask.Size()
err = c.conn.Route.Replace(&rtnetlink.RouteMessage{
Family: family,
Table: unix.RT_TABLE_MAIN,
Protocol: unix.RTPROT_BOOT,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
DstLength: uint8(ones),
Attributes: attr,
})
if err != nil {
return fmt.Errorf("failed to add route %w", err)
}
}
return nil
}
// DeleteRoute deletes routes with the gateway and destination
func (c *RtNetlinkConfig) DeleteRoute(ifName string, route Route) error {
iface, err := net.InterfaceByName(ifName)
if err != nil {
return fmt.Errorf("failed accessing interface %s error %w", ifName, err)
}
gw := route.Gateway
dst := route.Destination
var family uint8 = unix.AF_INET6
if dst.IP.To4() != nil {
family = unix.AF_INET
}
attr := rtnetlink.RouteAttributes{
Dst: dst.IP,
OutIface: uint32(iface.Index),
Gateway: gw,
}
ones, _ := dst.Mask.Size()
err = c.conn.Route.Delete(&rtnetlink.RouteMessage{
Family: family,
Table: unix.RT_TABLE_MAIN,
Protocol: unix.RTPROT_BOOT,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
DstLength: uint8(ones),
Attributes: attr,
})
if err != nil {
return fmt.Errorf("failed to delete route %s", dst.IP.String())
}
return nil
}
type Route struct {
Gateway net.IP
Destination net.IPNet
}
func (r1 Route) equal(r2 Route) bool {
mask1Ones, _ := r1.Destination.Mask.Size()
mask2Ones, _ := r2.Destination.Mask.Size()
return r1.Gateway.String() == r2.Gateway.String() &&
(mask1Ones == 0 && mask2Ones == 0 || r1.Destination.IP.Equal(r2.Destination.IP))
}
// DeleteRoutes deletes all routes not in exclude
func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...Route) error {
routes, err := c.listRoutes(ifName, family)
if err != nil {
return err
}
ifRoutes := make([]Route, 0)
for _, rtRoute := range routes {
maskSize := 128
if family == unix.AF_INET {
maskSize = 32
}
cidr := net.CIDRMask(int(rtRoute.DstLength), maskSize)
route := Route{
Gateway: rtRoute.Attributes.Gateway,
Destination: net.IPNet{IP: rtRoute.Attributes.Dst, Mask: cidr},
}
ifRoutes = append(ifRoutes, route)
}
shouldExclude := func(r Route) bool {
for _, route := range exclude {
if r.equal(route) {
return false
}
}
return true
}
toDelete := Filter(ifRoutes, shouldExclude)
for _, route := range toDelete {
logging.Log.WriteInfof("Deleting route: %s", route.Destination.String())
err := c.DeleteRoute(ifName, route)
if err != nil {
return err
}
}
return nil
}
// listRoutes lists all routes on the interface
func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8) ([]rtnetlink.RouteMessage, error) {
iface, err := net.InterfaceByName(ifName)
if err != nil {
return nil, fmt.Errorf("failed accessing interface %s error %w", ifName, err)
}
routes, err := c.conn.Route.List()
if err != nil {
return nil, fmt.Errorf("failed to get route %w", err)
}
filterFunc := func(r rtnetlink.RouteMessage) bool {
return r.Attributes.Gateway != nil && r.Attributes.OutIface == uint32(iface.Index)
}
routes = Filter(routes, filterFunc)
return routes, nil
}
func (c *RtNetlinkConfig) Close() error {
return c.conn.Close()
}

40
pkg/lib/stats.go Normal file
View File

@ -0,0 +1,40 @@
// lib contains helper functions for the implementation
package lib
import (
"cmp"
"math"
"gonum.org/v1/gonum/stat"
"gonum.org/v1/gonum/stat/distuv"
)
// Modelling the distribution using a normal distribution get the count
// of the outliers
func GetOutliers[K cmp.Ordered](counts map[K]uint64, alpha float64) []K {
n := float64(len(counts))
keys := MapKeys(counts)
values := make([]float64, len(keys))
for index, key := range keys {
values[index] = float64(counts[key])
}
mean := stat.Mean(values, nil)
stdDev := stat.StdDev(values, nil)
moe := distuv.Normal{Mu: 0, Sigma: 1}.Quantile(1-alpha/2) * (stdDev / math.Sqrt(n))
lowerBound := mean - moe
var outliers []K
for i, count := range values {
if count < lowerBound {
outliers = append(outliers, keys[i])
}
}
return outliers
}

42
pkg/lib/timer.go Normal file
View File

@ -0,0 +1,42 @@
package lib
import "time"
type TimerFunc = func() error
type Timer struct {
f TimerFunc
quit chan struct{}
updateRate int
}
func (t *Timer) Run() error {
ticker := time.NewTicker(time.Duration(t.updateRate) * time.Second)
t.quit = make(chan struct{})
for {
select {
case <-ticker.C:
err := t.f()
if err != nil {
return err
}
case <-t.quit:
break
}
}
}
func (t *Timer) Stop() error {
close(t.quit)
return nil
}
func NewTimer(f TimerFunc, updateRate int) *Timer {
return &Timer{
f: f,
updateRate: updateRate,
}
}

View File

@ -2,6 +2,7 @@
package logging
import (
"io"
"os"
"github.com/sirupsen/logrus"
@ -15,6 +16,7 @@ type Logger interface {
WriteInfof(msg string, args ...interface{})
WriteErrorf(msg string, args ...interface{})
WriteWarnf(msg string, args ...interface{})
Writer() io.Writer
}
type LogrusLogger struct {
@ -33,6 +35,10 @@ func (l *LogrusLogger) WriteWarnf(msg string, args ...interface{}) {
l.logger.Warnf(msg, args...)
}
func (l *LogrusLogger) Writer() io.Writer {
return l.logger.Writer()
}
func NewLogrusLogger() *LogrusLogger {
logger := logrus.New()
logger.SetFormatter(&logrus.TextFormatter{FullTimestamp: true})

46
pkg/mesh/alias.go Normal file
View File

@ -0,0 +1,46 @@
package mesh
import (
"fmt"
"github.com/tim-beatham/wgmesh/pkg/hosts"
)
type MeshAliasManager interface {
AddAliases(nodes []MeshNode)
RemoveAliases(node []MeshNode)
}
type AliasManager struct {
hosts hosts.HostsManipulator
}
// AddAliases: on node update or change add aliases to the hosts file
func (a *AliasManager) AddAliases(nodes []MeshNode) {
for _, node := range nodes {
if node.GetAlias() != "" {
a.hosts.AddAddr(hosts.HostsEntry{
Alias: fmt.Sprintf("%s.smeg", node.GetAlias()),
Ip: node.GetWgHost().IP,
})
}
}
}
// RemoveAliases: on node remove remove aliases from the hosts file
func (a *AliasManager) RemoveAliases(nodes []MeshNode) {
for _, node := range nodes {
if node.GetAlias() != "" {
a.hosts.Remove(hosts.HostsEntry{
Alias: fmt.Sprintf("%s.smeg", node.GetAlias()),
Ip: node.GetWgHost().IP,
})
}
}
}
func NewAliasManager() MeshAliasManager {
return &AliasManager{
hosts: hosts.NewHostsManipulator(),
}
}

View File

@ -1,10 +1,16 @@
package mesh
import (
"errors"
"fmt"
"net"
"slices"
"strings"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/route"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@ -12,82 +18,446 @@ import (
type MeshConfigApplyer interface {
ApplyConfig() error
RemovePeers(meshId string) error
SetMeshManager(manager MeshManager)
}
// WgMeshConfigApplyer applies WireGuard configuration
type WgMeshConfigApplyer struct {
meshManager *MeshManager
meshManager MeshManager
routeInstaller route.RouteInstaller
hashFunc func(MeshNode) int
}
func convertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) {
endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint())
if err != nil {
return nil, err
type routeNode struct {
gateway string
route Route
}
pubKey, err := node.GetPublicKey()
type convertMeshNodeParams struct {
node MeshNode
self MeshNode
mesh MeshProvider
device *wgtypes.Device
peerToClients map[string][]net.IPNet
routes map[string][]routeNode
}
func (m *WgMeshConfigApplyer) convertMeshNode(params convertMeshNodeParams) (*wgtypes.PeerConfig, error) {
pubKey, err := params.node.GetPublicKey()
if err != nil {
return nil, err
}
allowedips := make([]net.IPNet, 1)
allowedips[0] = *node.GetWgHost()
allowedips[0] = *params.node.GetWgHost()
for _, route := range node.GetRoutes() {
_, ipnet, _ := net.ParseCIDR(route)
allowedips = append(allowedips, *ipnet)
clients, ok := params.peerToClients[pubKey.String()]
if ok {
allowedips = append(allowedips, clients...)
}
for _, route := range params.node.GetRoutes() {
bestRoutes := params.routes[route.GetDestination().String()]
var pickedRoute routeNode
if len(bestRoutes) == 1 {
pickedRoute = bestRoutes[0]
} else if len(bestRoutes) > 1 {
bucketFunc := func(rn routeNode) int {
return lib.HashString(rn.gateway)
}
// Else there is more than one candidate so consistently hash
pickedRoute = lib.ConsistentHash(bestRoutes, params.self, bucketFunc, m.hashFunc)
}
if pickedRoute.gateway == pubKey.String() {
allowedips = append(allowedips, *pickedRoute.route.GetDestination())
}
}
config := params.mesh.GetConfiguration()
var keepAlive time.Duration = time.Duration(0)
if config.KeepAliveWg != nil {
keepAlive = time.Duration(*config.KeepAliveWg) * time.Second
}
existing := slices.IndexFunc(params.device.Peers, func(p wgtypes.Peer) bool {
pubKey, _ := params.node.GetPublicKey()
return p.PublicKey.String() == pubKey.String()
})
endpoint, err := net.ResolveUDPAddr("udp", params.node.GetWgEndpoint())
if err != nil {
return nil, err
}
// Don't override the existing IP in case it already exists
if existing != -1 {
endpoint = params.device.Peers[existing].Endpoint
}
peerConfig := wgtypes.PeerConfig{
PublicKey: pubKey,
Endpoint: endpoint,
AllowedIPs: allowedips,
PersistentKeepaliveInterval: &keepAlive,
ReplaceAllowedIPs: true,
}
return &peerConfig, nil
}
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
snap, err := mesh.GetMesh()
// getRoutes: finds the routes with the least hop distance. If more than one route exists
// consistently hash to evenly spread the distribution of traffic
func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]routeNode {
mesh, _ := meshProvider.GetMesh()
routes := make(map[string][]routeNode)
if err != nil {
return err
peers := lib.Filter(lib.MapValues(mesh.GetNodes()), func(p MeshNode) bool {
return p.GetType() == conf.PEER_ROLE
})
meshPrefixes := lib.Map(lib.MapValues(m.meshManager.GetMeshes()), func(mesh MeshProvider) *net.IPNet {
ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
return ipNet
})
for _, node := range mesh.GetNodes() {
pubKey, _ := node.GetPublicKey()
for _, route := range node.GetRoutes() {
if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool {
if prefix.IP.Equal(net.IPv6zero) && *meshProvider.GetConfiguration().AdvertiseDefaultRoute {
return true
}
nodes := snap.GetNodes()
peerConfigs := make([]wgtypes.PeerConfig, len(nodes))
var count int = 0
for _, n := range nodes {
peer, err := convertMeshNode(n)
if err != nil {
return err
return prefix.Contains(route.GetDestination().IP)
}) {
continue
}
peerConfigs[count] = *peer
count++
destination := route.GetDestination().String()
otherRoute, ok := routes[destination]
rn := routeNode{
gateway: pubKey.String(),
route: route,
}
// Client's only acessible by another peer
if node.GetType() == conf.CLIENT_ROLE {
peer := m.getCorrespondingPeer(peers, node)
self, _ := m.meshManager.GetSelf(meshProvider.GetMeshId())
// If the node isn't the self use that peer as the gateway
if !NodeEquals(peer, self) {
peerPub, _ := peer.GetPublicKey()
rn.gateway = peerPub.String()
rn.route = &RouteStub{
Destination: rn.route.GetDestination(),
HopCount: rn.route.GetHopCount() + 1,
// Append the path to this peer
Path: append(rn.route.GetPath(), peer.GetWgHost().IP.String()),
}
}
}
if !ok {
otherRoute = make([]routeNode, 1)
otherRoute[0] = rn
routes[destination] = otherRoute
} else if route.GetHopCount() < otherRoute[0].route.GetHopCount() {
otherRoute[0] = rn
} else if otherRoute[0].route.GetHopCount() == route.GetHopCount() {
routes[destination] = append(otherRoute, rn)
}
}
}
return routes
}
// getCorrespondignPeer: gets the peer corresponding to the client
func (m *WgMeshConfigApplyer) getCorrespondingPeer(peers []MeshNode, client MeshNode) MeshNode {
peer := lib.ConsistentHash(peers, client, m.hashFunc, m.hashFunc)
return peer
}
func (m *WgMeshConfigApplyer) getPeerCfgsToRemove(dev *wgtypes.Device, newPeers []wgtypes.PeerConfig) []wgtypes.PeerConfig {
peers := dev.Peers
peers = lib.Filter(peers, func(p1 wgtypes.Peer) bool {
return !lib.Contains(newPeers, func(p2 wgtypes.PeerConfig) bool {
return p1.PublicKey.String() == p2.PublicKey.String()
})
})
return lib.Map(peers, func(p wgtypes.Peer) wgtypes.PeerConfig {
return wgtypes.PeerConfig{
PublicKey: p.PublicKey,
Remove: true,
}
})
}
type GetConfigParams struct {
mesh MeshProvider
peers []MeshNode
clients []MeshNode
dev *wgtypes.Device
routes map[string][]routeNode
}
func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes.Config, error) {
self, err := m.meshManager.GetSelf(params.mesh.GetMeshId())
ula := &ip.ULABuilder{}
meshNet, _ := ula.GetIPNet(params.mesh.GetMeshId())
routesForMesh := lib.Map(lib.MapValues(params.routes), func(rns []routeNode) []routeNode {
return lib.Filter(rns, func(rn routeNode) bool {
ip, _, _ := net.ParseCIDR(rn.gateway)
return meshNet.Contains(ip)
})
})
routes := lib.Map(routesForMesh, func(rs []routeNode) net.IPNet {
return *rs[0].route.GetDestination()
})
routes = append(routes, *meshNet)
if err != nil {
return nil, err
}
peer := m.getCorrespondingPeer(params.peers, self)
pubKey, _ := peer.GetPublicKey()
config := params.mesh.GetConfiguration()
keepAlive := time.Duration(*config.KeepAliveWg) * time.Second
endpoint, err := net.ResolveUDPAddr("udp", peer.GetWgEndpoint())
if err != nil {
return nil, err
}
peerCfgs := make([]wgtypes.PeerConfig, 1)
peerCfgs[0] = wgtypes.PeerConfig{
PublicKey: pubKey,
Endpoint: endpoint,
PersistentKeepaliveInterval: &keepAlive,
AllowedIPs: routes,
ReplaceAllowedIPs: true,
}
installedRoutes := make([]lib.Route, 0)
for _, route := range peerCfgs[0].AllowedIPs {
installedRoutes = append(installedRoutes, lib.Route{
Gateway: peer.GetWgHost().IP,
Destination: route,
})
}
cfg := wgtypes.Config{
Peers: peerCfgs,
}
m.routeInstaller.InstallRoutes(params.dev.Name, installedRoutes...)
return &cfg, err
}
func (m *WgMeshConfigApplyer) getRoutesToInstall(wgNode *wgtypes.PeerConfig, mesh MeshProvider, node MeshNode) []lib.Route {
routes := make([]lib.Route, 0)
for _, route := range wgNode.AllowedIPs {
ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
_, defaultRoute, _ := net.ParseCIDR("::/0")
if !ipNet.Contains(route.IP) && !ipNet.IP.Equal(defaultRoute.IP) {
routes = append(routes, lib.Route{
Gateway: node.GetWgHost().IP,
Destination: route,
})
}
}
return routes
}
func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.Config, error) {
peerToClients := make(map[string][]net.IPNet)
installedRoutes := make([]lib.Route, 0)
peerConfigs := make([]wgtypes.PeerConfig, 0)
self, err := m.meshManager.GetSelf(params.mesh.GetMeshId())
if err != nil {
return nil, err
}
for _, n := range params.clients {
if len(params.peers) > 0 {
peer := m.getCorrespondingPeer(params.peers, n)
pubKey, _ := peer.GetPublicKey()
clients, ok := peerToClients[pubKey.String()]
if !ok {
clients = make([]net.IPNet, 0)
peerToClients[pubKey.String()] = clients
}
peerToClients[pubKey.String()] = append(clients, *n.GetWgHost())
if NodeEquals(self, peer) {
cfg, err := m.convertMeshNode(convertMeshNodeParams{
node: n,
self: self,
mesh: params.mesh,
device: params.dev,
peerToClients: peerToClients,
routes: params.routes,
})
if err != nil {
return nil, err
}
installedRoutes = append(installedRoutes, m.getRoutesToInstall(cfg, params.mesh, n)...)
peerConfigs = append(peerConfigs, *cfg)
}
}
}
for _, n := range params.peers {
if NodeEquals(n, self) {
continue
}
peer, err := m.convertMeshNode(convertMeshNodeParams{
node: n,
self: self,
mesh: params.mesh,
peerToClients: peerToClients,
routes: params.routes,
device: params.dev,
})
if err != nil {
return nil, err
}
installedRoutes = append(installedRoutes, m.getRoutesToInstall(peer, params.mesh, n)...)
peerConfigs = append(peerConfigs, *peer)
}
cfg := wgtypes.Config{
Peers: peerConfigs,
}
dev, err := mesh.GetDevice()
err = m.routeInstaller.InstallRoutes(params.dev.Name, installedRoutes...)
return &cfg, err
}
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string][]routeNode) error {
snap, err := mesh.GetMesh()
if err != nil {
return err
}
return m.meshManager.Client.ConfigureDevice(dev.Name, cfg)
nodes := lib.MapValues(snap.GetNodes())
dev, _ := mesh.GetDevice()
slices.SortFunc(nodes, func(a, b MeshNode) int {
return strings.Compare(string(a.GetType()), string(b.GetType()))
})
peers := lib.Filter(nodes, func(mn MeshNode) bool {
return mn.GetType() == conf.PEER_ROLE
})
clients := lib.Filter(nodes, func(mn MeshNode) bool {
return mn.GetType() == conf.CLIENT_ROLE
})
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
if err != nil {
return err
}
var cfg *wgtypes.Config = nil
configParams := &GetConfigParams{
mesh: mesh,
peers: peers,
clients: clients,
dev: dev,
routes: routes,
}
switch self.GetType() {
case conf.PEER_ROLE:
cfg, err = m.getPeerConfig(configParams)
case conf.CLIENT_ROLE:
cfg, err = m.getClientConfig(configParams)
}
if err != nil {
return err
}
toRemove := m.getPeerCfgsToRemove(dev, cfg.Peers)
cfg.Peers = append(cfg.Peers, toRemove...)
err = m.meshManager.GetClient().ConfigureDevice(dev.Name, *cfg)
if err != nil {
return err
}
return nil
}
func (m *WgMeshConfigApplyer) getAllRoutes() map[string][]routeNode {
allRoutes := make(map[string][]routeNode)
for _, mesh := range m.meshManager.GetMeshes() {
routes := m.getRoutes(mesh)
for destination, route := range routes {
_, ok := allRoutes[destination]
if !ok {
allRoutes[destination] = route
continue
}
if allRoutes[destination][0].route.GetHopCount() == route[0].route.GetHopCount() {
allRoutes[destination] = append(allRoutes[destination], route...)
} else if route[0].route.GetHopCount() < allRoutes[destination][0].route.GetHopCount() {
allRoutes[destination] = route
}
}
}
return allRoutes
}
func (m *WgMeshConfigApplyer) ApplyConfig() error {
for _, mesh := range m.meshManager.Meshes {
err := m.updateWgConf(mesh)
allRoutes := m.getAllRoutes()
for _, mesh := range m.meshManager.GetMeshes() {
err := m.updateWgConf(mesh, allRoutes)
if err != nil {
return err
@ -101,7 +471,7 @@ func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error {
mesh := m.meshManager.GetMesh(meshId)
if mesh == nil {
return errors.New(fmt.Sprintf("mesh %s does not exist", meshId))
return fmt.Errorf("mesh %s does not exist", meshId)
}
dev, err := mesh.GetDevice()
@ -110,14 +480,24 @@ func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error {
return err
}
m.meshManager.Client.ConfigureDevice(dev.Name, wgtypes.Config{
m.meshManager.GetClient().ConfigureDevice(dev.Name, wgtypes.Config{
Peers: make([]wgtypes.PeerConfig, 0),
ReplacePeers: true,
Peers: make([]wgtypes.PeerConfig, 1),
})
return nil
}
func NewWgMeshConfigApplyer(manager *MeshManager) MeshConfigApplyer {
return &WgMeshConfigApplyer{meshManager: manager}
func (m *WgMeshConfigApplyer) SetMeshManager(manager MeshManager) {
m.meshManager = manager
}
func NewWgMeshConfigApplyer() MeshConfigApplyer {
return &WgMeshConfigApplyer{
routeInstaller: route.NewRouteInstaller(),
hashFunc: func(mn MeshNode) int {
pubKey, _ := mn.GetPublicKey()
return lib.HashString(pubKey.String())
},
}
}

View File

@ -15,7 +15,7 @@ type MeshGraphConverter interface {
}
type MeshDOTConverter struct {
manager *MeshManager
manager MeshManager
}
func (c *MeshDOTConverter) Generate(meshId string) (string, error) {
@ -34,7 +34,7 @@ func (c *MeshDOTConverter) Generate(meshId string) (string, error) {
}
for _, node := range snapshot.GetNodes() {
c.graphNode(g, node)
c.graphNode(g, node, meshId)
}
nodes := lib.MapValues(snapshot.GetNodes())
@ -55,11 +55,13 @@ func (c *MeshDOTConverter) Generate(meshId string) (string, error) {
}
// graphNode: graphs a node within the mesh
func (c *MeshDOTConverter) graphNode(g *graph.Graph, node MeshNode) {
func (c *MeshDOTConverter) graphNode(g *graph.Graph, node MeshNode, meshId string) {
nodeId := fmt.Sprintf("\"%s\"", node.GetIdentifier())
g.PutNode(nodeId, graph.CIRCLE)
if node.GetHostEndpoint() == c.manager.HostParameters.HostEndpoint {
self, _ := c.manager.GetSelf(meshId)
if NodeEquals(self, node) {
return
}
@ -70,6 +72,6 @@ func (c *MeshDOTConverter) graphNode(g *graph.Graph, node MeshNode) {
}
}
func NewMeshDotConverter(m *MeshManager) MeshGraphConverter {
func NewMeshDotConverter(m MeshManager) MeshGraphConverter {
return &MeshDOTConverter{manager: m}
}

View File

@ -3,88 +3,225 @@ package mesh
import (
"errors"
"fmt"
"sync"
"github.com/tim-beatham/wgmesh/pkg/cmd"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/wg"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type MeshManager struct {
type MeshManager interface {
CreateMesh(params *CreateMeshParams) (string, error)
AddMesh(params *AddMeshParams) error
HasChanges(meshid string) bool
GetMesh(meshId string) MeshProvider
GetPublicKey() *wgtypes.Key
AddSelf(params *AddSelfParams) error
LeaveMesh(meshId string) error
GetSelf(meshId string) (MeshNode, error)
ApplyConfig() error
SetDescription(description string) error
SetAlias(alias string) error
SetService(service string, value string) error
RemoveService(service string) error
UpdateTimeStamp() error
GetClient() *wgctrl.Client
GetMeshes() map[string]MeshProvider
Close() error
GetMonitor() MeshMonitor
GetNode(string, string) MeshNode
GetRouteManager() RouteManager
}
type MeshManagerImpl struct {
lock sync.RWMutex
Meshes map[string]MeshProvider
RouteManager RouteManager
Client *wgctrl.Client
// HostParameters contains information that uniquely locates
// the node in the mesh network.
HostParameters *HostParameters
conf *conf.WgMeshConfiguration
conf *conf.DaemonConfiguration
meshProviderFactory MeshProviderFactory
nodeFactory MeshNodeFactory
configApplyer MeshConfigApplyer
idGenerator lib.IdGenerator
ipAllocator ip.IPAllocator
interfaceManipulator wg.WgInterfaceManipulator
Monitor MeshMonitor
cmdRunner cmd.CmdRunner
OnDelete func(MeshProvider)
}
// GetRouteManager implements MeshManager.
func (m *MeshManagerImpl) GetRouteManager() RouteManager {
return m.RouteManager
}
// RemoveService implements MeshManager.
func (m *MeshManagerImpl) RemoveService(service string) error {
for _, mesh := range m.Meshes {
err := mesh.RemoveService(m.HostParameters.GetPublicKey(), service)
if err != nil {
return err
}
}
return nil
}
// SetService implements MeshManager.
func (m *MeshManagerImpl) SetService(service string, value string) error {
for _, mesh := range m.Meshes {
err := mesh.AddService(m.HostParameters.GetPublicKey(), service, value)
if err != nil {
return err
}
}
return nil
}
func (m *MeshManagerImpl) GetNode(meshid, nodeId string) MeshNode {
mesh, ok := m.Meshes[meshid]
if !ok {
return nil
}
node, err := mesh.GetNode(nodeId)
if err != nil {
return nil
}
return node
}
// GetMonitor implements MeshManager.
func (m *MeshManagerImpl) GetMonitor() MeshMonitor {
return m.Monitor
}
// CreateMeshParams contains the parameters required to create a mesh
type CreateMeshParams struct {
Port int
Conf *conf.WgConfiguration
}
// getConf: gets the new configuration with the base configuration overriden
// from the recent
func (m *MeshManagerImpl) getConf(override *conf.WgConfiguration) (*conf.WgConfiguration, error) {
meshConfiguration := m.conf.BaseConfiguration
if override != nil {
newConf, err := conf.MergeMeshConfiguration(meshConfiguration, *override)
if err != nil {
return nil, err
}
meshConfiguration = newConf
}
return &meshConfiguration, nil
}
// CreateMesh: Creates a new mesh, stores it and returns the mesh id
func (m *MeshManager) CreateMesh(devName string, port int) (string, error) {
meshId, err := m.idGenerator.GetId()
func (m *MeshManagerImpl) CreateMesh(args *CreateMeshParams) (string, error) {
meshConfiguration, err := m.getConf(args.Conf)
if err != nil {
return "", err
}
meshId, err := m.idGenerator.GetId()
var ifName string = ""
if err != nil {
return "", err
}
m.cmdRunner.RunCommands(m.conf.BaseConfiguration.PreUp...)
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(args.Port, m.HostParameters.PrivateKey)
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
}
}
nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: devName,
Port: port,
Conf: m.conf,
DevName: ifName,
Port: args.Port,
Conf: meshConfiguration,
Client: m.Client,
MeshId: meshId,
DaemonConf: m.conf,
NodeID: m.HostParameters.GetPublicKey(),
})
if err != nil {
return "", err
}
err = m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: devName,
Port: port,
})
if err != nil {
return "", nil
return "", fmt.Errorf("error creating mesh: %w", err)
}
m.lock.Lock()
m.Meshes[meshId] = nodeManager
err = m.configApplyer.RemovePeers(meshId)
m.lock.Unlock()
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
m.cmdRunner.RunCommands(m.conf.BaseConfiguration.PostUp...)
return meshId, nil
}
type AddMeshParams struct {
MeshId string
DevName string
WgPort int
MeshBytes []byte
Conf *conf.WgConfiguration
}
// AddMesh: Add the mesh to the list of meshes
func (m *MeshManager) AddMesh(params *AddMeshParams) error {
func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
var ifName string
var err error
meshConfiguration, err := m.getConf(params.Conf)
if err != nil {
return err
}
m.cmdRunner.RunCommands(meshConfiguration.PreUp...)
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(params.WgPort, m.HostParameters.PrivateKey)
if err != nil {
return err
}
}
meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: params.DevName,
DevName: ifName,
Port: params.WgPort,
Conf: m.conf,
Conf: meshConfiguration,
Client: m.Client,
MeshId: params.MeshId,
DaemonConf: m.conf,
NodeID: m.HostParameters.GetPublicKey(),
})
m.cmdRunner.RunCommands(meshConfiguration.PostUp...)
if err != nil {
return err
}
@ -95,69 +232,27 @@ func (m *MeshManager) AddMesh(params *AddMeshParams) error {
return err
}
m.lock.Lock()
m.Meshes[params.MeshId] = meshProvider
return m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: params.DevName,
Port: params.WgPort,
})
m.lock.Unlock()
return nil
}
// HasChanges returns true if the mesh has changes
func (m *MeshManager) HasChanges(meshId string) bool {
func (m *MeshManagerImpl) HasChanges(meshId string) bool {
return m.Meshes[meshId].HasChanges()
}
// GetMesh returns the mesh with the given meshid
func (m *MeshManager) GetMesh(meshId string) MeshProvider {
theMesh, _ := m.Meshes[meshId]
func (m *MeshManagerImpl) GetMesh(meshId string) MeshProvider {
theMesh := m.Meshes[meshId]
return theMesh
}
// EnableInterface: Enables the given WireGuard interface.
func (s *MeshManager) EnableInterface(meshId string) error {
err := s.configApplyer.ApplyConfig()
if err != nil {
return err
}
meshNode, err := s.GetSelf(meshId)
if err != nil {
return err
}
mesh := s.GetMesh(meshId)
if err != nil {
return err
}
dev, err := mesh.GetDevice()
if err != nil {
return err
}
return s.interfaceManipulator.EnableInterface(dev.Name, meshNode.GetWgHost().String())
}
// GetPublicKey: Gets the public key of the WireGuard mesh
func (s *MeshManager) GetPublicKey(meshId string) (*wgtypes.Key, error) {
mesh, ok := s.Meshes[meshId]
if !ok {
return nil, errors.New("mesh does not exist")
}
dev, err := mesh.GetDevice()
if err != nil {
return nil, err
}
return &dev.PublicKey, nil
func (s *MeshManagerImpl) GetPublicKey() *wgtypes.Key {
key := s.HostParameters.PrivateKey.PublicKey()
return &key
}
type AddSelfParams struct {
@ -166,134 +261,243 @@ type AddSelfParams struct {
// WgPort is the WireGuard port to advertise
WgPort int
// Endpoint is the alias of the machine to send routable packets
// to
Endpoint string
}
// AddSelf adds this host to the mesh
func (s *MeshManager) AddSelf(params *AddSelfParams) error {
pubKey, err := s.GetPublicKey(params.MeshId)
func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
mesh := s.GetMesh(params.MeshId)
if mesh == nil {
return fmt.Errorf("addself: mesh %s does not exist", params.MeshId)
}
if params.WgPort == 0 && !s.conf.StubWg {
device, err := mesh.GetDevice()
if err != nil {
return err
}
nodeIP, err := s.ipAllocator.GetIP(*pubKey, params.MeshId)
params.WgPort = device.ListenPort
}
pubKey := s.HostParameters.PrivateKey.PublicKey()
nodeIP, err := s.ipAllocator.GetIP(pubKey, params.MeshId)
if err != nil {
return err
}
node := s.nodeFactory.Build(&MeshNodeFactoryParams{
PublicKey: pubKey,
PublicKey: &pubKey,
NodeIP: nodeIP,
WgPort: params.WgPort,
Endpoint: params.Endpoint,
MeshConfig: mesh.GetConfiguration(),
})
if !s.conf.StubWg {
device, err := mesh.GetDevice()
if err != nil {
return fmt.Errorf("failed to get device %w", err)
}
err = s.interfaceManipulator.AddAddress(device.Name, fmt.Sprintf("%s/64", nodeIP))
if err != nil {
return fmt.Errorf("addSelf: failed to add address to dev %w", err)
}
}
s.Meshes[params.MeshId].AddNode(node)
return s.RouteManager.UpdateRoutes()
}
// LeaveMesh leaves the mesh network
func (s *MeshManager) LeaveMesh(meshId string) error {
_, exists := s.Meshes[meshId]
if !exists {
return errors.New(fmt.Sprintf("mesh %s does not exist", meshId))
}
// For now just delete the mesh with the ID.
delete(s.Meshes, meshId)
return nil
}
func (s *MeshManager) GetSelf(meshId string) (MeshNode, error) {
// LeaveMesh leaves the mesh network
func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
mesh := s.GetMesh(meshId)
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
err := mesh.RemoveNode(s.HostParameters.GetPublicKey())
if err != nil {
return err
}
if s.OnDelete != nil {
s.OnDelete(mesh)
}
s.lock.Lock()
delete(s.Meshes, meshId)
s.lock.Unlock()
s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PreDown...)
if !s.conf.StubWg {
device, err := mesh.GetDevice()
if err != nil {
return err
}
err = s.interfaceManipulator.RemoveInterface(device.Name)
if err != nil {
return err
}
}
s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PostDown...)
return err
}
func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
meshInstance, ok := s.Meshes[meshId]
if !ok {
return nil, errors.New(fmt.Sprintf("mesh %s does not exist", meshId))
return nil, fmt.Errorf("mesh %s does not exist", meshId)
}
snapshot, err := meshInstance.GetMesh()
node, err := meshInstance.GetNode(s.HostParameters.GetPublicKey())
if err != nil {
return nil, err
}
node, ok := snapshot.GetNodes()[s.HostParameters.HostEndpoint]
if !ok {
return nil, errors.New("the node doesn't exist in the mesh")
}
return node, nil
}
func (s *MeshManager) ApplyConfig() error {
return s.configApplyer.ApplyConfig()
func (s *MeshManagerImpl) ApplyConfig() error {
if s.conf.StubWg {
return nil
}
func (s *MeshManager) SetDescription(description string) error {
for _, mesh := range s.Meshes {
err := mesh.SetDescription(s.HostParameters.HostEndpoint, description)
err := s.configApplyer.ApplyConfig()
if err != nil {
return err
}
return nil
}
func (s *MeshManagerImpl) SetDescription(description string) error {
meshes := s.GetMeshes()
for _, mesh := range meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetDescription(s.HostParameters.GetPublicKey(), description)
if err != nil {
return err
}
}
}
return nil
}
// SetAlias implements MeshManager.
func (s *MeshManagerImpl) SetAlias(alias string) error {
meshes := s.GetMeshes()
for _, mesh := range meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetAlias(s.HostParameters.GetPublicKey(), alias)
if err != nil {
return err
}
}
}
return nil
}
// UpdateTimeStamp updates the timestamp of this node in all meshes
func (s *MeshManager) UpdateTimeStamp() error {
func (s *MeshManagerImpl) UpdateTimeStamp() error {
meshes := s.GetMeshes()
for _, mesh := range meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.UpdateTimeStamp(s.HostParameters.GetPublicKey())
if err != nil {
return err
}
}
}
return nil
}
func (s *MeshManagerImpl) GetClient() *wgctrl.Client {
return s.Client
}
func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider {
meshes := make(map[string]MeshProvider)
s.lock.RLock()
for id, mesh := range s.Meshes {
meshes[id] = mesh
}
s.lock.RUnlock()
return meshes
}
// Close the mesh manager
func (s *MeshManagerImpl) Close() error {
if s.conf.StubWg {
return nil
}
for _, mesh := range s.Meshes {
snapshot, err := mesh.GetMesh()
dev, err := mesh.GetDevice()
if err != nil {
return err
}
_, exists := snapshot.GetNodes()[s.HostParameters.HostEndpoint]
if exists {
err = mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint)
err = s.interfaceManipulator.RemoveInterface(dev.Name)
if err != nil {
return err
}
}
}
return nil
}
// NewMeshManagerParams params required to create an instance of a mesh manager
type NewMeshManagerParams struct {
Conf conf.WgMeshConfiguration
Conf conf.DaemonConfiguration
Client *wgctrl.Client
MeshProvider MeshProviderFactory
NodeFactory MeshNodeFactory
IdGenerator lib.IdGenerator
IPAllocator ip.IPAllocator
InterfaceManipulator wg.WgInterfaceManipulator
ConfigApplyer MeshConfigApplyer
RouteManager RouteManager
CommandRunner cmd.CmdRunner
OnDelete func(MeshProvider)
}
// Creates a new instance of a mesh manager with the given parameters
func NewMeshManager(params *NewMeshManagerParams) *MeshManager {
hostParams := HostParameters{}
switch params.Conf.Endpoint {
case "":
hostParams.HostEndpoint = fmt.Sprintf("%s:%s", lib.GetOutboundIP().String(), params.Conf.GrpcPort)
default:
hostParams.HostEndpoint = fmt.Sprintf("%s:%s", params.Conf.Endpoint, params.Conf.GrpcPort)
func NewMeshManager(params *NewMeshManagerParams) MeshManager {
privateKey, _ := wgtypes.GeneratePrivateKey()
hostParams := HostParameters{
PrivateKey: &privateKey,
}
logging.Log.WriteInfof("Endpoint %s", hostParams.HostEndpoint)
m := &MeshManager{
m := &MeshManagerImpl{
Meshes: make(map[string]MeshProvider),
HostParameters: &hostParams,
meshProviderFactory: params.MeshProvider,
@ -301,10 +505,27 @@ func NewMeshManager(params *NewMeshManagerParams) *MeshManager {
Client: params.Client,
conf: &params.Conf,
}
m.configApplyer = NewWgMeshConfigApplyer(m)
m.configApplyer = params.ConfigApplyer
m.RouteManager = params.RouteManager
if m.RouteManager == nil {
m.RouteManager = NewRouteManager(m)
}
if params.CommandRunner == nil {
m.cmdRunner = &cmd.UnixCmdRunner{}
}
m.idGenerator = params.IdGenerator
m.ipAllocator = params.IPAllocator
m.interfaceManipulator = params.InterfaceManipulator
m.Monitor = NewMeshMonitor(m)
aliasManager := NewAliasManager()
m.Monitor.AddUpdateCallback(aliasManager.AddAliases)
m.Monitor.AddRemoveCallback(aliasManager.RemoveAliases)
m.OnDelete = params.OnDelete
return m
}

236
pkg/mesh/manager_test.go Normal file
View File

@ -0,0 +1,236 @@
package mesh
import (
"testing"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/wg"
)
func getMeshConfiguration() *conf.DaemonConfiguration {
return &conf.DaemonConfiguration{
GrpcPort: 8080,
}
}
func getMeshManager() MeshManager {
manager := NewMeshManager(&NewMeshManagerParams{
Conf: *getMeshConfiguration(),
Client: nil,
MeshProvider: &StubMeshProviderFactory{},
NodeFactory: &StubNodeFactory{Config: getMeshConfiguration()},
IdGenerator: &lib.UUIDGenerator{},
IPAllocator: &ip.ULABuilder{},
InterfaceManipulator: &wg.WgInterfaceManipulatorStub{},
ConfigApplyer: &MeshConfigApplyerStub{},
RouteManager: &RouteManagerStub{},
})
return manager
}
func TestCreateMeshCreatesANewMeshProvider(t *testing.T) {
manager := getMeshManager()
meshId, err := manager.CreateMesh("wg0", 5000)
if err != nil {
t.Error(err)
}
if len(meshId) == 0 {
t.Fatal(`meshId should not be empty`)
}
_, exists := manager.GetMeshes()[meshId]
if !exists {
t.Fatal(`mesh was not created when it should be`)
}
}
func TestAddMeshAddsAMesh(t *testing.T) {
manager := getMeshManager()
meshId := "meshid123"
manager.AddMesh(&AddMeshParams{
MeshId: meshId,
WgPort: 6000,
MeshBytes: make([]byte, 0),
})
mesh := manager.GetMesh(meshId)
if mesh == nil || mesh.GetMeshId() != meshId {
t.Fatalf(`mesh has not been added to the list of meshes`)
}
}
func TestAddMeshMeshAlreadyExistsReplacesIt(t *testing.T) {
manager := getMeshManager()
meshId := "meshid123"
for i := 0; i < 2; i++ {
err := manager.AddMesh(&AddMeshParams{
MeshId: meshId,
WgPort: 6000,
MeshBytes: make([]byte, 0),
})
if err != nil {
t.Error(err)
}
}
mesh := manager.GetMesh(meshId)
if mesh == nil || mesh.GetMeshId() != meshId {
t.Fatalf(`mesh has not been added to the list of meshes`)
}
}
func TestAddSelfAddsSelfToTheMesh(t *testing.T) {
manager := getMeshManager()
meshId := "meshid123"
err := manager.AddMesh(&AddMeshParams{
MeshId: meshId,
WgPort: 6000,
MeshBytes: make([]byte, 0),
})
if err != nil {
t.Error(err)
}
err = manager.AddSelf(&AddSelfParams{
MeshId: meshId,
WgPort: 5000,
Endpoint: "abc.com",
})
if err != nil {
t.Error(err)
}
mesh, err := manager.GetMesh(meshId).GetMesh()
if err != nil {
t.Error(err)
}
_, ok := mesh.GetNodes()["abc.com"]
if !ok {
t.Fatalf(`node has not been added`)
}
}
func TestAddSelfToMeshAlreadyInMesh(t *testing.T) {
TestAddSelfAddsSelfToTheMesh(t)
TestAddSelfAddsSelfToTheMesh(t)
}
func TestAddSelfToMeshMeshDoesNotExist(t *testing.T) {
manager := getMeshManager()
meshId := "meshid123"
err := manager.AddSelf(&AddSelfParams{
MeshId: meshId,
WgPort: 5000,
Endpoint: "abc.com",
})
if err == nil {
t.Fatalf(`Expected error to be thrown`)
}
}
func TestLeaveMeshMeshDoesNotExist(t *testing.T) {
manager := getMeshManager()
meshId := "meshid123"
err := manager.LeaveMesh(meshId)
if err == nil {
t.Fatalf(`Expected error to be thrown`)
}
}
func TestLeaveMeshDeletesMesh(t *testing.T) {
manager := getMeshManager()
meshId := "meshid123"
err := manager.AddMesh(&AddMeshParams{
MeshId: meshId,
WgPort: 6000,
MeshBytes: make([]byte, 0),
})
if err != nil {
t.Error(err)
}
err = manager.LeaveMesh(meshId)
if err != nil {
t.Fatalf("%s", err.Error())
}
_, exists := manager.GetMeshes()[meshId]
if exists {
t.Fatalf(`expected mesh to have been deleted`)
}
}
func TestSetDescription(t *testing.T) {
manager := getMeshManager()
description := "wooooo"
meshId1, _ := manager.CreateMesh(5000)
meshId2, _ := manager.CreateMesh(5001)
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId2,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
err := manager.SetDescription(description)
if err != nil {
t.Fatalf(`failed to set the descriptions`)
}
}
func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) {
manager := getMeshManager()
meshId1, _ := manager.CreateMesh(5000)
meshId2, _ := manager.CreateMesh(5001)
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId2,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
err := manager.UpdateTimeStamp()
if err != nil {
t.Fatalf(`failed to update the timestamp`)
}
}

81
pkg/mesh/monitor.go Normal file
View File

@ -0,0 +1,81 @@
package mesh
type OnChange = func([]MeshNode)
type MeshMonitor interface {
AddUpdateCallback(cb OnChange)
AddRemoveCallback(cb OnChange)
Trigger() error
}
type MeshMonitorImpl struct {
updateCbs []OnChange
removeCbs []OnChange
nodes map[string]MeshNode
manager MeshManager
}
// Trigger causes the mesh monitor to trigger all of
// the callbacks.
func (m *MeshMonitorImpl) Trigger() error {
changedNodes := make([]MeshNode, 0)
removedNodes := make([]MeshNode, 0)
nodes := make(map[string]MeshNode)
for _, mesh := range m.manager.GetMeshes() {
snapshot, err := mesh.GetMesh()
if err != nil {
return err
}
for _, node := range snapshot.GetNodes() {
previous, exists := m.nodes[node.GetWgHost().String()]
if !exists || !NodeEquals(previous, node) {
changedNodes = append(changedNodes, node)
}
nodes[node.GetWgHost().String()] = node
}
}
for _, previous := range m.nodes {
_, ok := nodes[previous.GetWgHost().String()]
if !ok {
removedNodes = append(removedNodes, previous)
}
}
if len(removedNodes) > 0 {
for _, cb := range m.removeCbs {
cb(removedNodes)
}
}
if len(changedNodes) > 0 {
for _, cb := range m.updateCbs {
cb(changedNodes)
}
}
return nil
}
func (m *MeshMonitorImpl) AddUpdateCallback(cb OnChange) {
m.updateCbs = append(m.updateCbs, cb)
}
func (m *MeshMonitorImpl) AddRemoveCallback(cb OnChange) {
m.removeCbs = append(m.removeCbs, cb)
}
func NewMeshMonitor(manager MeshManager) MeshMonitor {
return &MeshMonitorImpl{
updateCbs: make([]OnChange, 0),
nodes: make(map[string]MeshNode),
manager: manager,
}
}

View File

@ -1,9 +1,10 @@
package mesh
import (
"net"
"github.com/tim-beatham/wgmesh/pkg/ip"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/route"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
type RouteManager interface {
@ -11,38 +12,105 @@ type RouteManager interface {
}
type RouteManagerImpl struct {
meshManager *MeshManager
routeInstaller route.RouteInstaller
meshManager MeshManager
}
func (r *RouteManagerImpl) UpdateRoutes() error {
meshes := r.meshManager.Meshes
ulaBuilder := new(ip.ULABuilder)
meshes := r.meshManager.GetMeshes()
routes := make(map[string][]Route)
for _, mesh1 := range meshes {
if !*mesh1.GetConfiguration().AdvertiseRoutes {
continue
}
self, err := r.meshManager.GetSelf(mesh1.GetMeshId())
if err != nil {
return err
}
if _, ok := routes[mesh1.GetMeshId()]; !ok {
routes[mesh1.GetMeshId()] = make([]Route, 0)
}
if *mesh1.GetConfiguration().AdvertiseDefaultRoute {
_, ipv6Default, _ := net.ParseCIDR("::/0")
defaultRoute := &RouteStub{
Destination: ipv6Default,
HopCount: 0,
Path: []string{mesh1.GetMeshId()},
}
mesh1.AddRoutes(NodeID(self), defaultRoute)
routes[mesh1.GetMeshId()] = append(routes[mesh1.GetMeshId()], defaultRoute)
}
routeMap, err := mesh1.GetRoutes(NodeID(self))
if err != nil {
return err
}
for _, mesh2 := range meshes {
routeValues, ok := routes[mesh2.GetMeshId()]
if !ok {
routeValues = make([]Route, 0)
}
if mesh1 == mesh2 {
continue
}
ipNet, err := ulaBuilder.GetIPNet(mesh2.GetMeshId())
mesh1IpNet, _ := (&ip.ULABuilder{}).GetIPNet(mesh1.GetMeshId())
if err != nil {
logging.Log.WriteErrorf(err.Error())
return err
routeValues = append(routeValues, &RouteStub{
Destination: mesh1IpNet,
HopCount: 0,
Path: []string{mesh1.GetMeshId()},
})
routeValues = append(routeValues, lib.MapValues(routeMap)...)
mesh2IpNet, _ := (&ip.ULABuilder{}).GetIPNet(mesh2.GetMeshId())
routeValues = lib.Filter(routeValues, func(r Route) bool {
pathNotMesh := func(s string) bool {
return s == mesh2.GetMeshId()
}
err = mesh1.AddRoutes(r.meshManager.HostParameters.HostEndpoint, ipNet.String())
// Remove any potential routing loops
return !r.GetDestination().IP.Equal(mesh2IpNet.IP) &&
!lib.Contains(r.GetPath()[1:], pathNotMesh)
})
if err != nil {
return err
routes[mesh2.GetMeshId()] = routeValues
}
}
// Calculate the set different of each, working out routes to remove and to keep.
for meshId, meshRoutes := range routes {
mesh := r.meshManager.GetMesh(meshId)
self, _ := r.meshManager.GetSelf(meshId)
toRemove := make([]Route, 0)
prevRoutes, _ := mesh.GetRoutes(NodeID(self))
for _, route := range prevRoutes {
if !lib.Contains(meshRoutes, func(r Route) bool {
return RouteEquals(r, route)
}) {
toRemove = append(toRemove, route)
}
}
mesh.RemoveRoutes(NodeID(self), toRemove...)
mesh.AddRoutes(NodeID(self), meshRoutes...)
}
return nil
}
func NewRouteManager(m *MeshManager) RouteManager {
return &RouteManagerImpl{meshManager: m, routeInstaller: route.NewRouteInstaller()}
func NewRouteManager(m MeshManager) RouteManager {
return &RouteManagerImpl{meshManager: m}
}

16
pkg/mesh/route_stub.go Normal file
View File

@ -0,0 +1,16 @@
package mesh
type RouteManagerStub struct {
}
func (r *RouteManagerStub) UpdateRoutes() error {
return nil
}
func (r *RouteManagerStub) InstallRoutes() error {
return nil
}
func (r *RouteManagerStub) RemoveRoutes(meshId string) error {
return nil
}

340
pkg/mesh/stub_types.go Normal file
View File

@ -0,0 +1,340 @@
package mesh
import (
"fmt"
"net"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type MeshNodeStub struct {
hostEndpoint string
publicKey wgtypes.Key
wgEndpoint string
wgHost *net.IPNet
timeStamp int64
routes []Route
identifier string
description string
}
// GetType implements MeshNode.
func (*MeshNodeStub) GetType() conf.NodeType {
return conf.PEER_ROLE
}
// GetServices implements MeshNode.
func (*MeshNodeStub) GetServices() map[string]string {
return make(map[string]string)
}
// GetAlias implements MeshNode.
func (*MeshNodeStub) GetAlias() string {
return ""
}
func (m *MeshNodeStub) GetHostEndpoint() string {
return m.hostEndpoint
}
func (m *MeshNodeStub) GetPublicKey() (wgtypes.Key, error) {
return m.publicKey, nil
}
func (m *MeshNodeStub) GetWgEndpoint() string {
return m.wgEndpoint
}
func (m *MeshNodeStub) GetWgHost() *net.IPNet {
return m.wgHost
}
func (m *MeshNodeStub) GetTimeStamp() int64 {
return m.timeStamp
}
func (m *MeshNodeStub) GetRoutes() []Route {
return m.routes
}
func (m *MeshNodeStub) GetIdentifier() string {
return m.identifier
}
func (m *MeshNodeStub) GetDescription() string {
return m.description
}
type MeshSnapshotStub struct {
nodes map[string]MeshNode
}
func (s *MeshSnapshotStub) GetNodes() map[string]MeshNode {
return s.nodes
}
type MeshProviderStub struct {
meshId string
snapshot *MeshSnapshotStub
}
// GetConfiguration implements MeshProvider.
func (*MeshProviderStub) GetConfiguration() *conf.WgConfiguration {
panic("unimplemented")
}
// Mark implements MeshProvider.
func (*MeshProviderStub) Mark(nodeId string) {
panic("unimplemented")
}
// RemoveNode implements MeshProvider.
func (*MeshProviderStub) RemoveNode(nodeId string) error {
panic("unimplemented")
}
func (*MeshProviderStub) GetRoutes(targetId string) (map[string]Route, error) {
return nil, nil
}
// GetNodeIds implements MeshProvider.
func (*MeshProviderStub) GetPeers() []string {
return make([]string, 0)
}
// GetNode implements MeshProvider.
func (*MeshProviderStub) GetNode(string) (MeshNode, error) {
return nil, nil
}
// NodeExists implements MeshProvider.
func (*MeshProviderStub) NodeExists(string) bool {
return false
}
// AddService implements MeshProvider.
func (*MeshProviderStub) AddService(nodeId string, key string, value string) error {
return nil
}
// RemoveService implements MeshProvider.
func (*MeshProviderStub) RemoveService(nodeId string, key string) error {
return nil
}
// SetAlias implements MeshProvider.
func (*MeshProviderStub) SetAlias(nodeId string, alias string) error {
return nil
}
// RemoveRoutes implements MeshProvider.
func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...Route) error {
return nil
}
// Prune implements MeshProvider.
func (*MeshProviderStub) Prune() error {
return nil
}
// UpdateTimeStamp implements MeshProvider.
func (*MeshProviderStub) UpdateTimeStamp(nodeId string) error {
return nil
}
func (s *MeshProviderStub) AddNode(node MeshNode) {
s.snapshot.nodes[node.GetHostEndpoint()] = node
}
func (s *MeshProviderStub) GetMesh() (MeshSnapshot, error) {
return s.snapshot, nil
}
func (s *MeshProviderStub) GetMeshId() string {
return s.meshId
}
func (s *MeshProviderStub) Save() []byte {
return make([]byte, 0)
}
func (s *MeshProviderStub) Load(bytes []byte) error {
return nil
}
func (s *MeshProviderStub) GetDevice() (*wgtypes.Device, error) {
pubKey, _ := wgtypes.GenerateKey()
return &wgtypes.Device{
PublicKey: pubKey,
}, nil
}
func (s *MeshProviderStub) SaveChanges() {}
func (s *MeshProviderStub) HasChanges() bool {
return false
}
func (s *MeshProviderStub) AddRoutes(nodeId string, route ...Route) error {
return nil
}
func (s *MeshProviderStub) GetSyncer() MeshSyncer {
return nil
}
func (s *MeshProviderStub) SetDescription(nodeId string, description string) error {
return nil
}
type StubMeshProviderFactory struct{}
func (s *StubMeshProviderFactory) CreateMesh(params *MeshProviderFactoryParams) (MeshProvider, error) {
return &MeshProviderStub{
meshId: params.MeshId,
snapshot: &MeshSnapshotStub{nodes: make(map[string]MeshNode)},
}, nil
}
type StubNodeFactory struct {
Config *conf.DaemonConfiguration
}
func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode {
_, wgHost, _ := net.ParseCIDR(fmt.Sprintf("%s/128", params.NodeIP.String()))
return &MeshNodeStub{
hostEndpoint: params.Endpoint,
publicKey: *params.PublicKey,
wgEndpoint: fmt.Sprintf("%s:%s", params.Endpoint, s.Config.GrpcPort),
wgHost: wgHost,
timeStamp: time.Now().Unix(),
routes: make([]Route, 0),
identifier: "abc",
description: "A Mesh Node Stub",
}
}
type MeshConfigApplyerStub struct{}
func (a *MeshConfigApplyerStub) ApplyConfig() error {
return nil
}
func (a *MeshConfigApplyerStub) RemovePeers(meshId string) error {
return nil
}
func (a *MeshConfigApplyerStub) SetMeshManager(manager MeshManager) {
}
type MeshManagerStub struct {
meshes map[string]MeshProvider
}
// GetRouteManager implements MeshManager.
func (*MeshManagerStub) GetRouteManager() RouteManager {
panic("unimplemented")
}
// GetNode implements MeshManager.
func (*MeshManagerStub) GetNode(string, string) MeshNode {
panic("unimplemented")
}
// RemoveService implements MeshManager.
func (*MeshManagerStub) RemoveService(service string) error {
panic("unimplemented")
}
// SetService implements MeshManager.
func (*MeshManagerStub) SetService(service string, value string) error {
panic("unimplemented")
}
// GetMonitor implements MeshManager.
func (*MeshManagerStub) GetMonitor() MeshMonitor {
panic("unimplemented")
}
// SetAlias implements MeshManager.
func (*MeshManagerStub) SetAlias(alias string) error {
panic("unimplemented")
}
// Close implements MeshManager.
func (*MeshManagerStub) Close() error {
panic("unimplemented")
}
// Prune implements MeshManager.
func (*MeshManagerStub) Prune() error {
return nil
}
func NewMeshManagerStub() MeshManager {
return &MeshManagerStub{meshes: make(map[string]MeshProvider)}
}
func (m *MeshManagerStub) CreateMesh(*CreateMeshParams) (string, error) {
return "tim123", nil
}
func (m *MeshManagerStub) AddMesh(params *AddMeshParams) error {
m.meshes[params.MeshId] = &MeshProviderStub{
params.MeshId,
&MeshSnapshotStub{nodes: make(map[string]MeshNode)},
}
return nil
}
func (m *MeshManagerStub) HasChanges(meshId string) bool {
return false
}
func (m *MeshManagerStub) GetMesh(meshId string) MeshProvider {
return &MeshProviderStub{
meshId: meshId,
snapshot: &MeshSnapshotStub{nodes: make(map[string]MeshNode)}}
}
func (m *MeshManagerStub) GetPublicKey() *wgtypes.Key {
key, _ := wgtypes.GenerateKey()
return &key
}
func (m *MeshManagerStub) AddSelf(params *AddSelfParams) error {
return nil
}
func (m *MeshManagerStub) GetSelf(meshId string) (MeshNode, error) {
return nil, nil
}
func (m *MeshManagerStub) ApplyConfig() error {
return nil
}
func (m *MeshManagerStub) SetDescription(description string) error {
return nil
}
func (m *MeshManagerStub) UpdateTimeStamp() error {
return nil
}
func (m *MeshManagerStub) GetClient() *wgctrl.Client {
return nil
}
func (m *MeshManagerStub) GetMeshes() map[string]MeshProvider {
return m.meshes
}
func (m *MeshManagerStub) LeaveMesh(meshId string) error {
return nil
}

View File

@ -4,12 +4,46 @@ package mesh
import (
"net"
"slices"
"github.com/tim-beatham/wgmesh/pkg/conf"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type Route interface {
// GetDestination: returns the destination of the route
GetDestination() *net.IPNet
// GetHopCount: get the total hopcount of the prefix
GetHopCount() int
// GetPath: get a list of AS paths to get to the destination
GetPath() []string
}
func RouteEquals(r1, r2 Route) bool {
return r1.GetDestination().String() == r2.GetDestination().String() &&
r1.GetHopCount() == r2.GetHopCount() &&
slices.Equal(r1.GetPath(), r2.GetPath())
}
type RouteStub struct {
Destination *net.IPNet
HopCount int
Path []string
}
func (r *RouteStub) GetDestination() *net.IPNet {
return r.Destination
}
func (r *RouteStub) GetHopCount() int {
return r.HopCount
}
func (r *RouteStub) GetPath() []string {
return r.Path
}
// MeshNode represents an implementation of a node in a mesh
type MeshNode interface {
// GetHostEndpoint: gets the gRPC endpoint of the node
@ -23,11 +57,30 @@ type MeshNode interface {
// GetTimestamp: get the UNIX time stamp of the ndoe
GetTimeStamp() int64
// GetRoutes: returns the routes that the nodes provides
GetRoutes() []string
GetRoutes() []Route
// GetIdentifier: returns the identifier of the node
GetIdentifier() string
// GetDescription: returns the description for this node
GetDescription() string
// GetAlias: associates the node with an alias. Potentially used
// for DNS and so forth.
GetAlias() string
// GetServices: returns a list of services offered by the node
GetServices() map[string]string
GetType() conf.NodeType
}
// NodeEquals: determines if two mesh nodes are equivalent to one another
func NodeEquals(node1, node2 MeshNode) bool {
key1, _ := node1.GetPublicKey()
key2, _ := node2.GetPublicKey()
return key1.String() == key2.String()
}
func NodeID(node MeshNode) string {
key, _ := node.GetPublicKey()
return key.String()
}
type MeshSnapshot interface {
@ -46,7 +99,7 @@ type MeshSyncer interface {
type MeshProvider interface {
// AddNode() adds a node to the mesh
AddNode(node MeshNode)
// GetMesh() returns a snapshot of the mesh provided by the mesh provider
// GetMesh() returns a snapshot of the mesh provided by the mesh provider.
GetMesh() (MeshSnapshot, error)
// GetMeshId() returns the ID of the mesh network
GetMeshId() string
@ -58,20 +111,53 @@ type MeshProvider interface {
GetDevice() (*wgtypes.Device, error)
// HasChanges returns true if we have changes since last time we synced
HasChanges() bool
// Record that we have changges and save the corresponding changes
// Record that we have changes and save the corresponding changes
SaveChanges()
// UpdateTimeStamp: update the timestamp of the given node
UpdateTimeStamp(nodeId string) error
// AddRoutes: adds routes to the given node
AddRoutes(nodeId string, route ...string) error
AddRoutes(nodeId string, route ...Route) error
// DeleteRoutes: deletes the routes from the node
RemoveRoutes(nodeId string, route ...Route) error
// GetSyncer: returns the automerge syncer for sync
GetSyncer() MeshSyncer
// GetNode get a particular not within the mesh
GetNode(string) (MeshNode, error)
// NodeExists: returns true if a particular node exists false otherwise
NodeExists(string) bool
// SetDescription: sets the description of this automerge data type
SetDescription(nodeId string, description string) error
// SetAlias: set the alias of the nodeId
SetAlias(nodeId string, alias string) error
// AddService: adds the service to the given node
AddService(nodeId, key, value string) error
// RemoveService: removes the service form the node. throws an error if the service does not exist
RemoveService(nodeId, key string) error
// Prune: prunes all nodes that have not updated their
// vector clock
Prune() error
// GetPeers: get a list of contactable peers
GetPeers() []string
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen
GetRoutes(targetNode string) (map[string]Route, error)
// RemoveNode(): remove the node from the mesh
RemoveNode(nodeId string) error
// Mark: marks the node as unreachable. This is not broadcast to the entire
// this is not considered when syncing node state
Mark(nodeId string)
// GetConfiguration: gets the configuration parameters specific for this
// mesh network
GetConfiguration() *conf.WgConfiguration
}
// HostParameters contains the IDs of a node
type HostParameters struct {
HostEndpoint string
// TODO: Contain the WireGungracefullyuard identifier in this
PrivateKey *wgtypes.Key
}
// GetPublicKey: gets the public key of the node
func (h *HostParameters) GetPublicKey() string {
return h.PrivateKey.PublicKey().String()
}
// MeshProviderFactoryParams parameters required to build a mesh provider
@ -79,8 +165,10 @@ type MeshProviderFactoryParams struct {
DevName string
MeshId string
Port int
Conf *conf.WgMeshConfiguration
Conf *conf.WgConfiguration
DaemonConf *conf.DaemonConfiguration
Client *wgctrl.Client
NodeID string
}
// MeshProviderFactory creates an instance of a mesh provider
@ -95,6 +183,7 @@ type MeshNodeFactoryParams struct {
NodeIP net.IP
WgPort int
Endpoint string
MeshConfig *conf.WgConfiguration
}
// MeshBuilder build the hosts mesh node for it to be added to the mesh

View File

@ -1,29 +0,0 @@
package middleware
import (
"context"
"errors"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/rpc"
)
// AuthRpcProvider implements the AuthRpcProvider service
type AuthRpcProvider struct {
rpc.UnimplementedAuthenticationServer
}
// JoinMesh handles a JoinMeshRequest. Succeeds by stating the node managed to join the mesh
// or returns an error if it failed
func (a *AuthRpcProvider) JoinMesh(ctx context.Context, in *rpc.JoinAuthMeshRequest) (*rpc.JoinAuthMeshReply, error) {
meshId := in.MeshId
if meshId == "" {
return nil, errors.New("Must specify the meshId")
}
logging.Log.WriteInfof("MeshID: " + in.MeshId)
var token string = ""
return &rpc.JoinAuthMeshReply{Success: true, Token: &token}, nil
}

View File

@ -3,8 +3,10 @@ package query
import (
"encoding/json"
"fmt"
"strings"
"github.com/jmespath/go-jmespath"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
)
@ -16,21 +18,30 @@ type Querier interface {
}
type JmesQuerier struct {
manager *mesh.MeshManager
manager mesh.MeshManager
}
type QueryError struct {
msg string
}
type QueryRoute struct {
Destination string `json:"destination"`
HopCount int `json:"hopCount"`
Path string `json:"path"`
}
type QueryNode struct {
HostEndpoint string `json:"hostEndpoint"`
PublicKey string `json:"publicKey"`
WgEndpoint string `json:"wgEndpoint"`
WgHost string `json:"wgHost"`
Timestamp int64 `json:"timestmap"`
Timestamp int64 `json:"timestamp"`
Description string `json:"description"`
Routes []string `json:"routes"`
Routes []QueryRoute `json:"routes"`
Alias string `json:"alias"`
Services map[string]string `json:"services"`
Type conf.NodeType `json:"type"`
}
func (m *QueryError) Error() string {
@ -39,7 +50,7 @@ func (m *QueryError) Error() string {
// Query: queries the data
func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
mesh, ok := j.manager.Meshes[meshId]
mesh, ok := j.manager.GetMeshes()[meshId]
if !ok {
return nil, &QueryError{msg: fmt.Sprintf("%s does not exist", meshId)}
@ -51,7 +62,7 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
return nil, err
}
nodes := lib.Map(lib.MapValues(snapshot.GetNodes()), meshNodeToQueryNode)
nodes := lib.Map(lib.MapValues(snapshot.GetNodes()), MeshNodeToQueryNode)
result, err := jmespath.Search(queryParams, nodes)
@ -63,7 +74,7 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
return bytes, err
}
func meshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
func MeshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode := new(QueryNode)
queryNode.HostEndpoint = node.GetHostEndpoint()
pubKey, _ := node.GetPublicKey()
@ -74,11 +85,21 @@ func meshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode.WgHost = node.GetWgHost().String()
queryNode.Timestamp = node.GetTimeStamp()
queryNode.Routes = node.GetRoutes()
queryNode.Routes = lib.Map(node.GetRoutes(), func(r mesh.Route) QueryRoute {
return QueryRoute{
Destination: r.GetDestination().String(),
HopCount: r.GetHopCount(),
Path: strings.Join(r.GetPath(), ","),
}
})
queryNode.Description = node.GetDescription()
queryNode.Alias = node.GetAlias()
queryNode.Services = node.GetServices()
queryNode.Type = node.GetType()
return queryNode
}
func NewJmesQuerier(manager *mesh.MeshManager) Querier {
func NewJmesQuerier(manager mesh.MeshManager) Querier {
return &JmesQuerier{manager: manager}
}

View File

@ -7,40 +7,74 @@ import (
"strconv"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/rpc"
)
type IpcHandler struct {
Server *ctrlserver.MeshCtrlServer
ipAllocator ip.IPAllocator
Server ctrlserver.CtrlServer
}
func getOverrideConfiguration(args *ipc.WireGuardArgs) conf.WgConfiguration {
overrideConf := conf.WgConfiguration{}
if args.Role != "" {
role := conf.NodeType(args.Role)
overrideConf.Role = &role
}
if args.Endpoint != "" {
overrideConf.Endpoint = &args.Endpoint
}
if args.KeepAliveWg != 0 {
keepAliveWg := args.KeepAliveWg
overrideConf.KeepAliveWg = &keepAliveWg
}
overrideConf.AdvertiseRoutes = &args.AdvertiseRoutes
overrideConf.AdvertiseDefaultRoute = &args.AdvertiseDefaultRoute
return overrideConf
}
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
meshId, err := n.Server.MeshManager.CreateMesh(args.IfName, args.WgPort)
overrideConf := getOverrideConfiguration(&args.WgArgs)
if overrideConf.Role != nil && *overrideConf.Role == conf.CLIENT_ROLE {
return fmt.Errorf("cannot create a mesh with no public endpoint")
}
meshId, err := n.Server.GetMeshManager().CreateMesh(&mesh.CreateMeshParams{
Port: args.WgArgs.WgPort,
Conf: &overrideConf,
})
if err != nil {
return err
}
err = n.Server.MeshManager.AddSelf(&mesh.AddSelfParams{
err = n.Server.GetMeshManager().AddSelf(&mesh.AddSelfParams{
MeshId: meshId,
WgPort: args.WgPort,
Endpoint: args.Endpoint,
WgPort: args.WgArgs.WgPort,
Endpoint: args.WgArgs.Endpoint,
})
if err != nil {
return err
}
*reply = meshId
return err
}
func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
meshNames := make([]string, len(n.Server.MeshManager.Meshes))
meshNames := make([]string, len(n.Server.GetMeshManager().GetMeshes()))
i := 0
for meshId, _ := range n.Server.MeshManager.Meshes {
for meshId := range n.Server.GetMeshManager().GetMeshes() {
meshNames[i] = meshId
i++
}
@ -50,7 +84,9 @@ func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
}
func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
peerConnection, err := n.Server.ConnectionManager.GetConnection(args.IpAdress)
overrideConf := getOverrideConfiguration(&args.WgArgs)
peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAdress)
if err != nil {
return err
@ -68,7 +104,9 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
configuration := n.Server.GetConfiguration()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(configuration.Timeout))
defer cancel()
meshReply, err := c.GetMesh(ctx, &rpc.GetMeshRequest{MeshId: args.MeshId})
@ -77,21 +115,21 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
return err
}
err = n.Server.MeshManager.AddMesh(&mesh.AddMeshParams{
err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
MeshId: args.MeshId,
DevName: args.IfName,
WgPort: args.Port,
WgPort: args.WgArgs.WgPort,
MeshBytes: meshReply.Mesh,
Conf: &overrideConf,
})
if err != nil {
return err
}
err = n.Server.MeshManager.AddSelf(&mesh.AddSelfParams{
err = n.Server.GetMeshManager().AddSelf(&mesh.AddSelfParams{
MeshId: args.MeshId,
WgPort: args.Port,
Endpoint: args.Endpoint,
WgPort: args.WgArgs.WgPort,
Endpoint: args.WgArgs.Endpoint,
})
if err != nil {
@ -104,7 +142,7 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
// LeaveMesh leaves a mesh network
func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error {
err := n.Server.MeshManager.LeaveMesh(meshId)
err := n.Server.GetMeshManager().LeaveMesh(meshId)
if err == nil {
*reply = fmt.Sprintf("Left Mesh %s", meshId)
@ -114,36 +152,29 @@ func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error {
}
func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
mesh := n.Server.MeshManager.GetMesh(meshId)
meshSnapshot, err := mesh.GetMesh()
theMesh := n.Server.GetMeshManager().GetMesh(meshId)
if theMesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
meshSnapshot, err := theMesh.GetMesh()
if err != nil {
return err
}
if mesh == nil {
if theMesh == nil {
return errors.New("mesh does not exist")
}
nodes := make([]ctrlserver.MeshNode, len(meshSnapshot.GetNodes()))
i := 0
for _, node := range meshSnapshot.GetNodes() {
pubKey, _ := node.GetPublicKey()
node := ctrlserver.NewCtrlNode(theMesh, node)
if err != nil {
return err
}
node := ctrlserver.MeshNode{
HostEndpoint: node.GetHostEndpoint(),
WgEndpoint: node.GetWgEndpoint(),
PublicKey: pubKey.String(),
WgHost: node.GetWgHost().String(),
Timestamp: node.GetTimeStamp(),
Routes: node.GetRoutes(),
}
nodes[i] = node
nodes[i] = *node
i += 1
}
@ -151,20 +182,8 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
return nil
}
func (n *IpcHandler) EnableInterface(meshId string, reply *string) error {
err := n.Server.MeshManager.EnableInterface(meshId)
if err != nil {
*reply = err.Error()
return err
}
*reply = "up"
return nil
}
func (n *IpcHandler) GetDOT(meshId string, reply *string) error {
g := mesh.NewMeshDotConverter(n.Server.MeshManager)
g := mesh.NewMeshDotConverter(n.Server.GetMeshManager())
result, err := g.Generate(meshId)
@ -177,7 +196,7 @@ func (n *IpcHandler) GetDOT(meshId string, reply *string) error {
}
func (n *IpcHandler) Query(params ipc.QueryMesh, reply *string) error {
queryResponse, err := n.Server.Querier.Query(params.MeshId, params.Query)
queryResponse, err := n.Server.GetQuerier().Query(params.MeshId, params.Query)
if err != nil {
return err
@ -188,7 +207,7 @@ func (n *IpcHandler) Query(params ipc.QueryMesh, reply *string) error {
}
func (n *IpcHandler) PutDescription(description string, reply *string) error {
err := n.Server.MeshManager.SetDescription(description)
err := n.Server.GetMeshManager().SetDescription(description)
if err != nil {
return err
@ -198,8 +217,41 @@ func (n *IpcHandler) PutDescription(description string, reply *string) error {
return nil
}
func (n *IpcHandler) PutAlias(alias string, reply *string) error {
err := n.Server.GetMeshManager().SetAlias(alias)
if err != nil {
return err
}
*reply = fmt.Sprintf("Set alias to %s", alias)
return nil
}
func (n *IpcHandler) PutService(service ipc.PutServiceArgs, reply *string) error {
err := n.Server.GetMeshManager().SetService(service.Service, service.Value)
if err != nil {
return err
}
*reply = "success"
return nil
}
func (n *IpcHandler) DeleteService(service string, reply *string) error {
err := n.Server.GetMeshManager().RemoveService(service)
if err != nil {
return err
}
*reply = "success"
return nil
}
type RobinIpcParams struct {
CtrlServer *ctrlserver.MeshCtrlServer
CtrlServer ctrlserver.CtrlServer
}
func NewRobinIpc(ipcParams RobinIpcParams) IpcHandler {

View File

@ -0,0 +1,73 @@
package robin
import (
"testing"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/mesh"
)
func getRequester() *IpcHandler {
return &IpcHandler{Server: ctrlserver.NewCtrlServerStub()}
}
func TestCreateMeshRepliesMeshId(t *testing.T) {
var reply string
requester := getRequester()
err := requester.CreateMesh(&ipc.NewMeshArgs{
IfName: "wg0",
WgPort: 5000,
Endpoint: "abc.com",
}, &reply)
if err != nil {
t.Error(err)
}
if len(reply) == 0 {
t.Fatalf(`reply should have been returned`)
}
}
func TestListMeshesNoMeshesListsEmpty(t *testing.T) {
var reply ipc.ListMeshReply
requester := getRequester()
err := requester.ListMeshes("", &reply)
if err != nil {
t.Error(err)
}
if len(reply.Meshes) != 0 {
t.Fatalf(`meshes should be empty`)
}
}
func TestListMeshesMeshesNotEmpty(t *testing.T) {
var reply ipc.ListMeshReply
requester := getRequester()
requester.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
MeshId: "tim123",
DevName: "wg0",
WgPort: 5000,
MeshBytes: make([]byte, 0),
})
err := requester.ListMeshes("", &reply)
if err != nil {
t.Error(err)
}
if len(reply.Meshes) != 1 {
t.Fatalf(`only only mesh exists`)
}
if reply.Meshes[0] != "tim123" {
t.Fatalf(`meshId was %s expected %s`, reply.Meshes[0], "tim123")
}
}

View File

@ -13,29 +13,6 @@ type WgRpc struct {
Server *ctrlserver.MeshCtrlServer
}
func nodeToRpcNode(node ctrlserver.MeshNode) *rpc.MeshNode {
return &rpc.MeshNode{
PublicKey: node.PublicKey,
WgEndpoint: node.WgEndpoint,
WgHost: node.WgHost,
Endpoint: node.HostEndpoint,
}
}
func nodesToRpcNodes(nodes map[string]ctrlserver.MeshNode) []*rpc.MeshNode {
n := len(nodes)
meshNodes := make([]*rpc.MeshNode, n)
var i int = 0
for _, v := range nodes {
meshNodes[i] = nodeToRpcNode(v)
i++
}
return meshNodes
}
func (m *WgRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.GetMeshReply, error) {
mesh := m.Server.MeshManager.GetMesh(request.MeshId)
@ -51,7 +28,3 @@ func (m *WgRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.
return &reply, nil
}
func (m *WgRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (*rpc.JoinMeshReply, error) {
return &rpc.JoinMeshReply{Success: true}, nil
}

View File

@ -0,0 +1 @@
package robin

View File

@ -1,22 +1,32 @@
package route
import (
"net"
"os/exec"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/lib"
"golang.org/x/sys/unix"
)
type RouteInstaller interface {
InstallRoutes(devName string, routes ...*net.IPNet) error
InstallRoutes(devName string, routes ...lib.Route) error
}
type RouteInstallerImpl struct{}
// InstallRoutes: installs a route into the routing table
func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...*net.IPNet) error {
func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...lib.Route) error {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return err
}
err = rtnl.DeleteRoutes(devName, unix.AF_INET6, routes...)
if err != nil {
return err
}
for _, route := range routes {
err := r.installRoute(devName, route)
err := rtnl.AddRoute(devName, route)
if err != nil {
return err
@ -26,22 +36,6 @@ func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...*net.IPNet)
return nil
}
// installRoute: installs a route into the linux table
func (r *RouteInstallerImpl) installRoute(devName string, route *net.IPNet) error {
// TODO: Find a library that automates this
cmd := exec.Command("/usr/bin/ip", "-6", "route", "add", route.String(), "dev", devName)
logging.Log.WriteInfof("%s %s", route.String(), devName)
if msg, err := cmd.CombinedOutput(); err != nil {
logging.Log.WriteErrorf(err.Error())
logging.Log.WriteErrorf(string(msg))
return err
}
return nil
}
func NewRouteInstaller() RouteInstaller {
return &RouteInstallerImpl{}
}

View File

@ -20,77 +20,6 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type MeshNode struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
PublicKey string `protobuf:"bytes,1,opt,name=publicKey,proto3" json:"publicKey,omitempty"`
WgEndpoint string `protobuf:"bytes,2,opt,name=wgEndpoint,proto3" json:"wgEndpoint,omitempty"`
Endpoint string `protobuf:"bytes,3,opt,name=endpoint,proto3" json:"endpoint,omitempty"`
WgHost string `protobuf:"bytes,4,opt,name=wgHost,proto3" json:"wgHost,omitempty"`
}
func (x *MeshNode) Reset() {
*x = MeshNode{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *MeshNode) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*MeshNode) ProtoMessage() {}
func (x *MeshNode) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use MeshNode.ProtoReflect.Descriptor instead.
func (*MeshNode) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{0}
}
func (x *MeshNode) GetPublicKey() string {
if x != nil {
return x.PublicKey
}
return ""
}
func (x *MeshNode) GetWgEndpoint() string {
if x != nil {
return x.WgEndpoint
}
return ""
}
func (x *MeshNode) GetEndpoint() string {
if x != nil {
return x.Endpoint
}
return ""
}
func (x *MeshNode) GetWgHost() string {
if x != nil {
return x.WgHost
}
return ""
}
type GetMeshRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
@ -102,7 +31,7 @@ type GetMeshRequest struct {
func (x *GetMeshRequest) Reset() {
*x = GetMeshRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1]
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@ -115,7 +44,7 @@ func (x *GetMeshRequest) String() string {
func (*GetMeshRequest) ProtoMessage() {}
func (x *GetMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1]
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@ -128,7 +57,7 @@ func (x *GetMeshRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetMeshRequest.ProtoReflect.Descriptor instead.
func (*GetMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{1}
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{0}
}
func (x *GetMeshRequest) GetMeshId() string {
@ -149,7 +78,7 @@ type GetMeshReply struct {
func (x *GetMeshReply) Reset() {
*x = GetMeshReply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2]
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@ -162,7 +91,7 @@ func (x *GetMeshReply) String() string {
func (*GetMeshReply) ProtoMessage() {}
func (x *GetMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2]
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@ -175,7 +104,7 @@ func (x *GetMeshReply) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetMeshReply.ProtoReflect.Descriptor instead.
func (*GetMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{2}
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{1}
}
func (x *GetMeshReply) GetMesh() []byte {
@ -185,145 +114,24 @@ func (x *GetMeshReply) GetMesh() []byte {
return nil
}
type JoinMeshRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Changes []byte `protobuf:"bytes,1,opt,name=changes,proto3" json:"changes,omitempty"`
MeshId string `protobuf:"bytes,2,opt,name=meshId,proto3" json:"meshId,omitempty"`
}
func (x *JoinMeshRequest) Reset() {
*x = JoinMeshRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *JoinMeshRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JoinMeshRequest) ProtoMessage() {}
func (x *JoinMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JoinMeshRequest.ProtoReflect.Descriptor instead.
func (*JoinMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{3}
}
func (x *JoinMeshRequest) GetChanges() []byte {
if x != nil {
return x.Changes
}
return nil
}
func (x *JoinMeshRequest) GetMeshId() string {
if x != nil {
return x.MeshId
}
return ""
}
type JoinMeshReply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"`
}
func (x *JoinMeshReply) Reset() {
*x = JoinMeshReply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *JoinMeshReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JoinMeshReply) ProtoMessage() {}
func (x *JoinMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JoinMeshReply.ProtoReflect.Descriptor instead.
func (*JoinMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{4}
}
func (x *JoinMeshReply) GetSuccess() bool {
if x != nil {
return x.Success
}
return false
}
var File_pkg_grpc_ctrlserver_ctrlserver_proto protoreflect.FileDescriptor
var file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc = []byte{
0x0a, 0x24, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73,
0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72,
0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73,
0x22, 0x7c, 0x0a, 0x08, 0x4d, 0x65, 0x73, 0x68, 0x4e, 0x6f, 0x64, 0x65, 0x12, 0x1c, 0x0a, 0x09,
0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52,
0x09, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x77, 0x67,
0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a,
0x77, 0x67, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e,
0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x6e,
0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x77, 0x67, 0x48, 0x6f, 0x73, 0x74,
0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x77, 0x67, 0x48, 0x6f, 0x73, 0x74, 0x22, 0x28,
0x0a, 0x0e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09,
0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x22, 0x22, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x4d,
0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x65, 0x73, 0x68,
0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x22, 0x43, 0x0a, 0x0f,
0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12,
0x18, 0x0a, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c,
0x52, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73,
0x68, 0x49, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49,
0x64, 0x22, 0x29, 0x0a, 0x0d, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70,
0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20,
0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x32, 0x91, 0x01, 0x0a,
0x0e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x74, 0x72, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12,
0x3d, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x18, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e,
0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x40,
0x0a, 0x08, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x19, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73,
0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00,
0x42, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x33,
0x22, 0x28, 0x0a, 0x0e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65,
0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01,
0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x22, 0x22, 0x0a, 0x0c, 0x47, 0x65,
0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x65,
0x73, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x32, 0x4f,
0x0a, 0x0e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x74, 0x72, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72,
0x12, 0x3d, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x18, 0x2e, 0x72, 0x70,
0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73,
0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x42,
0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x33,
}
var (
@ -338,21 +146,16 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP() []byte {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescData
}
var file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
var file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_pkg_grpc_ctrlserver_ctrlserver_proto_goTypes = []interface{}{
(*MeshNode)(nil), // 0: rpctypes.MeshNode
(*GetMeshRequest)(nil), // 1: rpctypes.GetMeshRequest
(*GetMeshReply)(nil), // 2: rpctypes.GetMeshReply
(*JoinMeshRequest)(nil), // 3: rpctypes.JoinMeshRequest
(*JoinMeshReply)(nil), // 4: rpctypes.JoinMeshReply
(*GetMeshRequest)(nil), // 0: rpctypes.GetMeshRequest
(*GetMeshReply)(nil), // 1: rpctypes.GetMeshReply
}
var file_pkg_grpc_ctrlserver_ctrlserver_proto_depIdxs = []int32{
1, // 0: rpctypes.MeshCtrlServer.GetMesh:input_type -> rpctypes.GetMeshRequest
3, // 1: rpctypes.MeshCtrlServer.JoinMesh:input_type -> rpctypes.JoinMeshRequest
2, // 2: rpctypes.MeshCtrlServer.GetMesh:output_type -> rpctypes.GetMeshReply
4, // 3: rpctypes.MeshCtrlServer.JoinMesh:output_type -> rpctypes.JoinMeshReply
2, // [2:4] is the sub-list for method output_type
0, // [0:2] is the sub-list for method input_type
0, // 0: rpctypes.MeshCtrlServer.GetMesh:input_type -> rpctypes.GetMeshRequest
1, // 1: rpctypes.MeshCtrlServer.GetMesh:output_type -> rpctypes.GetMeshReply
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
@ -365,18 +168,6 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
}
if !protoimpl.UnsafeEnabled {
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*MeshNode); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetMeshRequest); i {
case 0:
return &v.state
@ -388,7 +179,7 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
return nil
}
}
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetMeshReply); i {
case 0:
return &v.state
@ -400,30 +191,6 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
return nil
}
}
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*JoinMeshRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*JoinMeshReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
@ -431,7 +198,7 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc,
NumEnums: 0,
NumMessages: 5,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},

View File

@ -23,7 +23,6 @@ const _ = grpc.SupportPackageIsVersion7
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type MeshCtrlServerClient interface {
GetMesh(ctx context.Context, in *GetMeshRequest, opts ...grpc.CallOption) (*GetMeshReply, error)
JoinMesh(ctx context.Context, in *JoinMeshRequest, opts ...grpc.CallOption) (*JoinMeshReply, error)
}
type meshCtrlServerClient struct {
@ -43,21 +42,11 @@ func (c *meshCtrlServerClient) GetMesh(ctx context.Context, in *GetMeshRequest,
return out, nil
}
func (c *meshCtrlServerClient) JoinMesh(ctx context.Context, in *JoinMeshRequest, opts ...grpc.CallOption) (*JoinMeshReply, error) {
out := new(JoinMeshReply)
err := c.cc.Invoke(ctx, "/rpctypes.MeshCtrlServer/JoinMesh", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// MeshCtrlServerServer is the server API for MeshCtrlServer service.
// All implementations must embed UnimplementedMeshCtrlServerServer
// for forward compatibility
type MeshCtrlServerServer interface {
GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error)
JoinMesh(context.Context, *JoinMeshRequest) (*JoinMeshReply, error)
mustEmbedUnimplementedMeshCtrlServerServer()
}
@ -68,9 +57,6 @@ type UnimplementedMeshCtrlServerServer struct {
func (UnimplementedMeshCtrlServerServer) GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetMesh not implemented")
}
func (UnimplementedMeshCtrlServerServer) JoinMesh(context.Context, *JoinMeshRequest) (*JoinMeshReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method JoinMesh not implemented")
}
func (UnimplementedMeshCtrlServerServer) mustEmbedUnimplementedMeshCtrlServerServer() {}
// UnsafeMeshCtrlServerServer may be embedded to opt out of forward compatibility for this service.
@ -102,24 +88,6 @@ func _MeshCtrlServer_GetMesh_Handler(srv interface{}, ctx context.Context, dec f
return interceptor(ctx, in, info, handler)
}
func _MeshCtrlServer_JoinMesh_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(JoinMeshRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(MeshCtrlServerServer).JoinMesh(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/rpctypes.MeshCtrlServer/JoinMesh",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(MeshCtrlServerServer).JoinMesh(ctx, req.(*JoinMeshRequest))
}
return interceptor(ctx, in, info, handler)
}
// MeshCtrlServer_ServiceDesc is the grpc.ServiceDesc for MeshCtrlServer service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@ -131,10 +99,6 @@ var MeshCtrlServer_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetMesh",
Handler: _MeshCtrlServer_GetMesh_Handler,
},
{
MethodName: "JoinMesh",
Handler: _MeshCtrlServer_JoinMesh_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "pkg/grpc/ctrlserver/ctrlserver.proto",

View File

@ -2,8 +2,9 @@ package sync
import (
"errors"
"fmt"
"io"
"math/rand"
"sync"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
@ -13,112 +14,181 @@ import (
"github.com/tim-beatham/wgmesh/pkg/mesh"
)
// Syncer: picks random nodes from the mesh
// Syncer: picks random nodes from the meshs
type Syncer interface {
Sync(meshId string) error
SyncMeshes() error
}
type SyncerImpl struct {
manager *mesh.MeshManager
manager mesh.MeshManager
requester SyncRequester
infectionCount int
syncCount int
cluster conn.ConnCluster
conf *conf.WgMeshConfiguration
conf *conf.DaemonConfiguration
lastSync map[string]uint64
}
// Sync: Sync random nodes
func (s *SyncerImpl) Sync(meshId string) error {
logging.Log.WriteInfof("UPDATING WG CONF")
s.manager.ApplyConfig()
// Self can be nil if the node is removed
self, _ := s.manager.GetSelf(meshId)
if !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
correspondingMesh := s.manager.GetMesh(meshId)
correspondingMesh.Prune()
if self != nil && self.GetType() == conf.PEER_ROLE && !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
logging.Log.WriteInfof("No changes for %s", meshId)
// If not synchronised in certain pull from random neighbour
if uint64(time.Now().Unix())-s.lastSync[meshId] > 20 {
return s.Pull(meshId)
}
return nil
}
theMesh := s.manager.GetMesh(meshId)
before := time.Now()
s.manager.GetRouteManager().UpdateRoutes()
if theMesh == nil {
return errors.New("the provided mesh does not exist")
publicKey := s.manager.GetPublicKey()
logging.Log.WriteInfof(publicKey.String())
nodeNames := correspondingMesh.GetPeers()
if self != nil {
nodeNames = lib.Filter(nodeNames, func(s string) bool {
return s != mesh.NodeID(self)
})
}
snapshot, err := theMesh.GetMesh()
var gossipNodes []string
// Clients always pings its peer for configuration
if self != nil && self.GetType() == conf.CLIENT_ROLE {
keyFunc := lib.HashString
bucketFunc := lib.HashString
neighbour := lib.ConsistentHash(nodeNames, publicKey.String(), keyFunc, bucketFunc)
gossipNodes = make([]string, 1)
gossipNodes[0] = neighbour
} else {
neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
gossipNodes = lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
gossipNodes[len(gossipNodes)-1] = s.cluster.GetInterCluster(nodeNames, publicKey.String())
}
}
var succeeded bool = false
// Do this synchronously to conserve bandwidth
for _, node := range gossipNodes {
correspondingPeer := s.manager.GetNode(meshId, node)
if correspondingPeer == nil {
logging.Log.WriteErrorf("node %s does not exist", node)
}
err := s.requester.SyncMesh(meshId, correspondingPeer)
if err == nil || err == io.EOF {
succeeded = true
} else {
// If the synchronisation operation has failed them mark a gravestone
// preventing the peer from being re-contacted until it has updated
// itself
s.manager.GetMesh(meshId).Mark(node)
}
}
s.syncCount++
logging.Log.WriteInfof("SYNC TIME: %v", time.Since(before))
logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount)
s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount)
if !succeeded {
// If could not gossip with anyone then repeat.
s.infectionCount++
}
s.manager.GetMesh(meshId).SaveChanges()
s.lastSync[meshId] = uint64(time.Now().Unix())
logging.Log.WriteInfof("UPDATING WG CONF")
err := s.manager.ApplyConfig()
if err != nil {
logging.Log.WriteInfof("Failed to update config %w", err)
}
return nil
}
// Pull one node in the cluster, if there has not been message dissemination
// in a certain period of time pull a random node within the cluster
func (s *SyncerImpl) Pull(meshId string) error {
mesh := s.manager.GetMesh(meshId)
self, err := s.manager.GetSelf(meshId)
if err != nil {
return err
}
nodes := snapshot.GetNodes()
pubKey, _ := self.GetPublicKey()
if len(nodes) <= 1 {
if mesh == nil {
return errors.New("mesh is nil, invalid operation")
}
peers := mesh.GetPeers()
neighbours := s.cluster.GetNeighbours(peers, pubKey.String())
neighbour := lib.RandomSubsetOfLength(neighbours, 1)
if len(neighbour) == 0 {
logging.Log.WriteInfof("no neighbours")
return nil
}
excludedNodes := map[string]struct{}{
s.manager.HostParameters.HostEndpoint: {},
}
meshNodes := lib.MapValuesWithExclude(nodes, excludedNodes)
logging.Log.WriteInfof("PULLING from node %s", neighbour[0])
getNames := func(node mesh.MeshNode) string {
return node.GetHostEndpoint()
pullNode, err := mesh.GetNode(neighbour[0])
if err != nil || pullNode == nil {
return fmt.Errorf("node %s does not exist in the mesh", neighbour[0])
}
nodeNames := lib.Map(meshNodes, getNames)
err = s.requester.SyncMesh(meshId, pullNode)
neighbours := s.cluster.GetNeighbours(nodeNames, s.manager.HostParameters.HostEndpoint)
randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
for _, node := range randomSubset {
logging.Log.WriteInfof("Random node: %s", node)
}
before := time.Now()
if len(meshNodes) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
logging.Log.WriteInfof("Sending to random cluster")
interCluster := s.cluster.GetInterCluster(nodeNames, s.manager.HostParameters.HostEndpoint)
randomSubset = append(randomSubset, interCluster)
}
var waitGroup sync.WaitGroup
for index := range randomSubset {
waitGroup.Add(1)
go func(i int) error {
defer waitGroup.Done()
err := s.requester.SyncMesh(meshId, randomSubset[i])
if err == nil || err == io.EOF {
s.lastSync[meshId] = uint64(time.Now().Unix())
} else {
return err
}(index)
}
waitGroup.Wait()
s.syncCount++
logging.Log.WriteInfof("SYNC TIME: %v", time.Now().Sub(before))
logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount)
s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount)
return nil
}
// SyncMeshes: Sync all meshes
func (s *SyncerImpl) SyncMeshes() error {
for meshId, _ := range s.manager.Meshes {
for meshId := range s.manager.GetMeshes() {
err := s.Sync(meshId)
if err != nil {
return err
logging.Log.WriteErrorf(err.Error())
}
}
return nil
}
func NewSyncer(m *mesh.MeshManager, conf *conf.WgMeshConfiguration, r SyncRequester) Syncer {
func NewSyncer(m mesh.MeshManager, conf *conf.DaemonConfiguration, r SyncRequester) Syncer {
cluster, _ := conn.NewConnCluster(conf.ClusterSize)
return &SyncerImpl{
manager: m,
@ -126,5 +196,6 @@ func NewSyncer(m *mesh.MeshManager, conf *conf.WgMeshConfiguration, r SyncReques
requester: r,
infectionCount: 0,
syncCount: 0,
cluster: cluster}
cluster: cluster,
lastSync: make(map[string]uint64)}
}

View File

@ -14,32 +14,28 @@ type SyncErrorHandler interface {
// SyncErrorHandlerImpl Is an implementation of the SyncErrorHandler
type SyncErrorHandlerImpl struct {
meshManager *mesh.MeshManager
meshManager mesh.MeshManager
}
func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint string) bool {
func (s *SyncErrorHandlerImpl) handleFailed(meshId string, nodeId string) bool {
mesh := s.meshManager.GetMesh(meshId)
if mesh == nil {
return false
}
mesh.Mark(nodeId)
return true
}
func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error) bool {
func (s *SyncErrorHandlerImpl) Handle(meshId string, nodeId string, err error) bool {
errStatus, _ := status.FromError(err)
logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message())
switch errStatus.Code() {
case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound:
return s.incrementFailedCount(meshId, endpoint)
return s.handleFailed(meshId, nodeId)
}
return false
}
func NewSyncErrorHandler(m *mesh.MeshManager) SyncErrorHandler {
func NewSyncErrorHandler(m mesh.MeshManager) SyncErrorHandler {
return &SyncErrorHandlerImpl{meshManager: m}
}

View File

@ -15,7 +15,7 @@ import (
// SyncRequester: coordinates the syncing of meshes
type SyncRequester interface {
GetMesh(meshId string, ifName string, port int, endPoint string) error
SyncMesh(meshid string, endPoint string) error
SyncMesh(meshid string, meshNode mesh.MeshNode) error
}
type SyncRequesterImpl struct {
@ -50,15 +50,14 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endP
err = s.server.MeshManager.AddMesh(&mesh.AddMeshParams{
MeshId: meshId,
DevName: ifName,
WgPort: port,
MeshBytes: reply.Mesh,
})
return err
}
func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error {
ok := s.errorHdlr.Handle(meshId, endpoint, err)
func (s *SyncRequesterImpl) handleErr(meshId, pubKey string, err error) error {
ok := s.errorHdlr.Handle(meshId, pubKey, err)
if ok {
return nil
@ -68,7 +67,10 @@ func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error
}
// SyncMesh: Proactively send a sync request to the other mesh
func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) error {
endpoint := meshNode.GetHostEndpoint()
pubKey, _ := meshNode.GetPublicKey()
peerConnection, err := s.server.ConnectionManager.GetConnection(endpoint)
if err != nil {
@ -89,20 +91,22 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
c := rpc.NewSyncServiceClient(client)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
syncTimeOut := float64(s.server.Conf.SyncRate) * float64(time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(syncTimeOut))
defer cancel()
err = syncMesh(mesh, ctx, c)
err = s.syncMesh(mesh, ctx, c)
if err != nil {
return s.handleErr(meshId, endpoint, err)
return s.handleErr(meshId, pubKey.String(), err)
}
logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId)
return nil
}
func syncMesh(mesh mesh.MeshProvider, ctx context.Context, client rpc.SyncServiceClient) error {
func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context, client rpc.SyncServiceClient) error {
stream, err := client.SyncMesh(ctx)
syncer := mesh.GetSyncer()

View File

@ -1,55 +1,18 @@
package sync
import (
"time"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
// SyncScheduler: Loops through all nodes in the mesh and runs a schedule to
// sync each event
type SyncScheduler interface {
Run() error
Stop() error
}
// SyncSchedulerImpl scheduler for sync scheduling
type SyncSchedulerImpl struct {
quit chan struct{}
server *ctrlserver.MeshCtrlServer
syncer Syncer
}
// Run implements SyncScheduler.
func (s *SyncSchedulerImpl) Run() error {
ticker := time.NewTicker(time.Duration(s.server.Conf.SyncRate) * time.Second)
quit := make(chan struct{})
s.quit = quit
for {
select {
case <-ticker.C:
err := s.syncer.SyncMeshes()
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
break
case <-quit:
break
}
}
}
// Stop implements SyncScheduler.
func (s *SyncSchedulerImpl) Stop() error {
close(s.quit)
func syncFunction(syncer Syncer) lib.TimerFunc {
return func() error {
syncer.SyncMeshes()
return nil
}
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester) SyncScheduler {
syncer := NewSyncer(s.MeshManager, s.Conf, syncRequester)
return &SyncSchedulerImpl{server: s, syncer: syncer}
}
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester, syncer Syncer) *lib.Timer {
return lib.NewTimer(syncFunction(syncer), s.Conf.SyncRate)
}

View File

@ -64,11 +64,11 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error
syncer = mesh.GetSyncer()
} else if meshId != in.MeshId {
return errors.New("Differing MeshIDs")
return errors.New("differing meshids")
}
if syncer == nil {
return errors.New("Syncer should not be nil")
return errors.New("syncer should not be nil")
}
msg, moreMessages := syncer.GenerateMessage()

15
pkg/timers/timers.go Normal file
View File

@ -0,0 +1,15 @@
package timer
import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
)
func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
timerFunc := func() error {
logging.Log.WriteInfof("Updated Timestamp")
return ctrlServer.MeshManager.UpdateTimeStamp()
}
return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime)
}

View File

@ -1,48 +0,0 @@
package timestamp
import (
"time"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
)
type TimestampScheduler interface {
Run() error
Stop() error
}
type TimeStampSchedulerImpl struct {
meshManager *mesh.MeshManager
updateRate int
quit chan struct{}
}
func (s *TimeStampSchedulerImpl) Run() error {
ticker := time.NewTicker(time.Duration(s.updateRate) * time.Second)
s.quit = make(chan struct{})
for {
select {
case <-ticker.C:
err := s.meshManager.UpdateTimeStamp()
if err != nil {
logging.Log.WriteErrorf("Update Timestamp Error: %s", err.Error())
}
case <-s.quit:
break
}
}
}
func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) TimestampScheduler {
return &TimeStampSchedulerImpl{meshManager: ctrlServer.MeshManager, updateRate: ctrlServer.Conf.KeepAliveRate}
}
func (s *TimeStampSchedulerImpl) Stop() error {
close(s.quit)
return nil
}

15
pkg/wg/stubs.go Normal file
View File

@ -0,0 +1,15 @@
package wg
type WgInterfaceManipulatorStub struct{}
func (i *WgInterfaceManipulatorStub) CreateInterface(port int) (string, error) {
return "", nil
}
func (i *WgInterfaceManipulatorStub) AddAddress(ifName string, addr string) error {
return nil
}
func (i *WgInterfaceManipulatorStub) RemoveInterface(ifName string) error {
return nil
}

View File

@ -1,5 +1,7 @@
package wg
import "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
type WgError struct {
msg string
}
@ -8,15 +10,11 @@ func (m *WgError) Error() string {
return m.msg
}
type CreateInterfaceParams struct {
IfName string
Port int
}
type WgInterfaceManipulator interface {
// CreateInterface creates a WireGuard interface
CreateInterface(params *CreateInterfaceParams) error
// Enable interface enables the given interface with
// the IP. It overrides the IP at the interface
EnableInterface(ifName string, ip string) error
CreateInterface(port int, privateKey *wgtypes.Key) (string, error)
// AddAddress adds an address to the given interface name
AddAddress(ifName string, addr string) error
// RemoveInterface removes the specified interface
RemoveInterface(ifName string) error
}

View File

@ -1,110 +1,91 @@
package wg
import (
"errors"
"crypto"
"crypto/rand"
"fmt"
"net"
"os/exec"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// createInterface uses ip link to create an interface. If the interface exists
// it returns an error
func createInterface(ifName string) error {
_, err := net.InterfaceByName(ifName)
if err == nil {
err = flushInterface(ifName)
return err
}
// Check if the interface exists
cmd := exec.Command("/usr/bin/ip", "link", "add", "dev", ifName, "type", "wireguard")
if err := cmd.Run(); err != nil {
return err
}
return nil
}
type WgInterfaceManipulatorImpl struct {
client *wgctrl.Client
}
func (m *WgInterfaceManipulatorImpl) CreateInterface(params *CreateInterfaceParams) error {
err := createInterface(params.IfName)
const hashLength = 6
// CreateInterface creates a WireGuard interface
func (m *WgInterfaceManipulatorImpl) CreateInterface(port int, privKey *wgtypes.Key) (string, error) {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return err
return "", fmt.Errorf("failed to access link: %w", err)
}
defer rtnl.Close()
randomBuf := make([]byte, 32)
_, err = rand.Read(randomBuf)
if err != nil {
return "", err
}
privateKey, err := wgtypes.GeneratePrivateKey()
md5 := crypto.MD5.New().Sum(randomBuf)
md5Str := fmt.Sprintf("wg%x", md5)[:hashLength]
err = rtnl.CreateLink(md5Str)
if err != nil {
return err
return "", fmt.Errorf("failed to create link: %w", err)
}
var cfg wgtypes.Config = wgtypes.Config{
PrivateKey: &privateKey,
ListenPort: &params.Port,
PrivateKey: privKey,
ListenPort: &port,
}
m.client.ConfigureDevice(params.IfName, cfg)
return nil
}
// flushInterface flushes the specified interface
func flushInterface(ifName string) error {
_, err := net.InterfaceByName(ifName)
err = m.client.ConfigureDevice(md5Str, cfg)
if err != nil {
return &WgError{msg: fmt.Sprintf("Interface %s does not exist cannot flush", ifName)}
m.RemoveInterface(md5Str)
return "", fmt.Errorf("failed to configure dev: %w", err)
}
cmd := exec.Command("/usr/bin/ip", "addr", "flush", "dev", ifName)
if err := cmd.Run(); err != nil {
logging.Log.WriteErrorf(fmt.Sprintf("%s error flushing interface %s", err.Error(), ifName))
return &WgError{msg: fmt.Sprintf("Failed to flush interface %s", ifName)}
logging.Log.WriteInfof("ip link set up dev %s type wireguard", md5Str)
return md5Str, nil
}
return nil
}
// EnableInterface flushes the interface and sets the ip address of the
// interface
func (m *WgInterfaceManipulatorImpl) EnableInterface(ifName string, ip string) error {
if len(ifName) == 0 {
return errors.New("ifName not provided")
}
err := flushInterface(ifName)
// Add an address to the given interface
func (m *WgInterfaceManipulatorImpl) AddAddress(ifName string, addr string) error {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return err
return fmt.Errorf("failed to create config: %w", err)
}
defer rtnl.Close()
cmd := exec.Command("/usr/bin/ip", "link", "set", "up", "dev", ifName)
if err := cmd.Run(); err != nil {
return err
}
hostIp, _, err := net.ParseCIDR(ip)
err = rtnl.AddAddress(ifName, addr)
if err != nil {
err = fmt.Errorf("failed to add address: %w", err)
}
return err
}
cmd = exec.Command("/usr/bin/ip", "addr", "add", hostIp.String()+"/64", "dev", ifName)
// RemoveInterface implements WgInterfaceManipulator.
func (*WgInterfaceManipulatorImpl) RemoveInterface(ifName string) error {
rtnl, err := lib.NewRtNetlinkConfig()
if err := cmd.Run(); err != nil {
return err
if err != nil {
return fmt.Errorf("failed to create config: %w", err)
}
return nil
defer rtnl.Close()
return rtnl.DeleteLink(ifName)
}
func NewWgInterfaceManipulator(client *wgctrl.Client) WgInterfaceManipulator {

View File

@ -0,0 +1,92 @@
// Package to convert an IPV6 addres into 8 words
package what8words
import (
"bufio"
"bytes"
"fmt"
"net"
"os"
"strings"
)
type What8Words struct {
words []string
}
// Convert implements What8Words.
func (w *What8Words) Convert(ipStr string) (string, error) {
ip, ipNet, err := net.ParseCIDR(ipStr)
if err != nil {
return "", err
}
ip16 := ip.To16()
if ip16 == nil {
return "", fmt.Errorf("cannot convert ip to 16 representation")
}
representation := make([]string, 7)
for i := 2; i <= net.IPv6len-2; i += 2 {
word1 := w.words[ip16[i]]
word2 := w.words[ip16[i+1]]
representation[i/2-1] = fmt.Sprintf("%s-%s", word1, word2)
}
prefixSize, _ := ipNet.Mask.Size()
return strings.Join(representation[:prefixSize/16-1], "."), nil
}
// Convert implements What8Words.
func (w *What8Words) ConvertIdentifier(ipStr string) (string, error) {
ip, err := w.Convert(ipStr)
if err != nil {
return "", err
}
constituents := strings.Split(ip, ".")
return strings.Join(constituents[3:], "."), nil
}
func NewWhat8Words(pathToWords string) (*What8Words, error) {
words, err := ReadWords(pathToWords)
if err != nil {
return nil, err
}
return &What8Words{words: words}, nil
}
// ReadWords reads the what 8 words txt file
func ReadWords(wordFile string) ([]string, error) {
f, err := os.ReadFile(wordFile)
if err != nil {
return nil, err
}
words := make([]string, 257)
reader := bufio.NewScanner(bytes.NewReader(f))
counter := 0
for reader.Scan() && counter <= len(words) {
text := reader.Text()
words[counter] = text
counter++
if reader.Err() != nil {
return nil, reader.Err()
}
}
return words, nil
}

1
smegmesh-web Submodule

Submodule smegmesh-web added at c1128bcd98