Compare commits

...

140 Commits

Author SHA1 Message Date
9a30f4d5cb Submitting 2024-01-05 18:22:05 +00:00
f647c1b806 81-seperate-processes
Prep for submission
2024-01-05 16:59:02 +00:00
a55dadf088 81-seperate-synchronisation-into-independent-procs
- Neaten code
2024-01-05 12:59:13 +00:00
0ec5156e59 81-procs
- fixed issue where route not deleting if mesh only one
2024-01-05 00:14:25 +00:00
2b73d241b6 81-serparate-procs
- nil dereference again
2024-01-04 22:29:30 +00:00
69b1790bb6 81-processes
- issue with client client traversal
2024-01-04 22:08:14 +00:00
4a92743880 81-seperate-sync
- build error
2024-01-04 21:48:54 +00:00
038393052c 81-seperate-synchronisation-into-independent-proc
- build error
2024-01-04 21:47:29 +00:00
5efff2314b 81-separate-synchronisation-into-independent-process
- nil dereference when no joins
2024-01-04 21:45:28 +00:00
1f8d229076 81-seperate-synchronisation-into-independent-process
- nil dereference due to concurrency issues (the method shouldn't be
  concurrent)
2024-01-04 21:16:33 +00:00
a0e7a4a644 81-seperateprocesses-into-independent-processes
- Fixed errors
2024-01-04 13:15:29 +00:00
f9b8b85ec3 81-seperate-synchronisation
- Removed authentication.proto
2024-01-04 13:12:33 +00:00
59d8ae4334 81-seperate-synchronisation
- More code comments
2024-01-04 13:12:07 +00:00
02dfd73e08 81-seperate-synchronisation-into-independent
- Separated synchronisation calls into independent processes
- Commented code for submission
2024-01-04 13:10:08 +00:00
9818645299 Merge pull request #82 from tim-beatham/bugfix-node-not-leving
bugfix-node-not-leaving
2024-01-04 00:24:58 +00:00
1f0914e2df bugfix-node-not-leaving
- Add lock when perform synchronisation on concurrent access
2024-01-04 00:23:20 +00:00
efb40d65de Merge pull request #80 from tim-beatham/bugfix-node-not-leving
main
2024-01-02 20:32:09 +00:00
27e00196cd main
- Not waiting in the waitgroup
2024-01-02 20:31:24 +00:00
4543205703 Merge pull request #79 from tim-beatham/bugfix-node-not-leving
main
2024-01-02 20:21:27 +00:00
dea6f1a22d main
- error in code invalid check for nil
2024-01-02 20:19:34 +00:00
4d19da6727 Merge pull request #78 from tim-beatham/bugfix-node-not-leving
main
2024-01-02 20:12:10 +00:00
913de57568 main
- Fixed bug
2024-01-02 20:11:11 +00:00
8a5673e303 Merge pull request #77 from tim-beatham/bugfix-node-not-leving
bugfix node not leaving
2024-01-02 19:43:04 +00:00
ce829114b1 bugfix
- on synchornisation node is not leaving mesh
2024-01-02 19:41:20 +00:00
05cc287e31 Merge pull request #76 from tim-beatham/74-perform-dad
- Fixing DNS error
2024-01-02 00:16:45 +00:00
cd844ff46e - Fixing DNS error 2024-01-02 00:15:23 +00:00
4b9406a920 Merge pull request #75 from tim-beatham/74-perform-dad
74-perform-dad
2024-01-02 00:14:37 +00:00
d0b1913796 74-perform-dad
- Fixing nil pointer dereference
2024-01-02 00:13:04 +00:00
90cfe820d2 - Fixing errors with stale paths 2024-01-02 00:09:31 +00:00
8a49809855 74-perform-dad
- Adding go.sum to fix errors
2024-01-01 23:59:04 +00:00
dbc18bddc6 74-perform-dad
- Performing DAD to check if IPv6 address present before adding
  outselves to mesh
- Changing name from wgmesh to smegmesh
2024-01-01 23:55:50 +00:00
14f335af74 Merge pull request #73 from tim-beatham/72-pull-rate-in-configuration
72 pull rate in configuration
2023-12-31 14:26:34 +00:00
36e82dba47 72-pull-rate-in-configuration
- Refactored pull rate into the configuration
- code freeze so no more code changes
2023-12-31 14:25:06 +00:00
3cc87bc252 72-pull-rate-in-configuration
- Updated examples
2023-12-31 12:47:45 +00:00
a9ed7c0a20 72-pull-rate-in-configuration
- Removing libp2p reference
2023-12-31 12:47:45 +00:00
fd29af73e3 72-pull-rate-in-configuration
- Added pull rate to configuration (finally) so this can
be modified by an administrator.
2023-12-31 12:47:45 +00:00
9e1058e0f2 72-pull-rate-in-configuration
- Added the pull rate to the configuration file
2023-12-31 12:47:45 +00:00
c29eb197f3 Merge pull request #71 from tim-beatham/66-ipv6-address-not-conforming-to-spec
66 ipv6 address not conforming to spec
2023-12-30 22:26:53 +00:00
1a9d9d61ad 66-ipv6-address-not-conforming-to-spec
- Missing commit
2023-12-30 22:26:08 +00:00
6954608c32 66-ipv6-address-not-confirming-to-spec
- UUID is not random just a name generator needs changing to shortuuid
- When in multiple meshes there is no wait group
2023-12-30 22:24:43 +00:00
2e6aed6f93 main
- Fixing issue with nil pointer de-reference due to bad design of mesh
  manager.
- Going forward all references to GetSelf should be depracated. It
  introduces a race condition when leaving a mesh network
2023-12-30 00:44:57 +00:00
b0893a0b8e Merge pull request #69 from tim-beatham/60-unit-test-crdt-data-store
60-unit-test-crdt-data-store
2023-12-29 22:06:20 +00:00
e7d6055fa3 60-unit-test-crdt-data-store
Provided unit tests for datastore.go
And fixed unit tets failing by different way of providing CA
2023-12-29 22:05:05 +00:00
e0f3f116b9 main
- Stale serverConfig entry causing certificate authorities
to not become authorised
2023-12-29 19:54:08 +00:00
352648b7cb main
- Fixed problem where connection not removed on error
2023-12-29 11:12:40 +00:00
2d5df25b1d main
- If deadline exceeded error remove connection from
connection manager
2023-12-29 01:29:11 +00:00
cabe173831 main
Adding retry parameter
2023-12-29 01:10:26 +00:00
d2c8a52ec6 main
- Adding retry policy for mobility
2023-12-29 00:58:43 +00:00
bf53108384 main
- Bugfix, fix consistent hash problem where
if failure happens then causes panic
2023-12-28 23:24:38 +00:00
77aac5534b main
- Bugfix in client where "-" was attempted to be parsed as a UDP addr
2023-12-28 17:46:04 +00:00
58439fcd56 main
- Bugfix when keepalivewg is not set causes segmentation fault
- give keepalive a default value of 0 if not set
2023-12-28 17:32:54 +00:00
311a15363a Merge pull request #67 from tim-beatham/66-improve-graph-dot-tool
66 improve graph dot tool
2023-12-25 01:26:15 +00:00
255d3c8b39 66-improve-graph-dot-tool
- Showing services a node provides
- Showing all meshes not just one
- Showing the default route
2023-12-25 01:25:20 +00:00
41899c5831 66-improve-graph-dot-tool
Improving the graph dot tool so that it shows all
meshes
2023-12-25 01:10:11 +00:00
fe4ca66ff6 Merge pull request #65 from tim-beatham/64-2p-set-unit-test
64 2p set unit test
2023-12-22 23:58:59 +00:00
0b91ba744a 61-improve-unit-test-coverage
- Provided unit tests for g_map and 2p_map
2023-12-22 23:57:10 +00:00
67483c2a90 64-unit-test-two-phase-set
Provide unit tests for two phase set to make it more
transparent what exactly they are doing.
2023-12-22 23:57:10 +00:00
af26e81bd3 Merge pull request #63 from tim-beatham/61-improve-unit-testing-coverage
61-improve-unit-testing-coverage
2023-12-22 21:52:46 +00:00
0cc3141b58 61-improve-unit-testing-coverage
- Added missing files to commit
2023-12-22 21:49:47 +00:00
186acbe915 Merge pull request #62 from tim-beatham/61-improve-unit-testing-coverage
61-improve-unit-testing-coverage
2023-12-22 21:49:06 +00:00
ceb43a1db1 61-improve-unit-testing-coverage
- Got unit tests passing
- Improved manager unit tests
2023-12-22 21:47:56 +00:00
bed59f120f Merge pull request #60 from tim-beatham/59-error-when-peer-not-selected
59-error-when-peer-not-selected
2023-12-22 19:12:30 +00:00
8aab4e99d8 59-error-when-peer-not-selected
In the CLI when the peer is not selected
as the type throwing an error stating
either client or peer must be selected
2023-12-22 19:08:20 +00:00
cf4be1ccab Merge pull request #58 from tim-beatham/bugfix-pull-only
Bugfix pull only
2023-12-22 18:49:09 +00:00
6ed32f3a79 bugfix-push-pull
Organised groups as a tree so that there
isn't a limit to dissemination
2023-12-19 00:50:17 +00:00
b6199892f0 bugfix-pull-only
Bugfix with inter-cluster communication pull not working
2023-12-18 22:17:46 +00:00
ad22f04b0d bugfix-pull-only
After certain period of time if no changes have
occurred then pull
2023-12-18 20:45:56 +00:00
092d9a4af5 checking-latency-for-pull-only 2023-12-17 09:44:32 +00:00
19abf712a6 Fixing bug with nodes being removed 2023-12-12 12:45:41 +00:00
b296e1f45a Merge pull request #57 from tim-beatham/55-cli-option-for-peer-type
55-cli-optionifor-peer-type
2023-12-12 12:00:42 +00:00
2dc89d171b 55-cli-optionifor-peer-type
- Ability to specify WireGuard keepalive in the CLI formatter
- Ability to specify publicly routeable endpoint
- Ability to specify whether to advetise routes into the mesh,
and whether to advertise default routes.
2023-12-12 11:58:47 +00:00
13bea10638 main - bugfix
- Nodes not being removed when deleted because when node gossips again
  it is readded.
- Keep track of highest vector clock we have removed and used this as a
  mark for determining if something is stale.
2023-12-11 11:09:02 +00:00
3222d7e388 main - adding WireGuard stats to JSON objects
- Adding WireGuard stats through to IPC calls so that they can be used
by the API
2023-12-11 09:55:25 +00:00
1789d203f6 main - fix default routing being deleted
Default route keeps fluctuating on configuration
update.
2023-12-10 23:35:00 +00:00
a5074a536e main - BUGFIX
- segfault BUGFIX
2023-12-10 22:31:24 +00:00
acb90a5679 main - go.sum should be tracked into the git
- go.sum should be contained in the git history
2023-12-10 22:11:09 +00:00
27ec23f133 Merge pull request #54 from tim-beatham/53-run-commands-pre-up-and-post-down
53-run-commands-pre-up-and-post-down
2023-12-10 19:22:59 +00:00
fe14f63217 53-run-commands-pre-up-and-post-down
- Ability to run a command pre up and post down
- Ability to be a client in one mesh and a peer in the other
- Added dev card to specify different sync rate, keepalive rate per
  mesh.
2023-12-10 19:21:54 +00:00
4a8a39601f Merge pull request #52 from tim-beatham/51-bufix-not-removing-when-withdrawn
51-bugfix-routes-not-removing-when-withdrawn
2023-12-10 15:13:57 +00:00
1e263cc6a8 51-bugfix-routes-not-removing-when-withdrawn
- Routes are not being removed despite being withdrawn from the
configuration.
- Best path routes are not shared across interfaces
- Bug in consistent hashing wrong parameter passed caused by
refactorings.
2023-12-10 15:10:36 +00:00
dae9cd31a1 Merge pull request #50 from tim-beatham/50-give-client-ability-to-bridge-meshes
50-give-client-ability-to-bridge-meshes
2023-12-08 23:58:32 +00:00
f855f53fbf 50-give-client-ability-to-bridge-meshes
Client can act as a route bridging meshes. Cient send keepalives
to all of it's peers in the different meshes act as a bridge between
the meshes
2023-12-08 23:56:07 +00:00
52feb5767b Merge pull request #48 from tim-beatham/47-default-routing
47 default routing
2023-12-08 20:03:45 +00:00
815c4484ee 47-default-routing
Implemented default routing and improved size of gossip. Using 64 bit
hash funciton to identify vector.
2023-12-08 20:02:57 +00:00
0058c9f4c9 47-default-routing
Implementing default routing so that all traffic goes out of an
exit point.
2023-12-08 11:49:24 +00:00
92c0805275 Merge pull request #46 from tim-beatham/45-use-statistical-testing
45 use statistical testing
2023-12-07 18:20:25 +00:00
661fb0d54c 45-use-statistical-testing
Keepalive is based on per mesh and not per node.
Using total ordering mechanism similar to paxos to elect a leader
if leader doesn't update it's timestamp within 3 * keepAlive then
give the leader a gravestone and elect the next leader.
Leader is bassed on lexicographically ordered public key.
2023-12-07 18:18:13 +00:00
64885f1055 45-use-statistical-testing
Using statistical testing to test whether the node has failed.
2023-12-07 01:44:54 +00:00
2169f7796f Merge pull request #44 from tim-beatham/43-gravestones
43-use-gravestones
2023-12-06 22:46:05 +00:00
a3ceff019d 43-use-gravestones
Change of approach from keepalive to a noiseless protocol
2023-12-06 22:45:04 +00:00
b78d96986c Merge pull request #42 from tim-beatham/41-bugfix-fluctuating-ips
41 bugfix fluctuating ips
2023-12-06 14:37:14 +00:00
1b18d89c9f 41-bugfix-fluctuating-ips
Fluctuating ips creating hub and spoke.
2023-12-05 02:00:16 +00:00
245a2c5f58 41-bugfix-fluctuating-ips
If the node is a peer then add the client in the WG
configuration.
2023-12-04 17:40:24 +00:00
c40f7510b8 41-bugfix-fluctuating-ips
IPs of clients fluctuating because there isn't a strict order on
clients. Client's need to be processed before the peers.
2023-12-04 17:32:50 +00:00
78d748770c BUGIX Hash client by public key 2023-12-04 17:13:51 +00:00
0ff2a8eef9 BUGFIX: Allowed IPs fluctuating 2023-12-04 17:11:37 +00:00
fd7bd80485 BUGFIX
Don't get device each time it is an expensive operation.
2023-12-04 16:40:15 +00:00
3ef1b68ba5 BUGFIX: Hashing datastore to work out changes
Changed hashing implementation to work out if there are changes
in the data store
2023-11-30 15:58:26 +00:00
b9ba836ae3 Merge pull request #40 from tim-beatham/39-implement-two-phase-map
39-implement-two-phase-map
2023-11-30 02:03:36 +00:00
650901aba1 39-implement-two-phase-map
Implemented my own two phase map based on vector clocks
2023-11-30 02:02:38 +00:00
a82eab0686 Bugfix
Added replace peers so that deleted nodes are automatically removed
2023-11-28 14:43:55 +00:00
32e7e4c7df main
Bugfix. Fixed issue where consistent hashing was not working.
2023-11-28 14:42:09 +00:00
1fae0a6c2c Merge pull request #37 from tim-beatham/36-add-route-path-into-route-object
36-add-route-path-into-route-object
2023-11-27 21:03:56 +00:00
d8e156f13f 36-add-route-path-into-route-object
Added the route path into the route object so that we can
see what meshes packets are routed across.
2023-11-27 18:55:41 +00:00
3fca49a1c9 Merge pull request #35 from tim-beatham/34-fix-routing
34 fix routing
2023-11-27 16:05:06 +00:00
a2517a1e72 34-fix-routing
- Added mesh-to-mesh routing of hop count > 1
- If there is a tie-breaker with respect to the hop-count use consistent
hashing to determine the route to take based on the public key.
2023-11-27 15:56:30 +00:00
aef8b59f22 32-fix-routing
Flooding routes into other meshes a bit like BGP.
2023-11-25 03:15:58 +00:00
4030d17b41 Fixed routing issue 2023-11-24 17:49:06 +00:00
73db65660b Merge pull request #33 from tim-beatham/32-incorporate-dns
32-incorporate-dns
2023-11-24 15:05:40 +00:00
d1a74a7b95 32-incorporate-dns
Incorporated a DNS server. A DNS server can be run to resolve host
names.
2023-11-24 15:04:07 +00:00
f28ed8260d Merge pull request #30 from tim-beatham/29-only-ping-clients-who-have-updated-their-config
29-only-ping-clients-who-have-updated-their-config
2023-11-24 12:39:14 +00:00
2c406718df 29-only-ping-clients-who-have-updated-their-config
Only consider clients who have updated their config when synchronising
with peers. Consider a dead time where we don't have a handshake and
a prune time when we remove them from the WireGuard configuration.
2023-11-24 12:37:54 +00:00
11b003b549 Merge pull request #28 from tim-beatham/27-remove-client-grpc-endpoint
27-remove-client-grpc-endpoint
2023-11-24 12:08:42 +00:00
7be11dbaa3 27-remove-client-grpc-endpoint
Removed a client's grpc endpoint value. Client's aren't publicly
available so there is no need for a client's gRPC endpoint.
Also changed a node ID's to their public key. A node id's public
address is an issue for mobility of clients as their endpoint
is subject to change
2023-11-24 12:07:03 +00:00
e7ac8c5542 Only updating WireGuard config if node exists 2023-11-22 13:08:02 +00:00
09c64c4628 Fixed container file 2023-11-22 12:45:01 +00:00
2c4f18f52b Merge pull request #26 from tim-beatham/25-modify-code-to-use-public-api
25-modify-code-to-use-public-api
2023-11-22 10:42:48 +00:00
4c54022f63 25-modify-code-to-use-public-api
Modify the code to use a public IP address by default if none is
specified
2023-11-22 10:41:54 +00:00
bf0724f6e5 Merge pull request #24 from tim-beatham/24-keepalive-holepunch
24 keepalive holepunch
2023-11-21 21:28:16 +00:00
624bd6e921 24-keepalive
Persistent keep alive working
2023-11-21 21:26:31 +00:00
7b939e0468 24-keepalive-holepunch
Added the ability to hole punch NAT
2023-11-21 20:42:43 +00:00
6e201ebaf5 24-keepalive-holepunch
Nodes acting as peers and nodes acting as clients
2023-11-21 16:42:49 +00:00
06542da03c main
Fixed problems with timestamp not updating
2023-11-21 13:31:34 +00:00
0d63cd6624 main
Adding words.txt for what words
2023-11-20 18:12:58 +00:00
f13319cfc1 Merge pull request #22 from tim-beatham/21-phonetic-words-ipv6
21 phonetic words ipv6
2023-11-20 18:08:49 +00:00
95f4495b0b 21-phonetic-words-ipv6
Simple what 8 words implementation
2023-11-20 18:07:52 +00:00
330fa74ef4 IPv6 What 8 Words
what 8 words for ipv6 started
2023-11-20 15:22:32 +00:00
3e5b57e41f Merge pull request #20 from tim-beatham/19-hash-wg-interface
Hashing the WireGuard interface
2023-11-20 13:04:19 +00:00
b179cd3cf4 Hashing the WireGuard interface
Hashing the interface and using ephmeral ports so that the admin doesn't
choose an interface and port combination. An administrator can alteranatively
decide to provide port but this isn't critical.
2023-11-20 13:03:42 +00:00
8f211aa116 Merge pull request #18 from tim-beatham/26-performance-testing
Stubbing out WireGuard components
2023-11-20 11:29:37 +00:00
388153e706 Stubbing out WireGuard components
Stubbing our WireGuard components so that I can use docker/podman
network_mode=host. This is much more efficient than the docker/podman
userspace network.
2023-11-20 11:28:12 +00:00
023565d985 Merge pull request #17 from tim-beatham/25-ability-to-aliases
25 ability to aliases
2023-11-17 22:20:57 +00:00
36c264b38e 25-ability-aliases
Fixed unit tests failing
2023-11-17 22:18:53 +00:00
68db795f47 Ability to specify aliases
Ability to specify aliases that automatically append to /etc/hosts
2023-11-17 22:13:51 +00:00
f6160fe138 Adding aliases that automatically gets added 2023-11-17 19:13:20 +00:00
2c5289afb0 Merge pull request #16 from tim-beatham/15-add-rest-api
Developed a rest API
2023-11-15 12:57:05 +00:00
7199d07a76 Added smegmesh submodule 2023-11-13 10:46:52 +00:00
5f176e731f Developed a rest API 2023-11-13 10:44:14 +00:00
44f119b45c Updating examples 2023-11-08 09:19:24 +00:00
5215d5d54d Merge pull request #14 from tim-beatham/13-netlink-api
Removed interface manipulation via os.Exec into
2023-11-07 19:53:39 +00:00
95 changed files with 8414 additions and 2761 deletions

3
.gitmodules vendored Normal file
View File

@ -0,0 +1,3 @@
[submodule "smegmesh-web"]
path = smegmesh-web
url = git@github.com:tim-beatham/smegmesh-web.git

12
Containerfile Normal file
View File

@ -0,0 +1,12 @@
FROM docker.io/library/golang:bookworm
COPY ./ /wgmesh
RUN apt-get update && apt-get install -y \
wireguard \
wireguard-tools \
iproute2 \
iputils-ping \
tmux \
vim
WORKDIR /wgmesh
RUN go mod tidy
RUN go build -o /usr/local/bin ./...

1
Dockerfile Symbolic link
View File

@ -0,0 +1 @@
Containerfile

19
cmd/api/main.go Normal file
View File

@ -0,0 +1,19 @@
package main
import (
"log"
"github.com/tim-beatham/smegmesh/pkg/api"
)
func main() {
apiServer, err := api.NewSmegServer(api.ApiServerConf{
WordsFile: "./cmd/api/words.txt",
})
if err != nil {
log.Fatal(err.Error())
}
apiServer.Run(":8080")
}

257
cmd/api/words.txt Normal file
View File

@ -0,0 +1,257 @@
be
to
of
it
in
we
do
he
on
go
at
if
or
up
by
hi
the
and
you
not
for
but
say
get
she
one
all
can
out
who
now
see
way
how
lot
yes
use
any
day
try
put
let
why
new
off
big
too
ask
man
bit
end
may
own
run
pay
job
old
kid
bad
few
ago
far
buy
set
guy
car
sit
war
win
yet
top
law
cut
low
die
eat
age
hit
air
add
boy
act
tax
oil
eye
son
key
fun
dad
dog
arm
fly
box
gas
lie
hot
gun
per
art
red
fit
bed
fan
mix
mom
sex
bus
fix
bar
lay
ice
bet
bag
due
aid
tie
leg
ban
odd
cup
dry
cry
rid
pop
sir
cat
map
sad
sea
aim
sun
fat
row
egg
tea
god
wed
tip
ear
hat
net
ill
dig
fee
mad
gap
nor
bid
era
toy
sky
bin
owe
wet
tap
pro
ski
cow
pen
van
web
pot
sum
cap
log
pub
pig
joy
raw
rat
via
lip
two
six
ten
lab
ton
mid
bat
hip
gut
sin
non
rub
sub
par
pre
ray
cue
dye
fin
ion
neo
hey
wow
mum
bye
aye
jet
sue
pet
flu
cop
ooh
rip
spy
pie
bug
gum
wan
rap
nut
beg
pin
pit
jam
tag
fax
vet
fry
pad
lad
mud
bay
con
pan
gee
toe
dip
shy
gym
zoo
fox
bow
tin
hop
wee
kit
opt
vow
sew
cab
bee
rob
rig
yep
ego
rib
nod
hug
lap
ash
hum
dam
bum
yen
jar

18
cmd/dns/main.go Normal file
View File

@ -0,0 +1,18 @@
package main
import (
"log"
smegdns "github.com/tim-beatham/smegmesh/pkg/dns"
)
func main() {
server, err := smegdns.NewDns(53)
if err != nil {
log.Fatal(err.Error())
}
defer server.Close()
server.Listen()
}

409
cmd/smegctl/main.go Normal file
View File

@ -0,0 +1,409 @@
package main
import (
"fmt"
ipcRpc "net/rpc"
"os"
"github.com/akamensky/argparse"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
graph "github.com/tim-beatham/smegmesh/pkg/dot"
"github.com/tim-beatham/smegmesh/pkg/ipc"
logging "github.com/tim-beatham/smegmesh/pkg/log"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
type CreateMeshParams struct {
Client *ipcRpc.Client
Endpoint string
WgArgs ipc.WireGuardArgs
AdvertiseRoutes bool
AdvertiseDefault bool
}
func createMesh(client *ipc.SmegmeshIpc, args *ipc.NewMeshArgs) {
var reply string
err := client.CreateMesh(args, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func listMeshes(client *ipc.SmegmeshIpc) {
reply := new(ipc.ListMeshReply)
err := client.ListMeshes(reply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
return
}
for _, meshId := range reply.Meshes {
fmt.Println(meshId)
}
}
func joinMesh(client *ipc.SmegmeshIpc, args ipc.JoinMeshArgs) {
var reply string
err := client.JoinMesh(args, &reply)
if err != nil {
fmt.Println(err.Error())
}
fmt.Println(reply)
}
func leaveMesh(client *ipc.SmegmeshIpc, meshId string) {
var reply string
err := client.LeaveMesh(meshId, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func getGraph(client *ipc.SmegmeshIpc) {
listMeshesReply := new(ipc.ListMeshReply)
err := client.ListMeshes(listMeshesReply)
if err != nil {
fmt.Println(err.Error())
return
}
meshes := make(map[string][]ctrlserver.MeshNode)
for _, meshId := range listMeshesReply.Meshes {
var meshReply ipc.GetMeshReply
err := client.GetMesh(meshId, &meshReply)
if err != nil {
fmt.Println(err.Error())
return
}
meshes[meshId] = meshReply.Nodes
}
dotGenerator := graph.NewMeshGraphConverter(meshes)
dot, err := dotGenerator.Generate()
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(dot)
}
func queryMesh(client *ipc.SmegmeshIpc, meshId, query string) {
var reply string
args := ipc.QueryMesh{
MeshId: meshId,
Query: query,
}
err := client.Query(args, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func putDescription(client *ipc.SmegmeshIpc, meshId, description string) {
var reply string
err := client.PutDescription(ipc.PutDescriptionArgs{
MeshId: meshId,
Description: description,
}, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
// putAlias: puts an alias for the node
func putAlias(client *ipc.SmegmeshIpc, meshid, alias string) {
var reply string
err := client.PutAlias(ipc.PutAliasArgs{
MeshId: meshid,
Alias: alias,
}, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func setService(client *ipc.SmegmeshIpc, meshId, service, value string) {
var reply string
err := client.PutService(ipc.PutServiceArgs{
MeshId: meshId,
Service: service,
Value: value,
}, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func deleteService(client *ipc.SmegmeshIpc, meshId, service string) {
var reply string
err := client.DeleteService(ipc.DeleteServiceArgs{
MeshId: meshId,
Service: service,
}, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func main() {
parser := argparse.NewParser("smgctl",
"smegctl Manipulate WireGuard mesh networks")
newMeshCmd := parser.NewCommand("new-mesh", "Create a new mesh")
listMeshCmd := parser.NewCommand("list-meshes", "List meshes the node is connected to")
joinMeshCmd := parser.NewCommand("join-mesh", "Join a mesh network")
getGraphCmd := parser.NewCommand("get-graph", "Convert a mesh into DOT format")
leaveMeshCmd := parser.NewCommand("leave-mesh", "Leave a mesh network")
queryMeshCmd := parser.NewCommand("query-mesh", "Query a mesh network using JMESPath")
putDescriptionCmd := parser.NewCommand("put-description", "Place a description for the node")
putAliasCmd := parser.NewCommand("put-alias", "Place an alias for the node")
setServiceCmd := parser.NewCommand("set-service", "Place a service into your advertisements")
deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements")
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{
Default: 0,
Help: "WireGuard port to use to the interface. A default of 0 uses an unused ephmeral port.",
})
var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{
Help: "Publicly routeable endpoint to advertise within the mesh",
})
var newMeshRole *string = newMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{
Help: "Role in the mesh network. A value of peer means that the node is publicly routeable and thus considered" +
" in the gossip protocol. Client means that the node is not publicly routeable and is not a candidate in the gossip" +
" protocol",
})
var newMeshKeepAliveWg *int = newMeshCmd.Int("k", "KeepAliveWg", &argparse.Options{
Default: 0,
Help: "WireGuard KeepAlive value for NAT traversal and firewall holepunching",
})
var newMeshAdvertiseRoutes *bool = newMeshCmd.Flag("a", "advertise", &argparse.Options{
Help: "Advertise routes to other mesh network into the mesh",
})
var newMeshAdvertiseDefaults *bool = newMeshCmd.Flag("d", "defaults", &argparse.Options{
Help: "Advertise ::/0 into the mesh network",
})
var joinMeshId *string = joinMeshCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{
Required: true,
Help: "IP address of the bootstrapping node to join through",
})
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{
Help: "Publicly routeable endpoint to advertise within the mesh",
})
var joinMeshRole *string = joinMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{
Help: "Role in the mesh network. A value of peer means that the node is publicly routeable and thus considered" +
" in the gossip protocol. Client means that the node is not publicly routeable and is not a candidate in the gossip" +
" protocol",
})
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{
Default: 0,
Help: "WireGuard port to use to the interface. A default of 0 uses an unused ephmeral port.",
})
var joinMeshKeepAliveWg *int = joinMeshCmd.Int("k", "KeepAliveWg", &argparse.Options{
Default: 0,
Help: "WireGuard KeepAlive value for NAT traversal and firewall ho;lepunching",
})
var joinMeshAdvertiseRoutes *bool = joinMeshCmd.Flag("a", "advertise", &argparse.Options{
Help: "Advertise routes to other mesh network into the mesh",
})
var joinMeshAdvertiseDefaults *bool = joinMeshCmd.Flag("d", "defaults", &argparse.Options{
Help: "Advertise ::/0 into the mesh network",
})
var leaveMeshMeshId *string = leaveMeshCmd.String("m", "mesh", &argparse.Options{
Required: true,
Help: "MeshID of the mesh to leave",
})
var queryMeshMeshId *string = queryMeshCmd.String("m", "mesh", &argparse.Options{
Required: true,
Help: "MeshID of the mesh to query",
})
var queryMeshQuery *string = queryMeshCmd.String("q", "query", &argparse.Options{
Required: true,
Help: "JMESPath Query Of The Mesh Network To Query",
})
var description *string = putDescriptionCmd.String("d", "description", &argparse.Options{
Required: true,
Help: "Description of the node in the mesh",
})
var descriptionMeshId *string = putDescriptionCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
var aliasMeshId *string = putAliasCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
var alias *string = putAliasCmd.String("a", "alias", &argparse.Options{
Required: true,
Help: "Alias of the node to set can be used in DNS to lookup an IP address",
})
var serviceKey *string = setServiceCmd.String("s", "service", &argparse.Options{
Required: true,
Help: "Key of the service to advertise in the mesh network",
})
var serviceValue *string = setServiceCmd.String("v", "value", &argparse.Options{
Required: true,
Help: "Value of the service to advertise in the mesh network",
})
var serviceMeshId *string = setServiceCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{
Required: true,
Help: "Key of the service to remove",
})
var deleteServiceMeshid *string = deleteServiceCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
err := parser.Parse(os.Args)
if err != nil {
fmt.Print(parser.Usage(err))
return
}
client, err := ipc.NewClientIpc()
if err != nil {
panic(err)
}
if newMeshCmd.Happened() {
args := &ipc.NewMeshArgs{
WgArgs: ipc.WireGuardArgs{
Endpoint: *newMeshEndpoint,
Role: *newMeshRole,
WgPort: *newMeshPort,
KeepAliveWg: *newMeshKeepAliveWg,
AdvertiseDefaultRoute: *newMeshAdvertiseDefaults,
AdvertiseRoutes: *newMeshAdvertiseRoutes,
},
}
createMesh(client, args)
}
if listMeshCmd.Happened() {
listMeshes(client)
}
if joinMeshCmd.Happened() {
args := ipc.JoinMeshArgs{
IpAddress: *joinMeshIpAddress,
MeshId: *joinMeshId,
WgArgs: ipc.WireGuardArgs{
Endpoint: *joinMeshEndpoint,
Role: *joinMeshRole,
WgPort: *joinMeshPort,
KeepAliveWg: *joinMeshKeepAliveWg,
AdvertiseDefaultRoute: *joinMeshAdvertiseDefaults,
AdvertiseRoutes: *joinMeshAdvertiseRoutes,
},
}
joinMesh(client, args)
}
if getGraphCmd.Happened() {
getGraph(client)
}
if leaveMeshCmd.Happened() {
leaveMesh(client, *leaveMeshMeshId)
}
if queryMeshCmd.Happened() {
queryMesh(client, *queryMeshMeshId, *queryMeshQuery)
}
if putDescriptionCmd.Happened() {
putDescription(client, *descriptionMeshId, *description)
}
if putAliasCmd.Happened() {
putAlias(client, *aliasMeshId, *alias)
}
if setServiceCmd.Happened() {
setService(client, *serviceMeshId, *serviceKey, *serviceValue)
}
if deleteServiceCmd.Happened() {
deleteService(client, *deleteServiceMeshid, *deleteServiceKey)
}
}

View File

@ -1,33 +1,35 @@
package main package main
import ( import (
"log" "net/http"
_ "net/http/pprof"
"os" "os"
"os/signal" "os/signal"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
ctrlserver "github.com/tim-beatham/wgmesh/pkg/ctrlserver" ctrlserver "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/smegmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/robin"
"github.com/tim-beatham/wgmesh/pkg/robin" "github.com/tim-beatham/smegmesh/pkg/sync"
"github.com/tim-beatham/wgmesh/pkg/sync"
"github.com/tim-beatham/wgmesh/pkg/timestamp"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
) )
func main() { func main() {
if len(os.Args) != 2 { if len(os.Args) != 2 {
logging.Log.WriteErrorf("Need to provide configuration.yaml") logging.Log.WriteErrorf("Did not provide configuration")
return return
} }
conf, err := conf.ParseConfiguration(os.Args[1]) configuration, err := conf.ParseDaemonConfiguration(os.Args[1])
if err != nil { if err != nil {
logging.Log.WriteInfof("Could not parse configuration") logging.Log.WriteErrorf("Could not parse configuration: %s", err.Error())
return return
} }
logging.SetLogger(logging.NewLogrusLogger(configuration.LogLevel))
client, err := wgctrl.New() client, err := wgctrl.New()
if err != nil { if err != nil {
@ -35,23 +37,30 @@ func main() {
return return
} }
if configuration.Profile {
go func() {
http.ListenAndServe("localhost:6060", nil)
}()
}
var robinRpc robin.WgRpc var robinRpc robin.WgRpc
var robinIpc robin.IpcHandler var robinIpc robin.IpcHandler
var syncProvider sync.SyncServiceImpl var syncProvider sync.SyncServiceImpl
ctrlServerParams := ctrlserver.NewCtrlServerParams{ ctrlServerParams := ctrlserver.NewCtrlServerParams{
Conf: conf, Conf: configuration,
CtrlProvider: &robinRpc, CtrlProvider: &robinRpc,
SyncProvider: &syncProvider, SyncProvider: &syncProvider,
Client: client, Client: client,
} }
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams) ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
syncProvider.Server = ctrlServer if err != nil {
syncRequester := sync.NewSyncRequester(ctrlServer) panic(err)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester) }
timestampScheduler := timestamp.NewTimestampScheduler(ctrlServer)
pruneScheduler := mesh.NewPruner(ctrlServer.MeshManager, *conf) syncProvider.MeshManager = ctrlServer.MeshManager
robinIpcParams := robin.RobinIpcParams{ robinIpcParams := robin.RobinIpcParams{
CtrlServer: ctrlServer, CtrlServer: ctrlServer,
@ -65,17 +74,11 @@ func main() {
return return
} }
log.Println("Running IPC Handler") logging.Log.WriteInfof("running ipc handler")
go ipc.RunIpcHandler(&robinIpc) go ipc.RunIpcHandler(&robinIpc)
go syncScheduler.Run()
go timestampScheduler.Run()
go pruneScheduler.Run()
closeResources := func() { closeResources := func() {
logging.Log.WriteInfof("Closing resources") logging.Log.WriteInfof("closing resources")
syncScheduler.Stop()
timestampScheduler.Stop()
ctrlServer.Close() ctrlServer.Close()
client.Close() client.Close()
} }

View File

@ -1,271 +0,0 @@
package main
import (
"fmt"
ipcRpc "net/rpc"
"os"
"strings"
"time"
"github.com/akamensky/argparse"
"github.com/tim-beatham/wgmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
type CreateMeshParams struct {
Client *ipcRpc.Client
IfName string
WgPort int
Endpoint string
}
func createMesh(args *CreateMeshParams) string {
var reply string
newMeshParams := ipc.NewMeshArgs{
IfName: args.IfName,
WgPort: args.WgPort,
Endpoint: args.Endpoint,
}
err := args.Client.Call("IpcHandler.CreateMesh", &newMeshParams, &reply)
if err != nil {
return err.Error()
}
return reply
}
func listMeshes(client *ipcRpc.Client) {
reply := new(ipc.ListMeshReply)
err := client.Call("IpcHandler.ListMeshes", "", &reply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
return
}
for _, meshId := range reply.Meshes {
fmt.Println(meshId)
}
}
type JoinMeshParams struct {
Client *ipcRpc.Client
MeshId string
IpAddress string
IfName string
WgPort int
Endpoint string
}
func joinMesh(params *JoinMeshParams) string {
var reply string
args := ipc.JoinMeshArgs{
MeshId: params.MeshId,
IpAdress: params.IpAddress,
IfName: params.IfName,
Port: params.WgPort,
}
err := params.Client.Call("IpcHandler.JoinMesh", &args, &reply)
if err != nil {
return err.Error()
}
return reply
}
func getMesh(client *ipcRpc.Client, meshId string) {
reply := new(ipc.GetMeshReply)
err := client.Call("IpcHandler.GetMesh", &meshId, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
for _, node := range reply.Nodes {
fmt.Println("Public Key: " + node.PublicKey)
fmt.Println("Control Endpoint: " + node.HostEndpoint)
fmt.Println("WireGuard Endpoint: " + node.WgEndpoint)
fmt.Println("Wg IP: " + node.WgHost)
fmt.Println(fmt.Sprintf("Timestamp: %s", time.Unix(node.Timestamp, 0).String()))
advertiseRoutes := strings.Join(node.Routes, ",")
fmt.Printf("Routes: %s\n", advertiseRoutes)
fmt.Println("---")
}
}
func leaveMesh(client *ipcRpc.Client, meshId string) {
var reply string
err := client.Call("IpcHandler.LeaveMesh", &meshId, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func enableInterface(client *ipcRpc.Client, meshId string) {
var reply string
err := client.Call("IpcHandler.EnableInterface", &meshId, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func getGraph(client *ipcRpc.Client, meshId string) {
var reply string
err := client.Call("IpcHandler.GetDOT", &meshId, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func queryMesh(client *ipcRpc.Client, meshId, query string) {
var reply string
err := client.Call("IpcHandler.Query", &ipc.QueryMesh{MeshId: meshId, Query: query}, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
// putDescription: puts updates the description about the node to the meshes
func putDescription(client *ipcRpc.Client, description string) {
var reply string
err := client.Call("IpcHandler.PutDescription", &description, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func main() {
parser := argparse.NewParser("wg-mesh",
"wg-mesh Manipulate WireGuard meshes")
newMeshCmd := parser.NewCommand("new-mesh", "Create a new mesh")
listMeshCmd := parser.NewCommand("list-meshes", "List meshes the node is connected to")
joinMeshCmd := parser.NewCommand("join-mesh", "Join a mesh network")
// getMeshCmd := parser.NewCommand("get-mesh", "Get a mesh network")
enableInterfaceCmd := parser.NewCommand("enable-interface", "Enable A Specific Mesh Interface")
getGraphCmd := parser.NewCommand("get-graph", "Convert a mesh into DOT format")
leaveMeshCmd := parser.NewCommand("leave-mesh", "Leave a mesh network")
queryMeshCmd := parser.NewCommand("query-mesh", "Query a mesh network using JMESPath")
putDescriptionCmd := parser.NewCommand("put-description", "Place a description for the node")
var newMeshIfName *string = newMeshCmd.String("f", "ifname", &argparse.Options{Required: true})
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{})
var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true})
var joinMeshIfName *string = joinMeshCmd.String("f", "ifname", &argparse.Options{Required: true})
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{})
// var getMeshId *string = getMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true})
var getGraphMeshId *string = getGraphCmd.String("m", "mesh", &argparse.Options{Required: true})
var leaveMeshMeshId *string = leaveMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var queryMeshMeshId *string = queryMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var queryMeshQuery *string = queryMeshCmd.String("q", "query", &argparse.Options{Required: true})
var description *string = putDescriptionCmd.String("d", "description", &argparse.Options{Required: true})
err := parser.Parse(os.Args)
if err != nil {
fmt.Print(parser.Usage(err))
return
}
client, err := ipcRpc.DialHTTP("unix", SockAddr)
if err != nil {
fmt.Println(err.Error())
return
}
if newMeshCmd.Happened() {
fmt.Println(createMesh(&CreateMeshParams{
Client: client,
IfName: *newMeshIfName,
WgPort: *newMeshPort,
Endpoint: *newMeshEndpoint,
}))
}
if listMeshCmd.Happened() {
listMeshes(client)
}
if joinMeshCmd.Happened() {
fmt.Println(joinMesh(&JoinMeshParams{
Client: client,
IfName: *joinMeshIfName,
WgPort: *joinMeshPort,
IpAddress: *joinMeshIpAddress,
MeshId: *joinMeshId,
Endpoint: *joinMeshEndpoint,
}))
}
// if getMeshCmd.Happened() {
// getMesh(client, *getMeshId)
// }
if getGraphCmd.Happened() {
getGraph(client, *getGraphMeshId)
}
if enableInterfaceCmd.Happened() {
enableInterface(client, *enableInterfaceMeshId)
}
if leaveMeshCmd.Happened() {
leaveMesh(client, *leaveMeshMeshId)
}
if queryMeshCmd.Happened() {
queryMesh(client, *queryMeshMeshId, *queryMeshQuery)
}
if putDescriptionCmd.Happened() {
putDescription(client, *description)
}
}

View File

@ -10,5 +10,5 @@ syncRate: 1
interClusterChance: 0.15 interClusterChance: 0.15
branchRate: 3 branchRate: 3
infectionCount: 3 infectionCount: 3
keepAliveTime: 10 heartBeatTime: 10
pruneTime: 20 pruneTime: 20

View File

@ -0,0 +1,43 @@
version: '3'
networks:
net-1:
services:
wg-1:
image: localhost/smegmesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "smegd /shared/configuration.yaml"
sysctls:
- net.ipv6.conf.all.forwarding=1
wg-2:
image: localhost/smegmesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "smegd /shared/configuration.yaml"
sysctls:
- net.ipv6.conf.all.forwarding=1
wg-3:
image: localhost/smegmesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "smegd /shared/configuration.yaml"
sysctls:
- net.ipv6.conf.all.forwarding=1

View File

@ -0,0 +1,34 @@
# Paths to the certificates modify
# if not running from Smegmesh
certificatePath: "./cert/cert.pem"
privateKeyPath: "./cert/priv.pem"
caCertificatePath: "./cert/cacert.pem"
skipCertVerification: true
# timeout is the configured grpc timeout
timeout: 5
# gRPC port to run the solution
gRPCPort: 4000
# whether or not to run go profiler
profile: false
# stubWg: whether to install WireGuard configurations
# if true just tests the control plane
stubWg: false
heartbeatInterval: 60
branch: 3
pullInterval: 20
infectionCount: 3
interClusterChance: 0.15
syncInterval: 2
clusterSize: 64
logLevel: "info"
baseConfiguration:
# ipDiscovery: specifies how to find your IP address
ipDiscovery: "outgoing"
# alternative to ipDiscovery specify an actual endpoint yourself with publicEndpoint: "xxxx"
# role is the role that you are playing (peer | client)
# peers can only bootstrap meshes
role: "peer"
# advertise meshes to other meshes
advertiseRoute: true
# advertise default routes
advertiseDefaults: true

38
go.mod
View File

@ -1,13 +1,20 @@
module github.com/tim-beatham/wgmesh module github.com/tim-beatham/smegmesh
go 1.21.3 go 1.21.3
require ( require (
github.com/akamensky/argparse v1.4.0 github.com/akamensky/argparse v1.4.0
github.com/anandvarma/namegen v0.0.0-20230727084436-5197c6ea3255
github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9 github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9
github.com/gin-gonic/gin v1.9.1
github.com/go-playground/validator/v10 v10.16.0
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/jmespath/go-jmespath v0.4.0 github.com/jmespath/go-jmespath v0.4.0
github.com/jsimonetti/rtnetlink v1.3.5
github.com/lithammer/shortuuid v3.0.0+incompatible
github.com/miekg/dns v1.1.57
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
golang.org/x/sys v0.14.0
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
google.golang.org/grpc v1.58.1 google.golang.org/grpc v1.58.1
google.golang.org/protobuf v1.31.0 google.golang.org/protobuf v1.31.0
@ -15,19 +22,36 @@ require (
) )
require ( require (
github.com/bytedance/sonic v1.9.1 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/golang/protobuf v1.5.3 // indirect github.com/golang/protobuf v1.5.3 // indirect
github.com/google/go-cmp v0.5.9 // indirect github.com/google/go-cmp v0.5.9 // indirect
github.com/josharian/native v1.1.0 // indirect github.com/josharian/native v1.1.0 // indirect
github.com/jsimonetti/rtnetlink v1.3.5 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.5.0 // indirect github.com/mdlayher/socket v0.5.0 // indirect
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
golang.org/x/crypto v0.13.0 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
golang.org/x/net v0.15.0 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect
golang.org/x/sync v0.3.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
golang.org/x/sys v0.12.0 // indirect github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.14.0 // indirect
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 // indirect
golang.org/x/mod v0.12.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sync v0.4.0 // indirect
golang.org/x/text v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect
golang.org/x/tools v0.13.0 // indirect
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 // indirect golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect
) )

141
go.sum Normal file
View File

@ -0,0 +1,141 @@
github.com/akamensky/argparse v1.4.0 h1:YGzvsTqCvbEZhL8zZu2AiA5nq805NZh75JNj4ajn1xc=
github.com/akamensky/argparse v1.4.0/go.mod h1:S5kwC7IuDcEr5VeXtGPRVZ5o/FdhcMlQz4IZQuw64xA=
github.com/anandvarma/namegen v0.0.0-20230727084436-5197c6ea3255 h1:aIAyyj4XPrke9Tc/umbBCzP5SKX/CHf3dKrL/PhH2lo=
github.com/anandvarma/namegen v0.0.0-20230727084436-5197c6ea3255/go.mod h1:MFyILur9tG8PxaCXGZVr/2BOnHtRIgxYejYFZdWLxr0=
github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9 h1:+6JSfuxZgmURoIlGdnYnY/FLRGWGagLyiBjt/VLtwi4=
github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9/go.mod h1:6UxoDE+thWsISXK93pxaOuOfkcAfCvDbg0eAnFmxL5E=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/cilium/ebpf v0.11.0 h1:V8gS/bTCCjX9uUnkUFUpPsksM8n1lXBAvHcpiFk1X2Y=
github.com/cilium/ebpf v0.11.0/go.mod h1:WE7CZAnqOL2RouJ4f1uyNhqr2P4CCvXFIqdRDUgWsVs=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE=
github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/jsimonetti/rtnetlink v1.3.5 h1:hVlNQNRlLDGZz31gBPicsG7Q53rnlsz1l1Ix/9XlpVA=
github.com/jsimonetti/rtnetlink v1.3.5/go.mod h1:0LFedyiTkebnd43tE4YAkWGIq9jQphow4CcwxaT2Y00=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/lithammer/shortuuid v3.0.0+incompatible h1:NcD0xWW/MZYXEHa6ITy6kaXN5nwm/V115vj2YXfhS0w=
github.com/lithammer/shortuuid v3.0.0+incompatible/go.mod h1:FR74pbAuElzOUuenUHTK2Tciko1/vKuIKS9dSkDrA4w=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI=
github.com/mdlayher/socket v0.5.0/go.mod h1:WkcBFfvyG8QENs5+hfQPl1X6Jpd2yeLIYgrGFmJiJxI=
github.com/miekg/dns v1.1.57 h1:Jzi7ApEIzwEPLHWRcafCN9LZSBbqQpxjt/wpgvg7wcM=
github.com/miekg/dns v1.1.57/go.mod h1:uqRjCRUuEAA6qsOiJvDd+CFo/vW+y5WR6SNmHE55hZk=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug=
golang.org/x/exp v0.0.0-20230321023759-10a507213a29/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ=
golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 h1:EY138uSo1JYlDq+97u1FtcOUwPpIU6WL1Lkt7WpYjPA=
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 h1:bVf09lpb+OJbByTj913DRJioFFAjf/ZGxEz7MajTp2U=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98/go.mod h1:TUfxEVdsvPg18p6AslUXFoLdpED4oBnGwyqk3dV1XzM=
google.golang.org/grpc v1.58.1 h1:OL+Vz23DTtrrldqHK49FUOPHyY75rvFqJfXC84NYW58=
google.golang.org/grpc v1.58.1/go.mod h1:tgX3ZQDlNJGU96V6yHh1T/JeoBQ2TXdr43YbYSsCJk0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

295
pkg/api/apiserver.go Normal file
View File

@ -0,0 +1,295 @@
package api
import (
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/smegmesh/pkg/ipc"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/what8words"
)
// routesToApiRoute: convert the returned type to a JSON object
func (s *SmegServer) routeToApiRoute(meshNode ctrlserver.MeshNode) []Route {
routes := make([]Route, len(meshNode.Routes))
for index, route := range meshNode.Routes {
if route.Path == nil {
route.Path = make([]string, 0)
}
routes[index] = Route{
Prefix: route.Destination,
Path: route.Path,
}
}
return routes
}
// meshNodeToAPImeshNode: convert daemon node to a JSON node
func (s *SmegServer) meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode {
if meshNode.Routes == nil {
meshNode.Routes = make([]ctrlserver.MeshRoute, 0)
}
alias := meshNode.Alias
if alias == "" {
alias, _ = s.words.ConvertIdentifier(meshNode.WgHost)
}
return &SmegNode{
WgHost: meshNode.WgHost,
WgEndpoint: meshNode.WgEndpoint,
Endpoint: meshNode.HostEndpoint,
Timestamp: int(meshNode.Timestamp),
Description: meshNode.Description,
Routes: s.routeToApiRoute(meshNode),
PublicKey: meshNode.PublicKey,
Alias: alias,
Services: meshNode.Services,
Stats: SmegStats{
TotalTransmit: meshNode.Stats.TransmitBytes,
TotalReceived: meshNode.Stats.ReceivedBytes,
KeepAliveInterval: meshNode.Stats.PersistentKeepAliveInterval,
AllowedIps: meshNode.Stats.AllowedIPs,
},
}
}
// meshToAPIMesh: Convert daemon mesh network to a JSON mesh network
func (s *SmegServer) meshToAPIMesh(meshId string, nodes []ctrlserver.MeshNode) SmegMesh {
var smegMesh SmegMesh
smegMesh.MeshId = meshId
smegMesh.Nodes = make(map[string]SmegNode)
for _, node := range nodes {
smegMesh.Nodes[node.WgHost] = *s.meshNodeToAPIMeshNode(node)
}
return smegMesh
}
// putAlias: place an alias in the mesh
func (s *SmegServer) putAlias(meshId, alias string) error {
var reply string
return s.client.PutAlias(ipc.PutAliasArgs{
Alias: alias,
MeshId: meshId,
}, &reply)
}
func (s *SmegServer) putDescription(meshId, description string) error {
var reply string
return s.client.PutDescription(ipc.PutDescriptionArgs{
Description: description,
MeshId: meshId,
}, &reply)
}
// CreateMesh: creates a mesh network
func (s *SmegServer) CreateMesh(c *gin.Context) {
var createMesh CreateMeshRequest
if err := c.ShouldBindJSON(&createMesh); err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
fmt.Printf("%+v\n", createMesh)
ipcRequest := ipc.NewMeshArgs{
WgArgs: ipc.WireGuardArgs{
WgPort: createMesh.WgPort,
Role: createMesh.Role,
Endpoint: createMesh.PublicEndpoint,
AdvertiseRoutes: createMesh.AdvertiseRoutes,
AdvertiseDefaultRoute: createMesh.AdvertiseDefaults,
},
}
var reply string
err := s.client.CreateMesh(&ipcRequest, &reply)
if err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
if createMesh.Alias != "" {
s.putAlias(reply, createMesh.Alias)
}
if createMesh.Description != "" {
s.putDescription(reply, createMesh.Description)
}
c.JSON(http.StatusOK, &gin.H{
"meshid": reply,
})
}
// JoinMesh: joins a mesh network
func (s *SmegServer) JoinMesh(c *gin.Context) {
var joinMesh JoinMeshRequest
if err := c.ShouldBindJSON(&joinMesh); err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
ipcRequest := ipc.JoinMeshArgs{
MeshId: joinMesh.MeshId,
IpAddress: joinMesh.Bootstrap,
WgArgs: ipc.WireGuardArgs{
WgPort: joinMesh.WgPort,
Endpoint: joinMesh.PublicEndpoint,
Role: joinMesh.Role,
AdvertiseRoutes: joinMesh.AdvertiseRoutes,
AdvertiseDefaultRoute: joinMesh.AdvertiseDefaults,
},
}
var reply string
err := s.client.JoinMesh(ipcRequest, &reply)
if err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
if joinMesh.Alias != "" {
s.putAlias(reply, joinMesh.Alias)
}
if joinMesh.Description != "" {
s.putDescription(reply, joinMesh.Description)
}
c.JSON(http.StatusOK, &gin.H{
"status": "success",
})
}
// GetMesh: given a meshId returns the corresponding mesh
// network.
func (s *SmegServer) GetMesh(c *gin.Context) {
meshidParam := c.Param("meshid")
var meshid string = meshidParam
getMeshReply := new(ipc.GetMeshReply)
err := s.client.GetMesh(meshid, getMeshReply)
if err != nil {
c.JSON(http.StatusNotFound,
&gin.H{
"error": fmt.Sprintf("could not find mesh %s", meshidParam),
})
return
}
mesh := s.meshToAPIMesh(meshidParam, getMeshReply.Nodes)
c.JSON(http.StatusOK, mesh)
}
// GetMeshes: return all the mesh networks that the
// user is a part of
func (s *SmegServer) GetMeshes(c *gin.Context) {
listMeshesReply := new(ipc.ListMeshReply)
err := s.client.ListMeshes(listMeshesReply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
c.JSON(http.StatusBadRequest, nil)
return
}
meshes := make([]SmegMesh, 0)
for _, mesh := range listMeshesReply.Meshes {
getMeshReply := new(ipc.GetMeshReply)
err := s.client.GetMesh(mesh, getMeshReply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
c.JSON(http.StatusBadRequest, nil)
return
}
meshes = append(meshes, s.meshToAPIMesh(mesh, getMeshReply.Nodes))
}
c.JSON(http.StatusOK, meshes)
}
// Run: run the API server
func (s *SmegServer) Run(addr string) error {
logging.Log.WriteInfof("Running API server")
return s.router.Run(addr)
}
// NewSmegServer: creates an instance of a new API server
// returns an error if something went wrong
func NewSmegServer(conf ApiServerConf) (ApiServer, error) {
client, err := ipc.NewClientIpc()
if err != nil {
return nil, err
}
words, err := what8words.NewWhat8Words(conf.WordsFile)
if err != nil {
return nil, err
}
router := gin.Default()
router.Use(gin.LoggerWithConfig(gin.LoggerConfig{
Output: logging.Log.Writer(),
}))
smegServer := &SmegServer{
router: router,
client: client,
words: words,
}
v1 := router.Group("/api/v1")
{
meshes := v1.Group("/meshes")
{
meshes.GET("/", smegServer.GetMeshes)
}
mesh := v1.Group("/mesh")
{
mesh.GET("/:meshid", smegServer.GetMesh)
mesh.POST("/create", smegServer.CreateMesh)
mesh.POST("/join", smegServer.JoinMesh)
}
}
return smegServer, nil
}

129
pkg/api/types.go Normal file
View File

@ -0,0 +1,129 @@
package api
import (
"time"
"github.com/gin-gonic/gin"
"github.com/tim-beatham/smegmesh/pkg/ipc"
"github.com/tim-beatham/smegmesh/pkg/what8words"
)
// Route is an advertised route in the data store
type Route struct {
// Prefix is the advertised route prefix
Prefix string `json:"prefix"`
// Path is the hops the destination
Path []string `json:"path"`
}
// SmegStats is the WireGuard stats that the underlying host
// has sent to the peer
type SmegStats struct {
// TotalTransmit number of bytes sent to the peer
TotalTransmit int64 `json:"totalTransmit"`
// TotalReceived number of bytes received from the peer
TotalReceived int64 `json:"totalReceived"`
// KeepAliveInterval WireGuard keepalive interval that is sent to the host
KeepAliveInterval time.Duration `json:"keepaliveInterval"`
// AllowsIps is the allowed path to the destination
AllowedIps []string `json:"allowedIps"`
}
// SmegNode is a node in the mesh network
type SmegNode struct {
// Alias is the human readable name that the node is assocaited with
Alias string `json:"alias"`
// WgHost is the WireGuard IP address of the node. This is an IPv6
// address
WgHost string `json:"wgHost"`
// WgEndpoint is the physical endpoint of the host that packets
// are forwarded to
WgEndpoint string `json:"wgEndpoint"`
// Endpoint is the control plane endpoint of the host which
// grpc connections are to be sent along
Endpoint string `json:"endpoint"`
// Timestamp is the last time the signified it was alive.
// if the node is the leader this is evert heartBeatInterval
// otherwise this is the time the node joined the network
Timestamp int `json:"timestamp"`
// Description is the human readable description of the node
Description string `json:"description"`
// PublicKey is the WireGuard public key of the node
PublicKey string `json:"publicKey"`
// Routes is the routes that the node is advertising
Routes []Route `json:"routes"`
// Services is information about services that the node offers
Services map[string]string `json:"services"`
// Stats is the WireGuard stats of the node (if any)
Stats SmegStats `json:"stats"`
}
// SmegMesh encapsulates a single mesh in the API
type SmegMesh struct {
// MeshId is the mesh id of the network
MeshId string `json:"meshid"`
// Nodes is the nodes in the network keyed by their public
// key
Nodes map[string]SmegNode `json:"nodes"`
}
// CreateMeshRequest encapsulates a request to create a mesh network
type CreateMeshRequest struct {
// WgPort is the WireGuard to create the mesh in
WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"`
// Role is the role to take on in the mesh
Role string `json:"role" binding:"required,eq=client|eq=peer"`
// AdvertiseRoutes: advertise thi mesh to other meshes
AdvertiseRoutes bool `json:"advertiseRoutes"`
// AdvertiseDefaults: advertise an exit point
AdvertiseDefaults bool `json:"advertiseDefaults"`
// Alias: alias of the node in the mesh
Alias string `json:"alias"`
// Description: description of the node in the mesh
Description string `json:"description"`
// PublicEndpoint: an alternative public endpoint to advertise
PublicEndpoint string `json:"publicEndpoint"`
}
// JoinMeshRequests encapsulates a request to create a mesh network
type JoinMeshRequest struct {
// WgPort is the WireGuard port to run the service on
WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"`
// Bootstrap is a bootstrap node to use to join the network
Bootstrap string `json:"bootstrap" binding:"required"`
// MeshId is the ID of the mesh to join
MeshId string `json:"meshid" binding:"required"`
// Role is the role to take on in the mesh
Role string `json:"role" binding:"required,eq=client|eq=peer"`
// AdvertiseRoutes: advertise thi mesh to other meshes
AdvertiseRoutes bool `json:"advertiseRoutes"`
// AdvertiseDefaults: advertise an exit point
AdvertiseDefaults bool `json:"advertiseDefaults"`
// Alias: alias of the node in the mesh
Alias string `json:"alias"`
// Description: description of the node in the mesh
Description string `json:"description"`
// PublicEndpoint: an alternative public endpoint to advertise
PublicEndpoint string `json:"publicEndpoint"`
}
// ApiServerConf configuration to instantiate the API server
type ApiServerConf struct {
// WordsFile to use to map IP to words
WordsFile string
}
// SmegSever is the GIN api server that runs the service
type SmegServer struct {
// gin router to use
router *gin.Engine
// client to invoke operations
client *ipc.SmegmeshIpc
// what8words to use to convert IP to an alias
words *what8words.What8Words
}
// ApiSever absrtacts the API server
type ApiServer interface {
Run(addr string) error
}

View File

@ -1,32 +1,46 @@
package crdt // automerge: package is depracated and unused. Please refer to crdt
// for crdt operations in the mesh
package automerge
import ( import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"slices"
"strings" "strings"
"time" "time"
"github.com/automerge/automerge-go" "github.com/automerge/automerge-go"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// CrdtMeshManager manages nodes in the crdt mesh // CrdtMeshManager manage the CRDT datastore
type CrdtMeshManager struct { type CrdtMeshManager struct {
MeshId string // MeshID of the mesh the datastore represents
IfName string MeshId string
NodeId string // IfName: corresponding ifName
Client *wgctrl.Client IfName string
doc *automerge.Doc // Client: corresponding wireguard control client
Client *wgctrl.Client
// doc: autommerge document
doc *automerge.Doc
// LastHash: last hash that the changes were made to
LastHash automerge.ChangeHash LastHash automerge.ChangeHash
conf *conf.WgMeshConfiguration // conf: WireGuard configuration
conf *conf.WgConfiguration
// cache: stored cache of the list automerge document
// so that the store does not have to be repopulated each time
cache *MeshCrdt
// lastCachehash: hash of when the document was last changed
lastCacheHash automerge.ChangeHash
} }
// AddNode as a node to the datastore
func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) { func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
crdt, ok := node.(*MeshNodeCrdt) crdt, ok := node.(*MeshNodeCrdt)
@ -34,18 +48,88 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
panic("node must be of type *MeshNodeCrdt") panic("node must be of type *MeshNodeCrdt")
} }
crdt.Routes = make(map[string]interface{}) crdt.Routes = make(map[string]Route)
crdt.Services = make(map[string]string)
crdt.Timestamp = time.Now().Unix() crdt.Timestamp = time.Now().Unix()
c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt)
err := c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt)
if err != nil {
logging.Log.WriteInfof("error")
}
} }
// GetMesh(): Converts the document into a struct // isPeer: returns true if the given node has type peer
func (c *CrdtMeshManager) isPeer(nodeId string) bool {
node, err := c.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return false
}
nodeType, err := node.Map().Get("type")
if err != nil || nodeType.Kind() != automerge.KindStr {
return false
}
return nodeType.Str() == string(conf.PEER_ROLE)
}
// isAlive: checks that the node's configuration has been updated
// since the rquired keep alive time. Depracated no longer works
// due to changes in approach
func (c *CrdtMeshManager) isAlive(nodeId string) bool {
node, err := c.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return false
}
timestamp, err := node.Map().Get("timestamp")
if err != nil || timestamp.Kind() != automerge.KindInt64 {
return false
}
// return (time.Now().Unix() - keepAliveTime) < int64(c.conf.DeadTime)
return true
}
// GetPeers: get all the peers in the mesh
func (c *CrdtMeshManager) GetPeers() []string {
keys, _ := c.doc.Path("nodes").Map().Keys()
keys = lib.Filter(keys, func(publicKey string) bool {
return c.isPeer(publicKey) && c.isAlive(publicKey)
})
return keys
}
// GetMesh: Converts the document into a mesh network
func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) { func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
return automerge.As[*MeshCrdt](c.doc.Root()) changes, err := c.doc.Changes(c.lastCacheHash)
if err != nil {
return nil, err
}
if c.cache == nil || len(changes) > 0 {
c.lastCacheHash = c.LastHash
cache, err := automerge.As[*MeshCrdt](c.doc.Root())
if err != nil {
return nil, err
}
c.cache = cache
}
return c.cache, nil
} }
// GetMeshId returns the meshid of the mesh // GetMeshId: returns the meshid of the mesh
func (c *CrdtMeshManager) GetMeshId() string { func (c *CrdtMeshManager) GetMeshId() string {
return c.MeshId return c.MeshId
} }
@ -66,29 +150,42 @@ func (c *CrdtMeshManager) Load(bytes []byte) error {
return nil return nil
} }
// NewCrdtNodeManagerParams: params to instantiate a new automerge
// datastore
type NewCrdtNodeMangerParams struct { type NewCrdtNodeMangerParams struct {
MeshId string MeshId string
DevName string DevName string
Port int Port int
Conf conf.WgMeshConfiguration Conf *conf.WgConfiguration
Client *wgctrl.Client Client *wgctrl.Client
} }
// NewCrdtNodeManager: Create a new crdt node manager // NewCrdtNodeManager: Create a new automerge crdt data store
func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, error) { func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, error) {
var manager CrdtMeshManager var manager CrdtMeshManager
manager.MeshId = params.MeshId manager.MeshId = params.MeshId
manager.doc = automerge.New() manager.doc = automerge.New()
manager.IfName = params.DevName manager.IfName = params.DevName
manager.Client = params.Client manager.Client = params.Client
manager.conf = &params.Conf manager.conf = params.Conf
manager.cache = nil
return &manager, nil return &manager, nil
} }
// GetNode: returns a mesh node crdt.Close releases resources used by a Client. // NodeExists: returns true if the node exists other returns false
func (m *CrdtMeshManager) GetNode(endpoint string) (*MeshNodeCrdt, error) { func (m *CrdtMeshManager) NodeExists(key string) bool {
node, err := m.doc.Path("nodes").Map().Get(key)
return node.Kind() == automerge.KindMap && err == nil
}
// GetNode: gets a node from the mesh network.
func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
node, err := m.doc.Path("nodes").Map().Get(endpoint) node, err := m.doc.Path("nodes").Map().Get(endpoint)
if node.Kind() != automerge.KindMap {
return nil, fmt.Errorf("getnode: node is not a map")
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -102,10 +199,12 @@ func (m *CrdtMeshManager) GetNode(endpoint string) (*MeshNodeCrdt, error) {
return meshNode, nil return meshNode, nil
} }
// Length: returns the number of nodes in the store
func (m *CrdtMeshManager) Length() int { func (m *CrdtMeshManager) Length() int {
return m.doc.Path("nodes").Map().Len() return m.doc.Path("nodes").Map().Len()
} }
// GetDevice: get the underlying WireGuard device
func (m *CrdtMeshManager) GetDevice() (*wgtypes.Device, error) { func (m *CrdtMeshManager) GetDevice() (*wgtypes.Device, error) {
dev, err := m.Client.Device(m.IfName) dev, err := m.Client.Device(m.IfName)
@ -116,7 +215,7 @@ func (m *CrdtMeshManager) GetDevice() (*wgtypes.Device, error) {
return dev, nil return dev, nil
} }
// HasChanges returns true if we have changes since the last time we synced // HasChanges: returns true if there are changes since last time synchronised
func (m *CrdtMeshManager) HasChanges() bool { func (m *CrdtMeshManager) HasChanges() bool {
changes, err := m.doc.Changes(m.LastHash) changes, err := m.doc.Changes(m.LastHash)
@ -130,6 +229,7 @@ func (m *CrdtMeshManager) HasChanges() bool {
return len(changes) > 0 return len(changes) > 0
} }
// SaveChanges: save changes to the datastore
func (m *CrdtMeshManager) SaveChanges() { func (m *CrdtMeshManager) SaveChanges() {
hashes := m.doc.Heads() hashes := m.doc.Heads()
hash := hashes[len(hashes)-1] hash := hashes[len(hashes)-1]
@ -138,6 +238,7 @@ func (m *CrdtMeshManager) SaveChanges() {
m.LastHash = hash m.LastHash = hash
} }
// UpdateTimeStamp: updates the timestamp of the document
func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error { func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId) node, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -158,6 +259,7 @@ func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
return err return err
} }
// SetDescription: set the description of the given node
func (m *CrdtMeshManager) SetDescription(nodeId string, description string) error { func (m *CrdtMeshManager) SetDescription(nodeId string, description string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId) node, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -178,8 +280,78 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro
return err return err
} }
// SetAlias: set the alias of the given node
func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
return err
}
if node.Kind() != automerge.KindMap {
return fmt.Errorf("%s does not exist", nodeId)
}
err = node.Map().Set("alias", alias)
if err == nil {
logging.Log.WriteInfof("Updated Alias for %s to %s", nodeId, alias)
}
return err
}
// AddService: add a service to the given node
func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return fmt.Errorf("AddService: node %s does not exist", nodeId)
}
service, err := node.Map().Get("services")
if err != nil {
return err
}
if service.Kind() != automerge.KindMap {
return fmt.Errorf("AddService: services property does not exist in node")
}
err = service.Map().Set(key, value)
return err
}
// RemoveService: remove a service from a node
func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return fmt.Errorf("RemoveService: node %s does not exist", nodeId)
}
service, err := node.Map().Get("services")
if err != nil {
return err
}
if service.Kind() != automerge.KindMap {
return fmt.Errorf("services property does not exist")
}
err = service.Map().Delete(key)
if err != nil {
return fmt.Errorf("service %s does not exist", key)
}
return nil
}
// AddRoutes: adds routes to the specific nodeId // AddRoutes: adds routes to the specific nodeId
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error { func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
logging.Log.WriteInfof("Adding route to %s", nodeId) logging.Log.WriteInfof("Adding route to %s", nodeId)
@ -198,7 +370,32 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
} }
for _, route := range routes { for _, route := range routes {
err = routeMap.Map().Set(route, struct{}{}) prevRoute, err := routeMap.Map().Get(route.GetDestination().String())
if prevRoute.Kind() == automerge.KindVoid && err != nil {
path, err := prevRoute.Map().Get("path")
if err != nil {
return err
}
if path.Kind() != automerge.KindList {
return fmt.Errorf("path is not a list")
}
pathStr, err := automerge.As[[]string](path)
if err != nil {
return err
}
slices.Equal(route.GetPath(), pathStr)
}
err = routeMap.Map().Set(route.GetDestination().String(), Route{
Destination: route.GetDestination().String(),
Path: route.GetPath(),
})
if err != nil { if err != nil {
return err return err
@ -207,8 +404,86 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
return nil return nil
} }
// DeleteRoutes deletes the specified routes // getRoutes: get the routes that the given node is directly advertising
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error { func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
return nil, err
}
if nodeVal.Kind() != automerge.KindMap {
return nil, fmt.Errorf("node does not exist")
}
routeMap, err := nodeVal.Map().Get("routes")
if err != nil {
return nil, err
}
if routeMap.Kind() != automerge.KindMap {
return nil, fmt.Errorf("node %s is not a map", nodeId)
}
routes, err := automerge.As[map[string]Route](routeMap)
return lib.MapValues(routes), err
}
// GetRoutes: get all the routes that the node can see. The routes that the node
// can say may not be direct but cann also be indirect
func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) {
node, err := m.GetNode(targetNode)
if err != nil {
return nil, err
}
routes := make(map[string]mesh.Route)
// Add routes that the node directly has
for _, route := range node.GetRoutes() {
routes[route.GetDestination().String()] = route
}
// Work out the other routes in the mesh
for _, node := range m.GetPeers() {
nodeRoutes, err := m.getRoutes(node)
if err != nil {
return nil, err
}
for _, route := range nodeRoutes {
otherRoute, ok := routes[route.GetDestination().String()]
hopCount := route.GetHopCount()
if node != targetNode {
hopCount += 1
}
if !ok || route.GetHopCount()+1 < otherRoute.GetHopCount() {
routes[route.GetDestination().String()] = &Route{
Destination: route.GetDestination().String(),
Path: append(route.Path, m.GetMeshId()),
}
}
}
}
return routes, nil
}
// RemoveNode: removes a node from the datastore
func (m *CrdtMeshManager) RemoveNode(nodeId string) error {
err := m.doc.Path("nodes").Map().Delete(nodeId)
return err
}
// RemoveRoutes: withdraw all the routes the nodeID is advertising
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil { if err != nil {
@ -226,116 +501,114 @@ func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
} }
for _, route := range routes { for _, route := range routes {
err = routeMap.Map().Delete(route) err = routeMap.Map().Delete(route.GetDestination().String())
} }
return err return err
} }
// GetConfiguration: gets the configuration for this mesh network
func (m *CrdtMeshManager) GetConfiguration() *conf.WgConfiguration {
return m.conf
}
// Mark: mark the node as locally dead
func (m *CrdtMeshManager) Mark(nodeId string) {
}
// GetSyncer: get the bi-directionally syncer to synchronise the document
func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer { func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
return NewAutomergeSync(m) return NewAutomergeSync(m)
} }
func (m *CrdtMeshManager) Prune(pruneTime int) error { // Prune: prune all dead nodes
nodes, err := m.doc.Path("nodes").Get() func (m *CrdtMeshManager) Prune() error {
if err != nil {
return err
}
if nodes.Kind() != automerge.KindMap {
return errors.New("node must be a map")
}
values, err := nodes.Map().Values()
if err != nil {
return err
}
deletionNodes := make([]string, 0)
for nodeId, node := range values {
if node.Kind() != automerge.KindMap {
return errors.New("node must be a map")
}
nodeMap := node.Map()
timeStamp, err := nodeMap.Get("timestamp")
if err != nil {
return err
}
if timeStamp.Kind() != automerge.KindInt64 {
return errors.New("timestamp is not int64")
}
timeValue := timeStamp.Int64()
nowValue := time.Now().Unix()
if nowValue-timeValue >= int64(pruneTime) {
deletionNodes = append(deletionNodes, nodeId)
}
}
for _, node := range deletionNodes {
logging.Log.WriteInfof("Pruning %s", node)
nodes.Map().Delete(node)
}
return nil return nil
} }
// Compare: compare two mesh node for equality
func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int { func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int {
return strings.Compare(m1.PublicKey, m2.PublicKey) return strings.Compare(m1.PublicKey, m2.PublicKey)
} }
// GetHostEndpoint: get the ctrl endpoint of the host
func (m *MeshNodeCrdt) GetHostEndpoint() string { func (m *MeshNodeCrdt) GetHostEndpoint() string {
return m.HostEndpoint return m.HostEndpoint
} }
// GetPublicKey: get the public key of the node
func (m *MeshNodeCrdt) GetPublicKey() (wgtypes.Key, error) { func (m *MeshNodeCrdt) GetPublicKey() (wgtypes.Key, error) {
return wgtypes.ParseKey(m.PublicKey) return wgtypes.ParseKey(m.PublicKey)
} }
// GetWgEndpoint: get the outer WireGuard endpoint
func (m *MeshNodeCrdt) GetWgEndpoint() string { func (m *MeshNodeCrdt) GetWgEndpoint() string {
return m.WgEndpoint return m.WgEndpoint
} }
// GetWgHost: get the WireGuard IP address of the host
func (m *MeshNodeCrdt) GetWgHost() *net.IPNet { func (m *MeshNodeCrdt) GetWgHost() *net.IPNet {
_, ipnet, err := net.ParseCIDR(m.WgHost) _, ipnet, err := net.ParseCIDR(m.WgHost)
if err != nil { if err != nil {
logging.Log.WriteErrorf("Cannot parse WgHost %s", err.Error())
return nil return nil
} }
return ipnet return ipnet
} }
// GetTimeStamp: get timestamp if when the node was last updated
func (m *MeshNodeCrdt) GetTimeStamp() int64 { func (m *MeshNodeCrdt) GetTimeStamp() int64 {
return m.Timestamp return m.Timestamp
} }
func (m *MeshNodeCrdt) GetRoutes() []string { // GetRoutes: get all the routes advertised by the node
return lib.MapKeys(m.Routes) func (m *MeshNodeCrdt) GetRoutes() []mesh.Route {
return lib.Map(lib.MapValues(m.Routes), func(r Route) mesh.Route {
return &Route{
Destination: r.Destination,
Path: r.Path,
}
})
} }
// GetDescription: get the description of the node
func (m *MeshNodeCrdt) GetDescription() string { func (m *MeshNodeCrdt) GetDescription() string {
return m.Description return m.Description
} }
// GetIdentifier: get the iderntifier section of the ipv6 address
func (m *MeshNodeCrdt) GetIdentifier() string { func (m *MeshNodeCrdt) GetIdentifier() string {
ipv6 := m.WgHost[:len(m.WgHost)-4] ipv6 := m.WgHost[:len(m.WgHost)-4]
constituents := strings.Split(ipv6, ":") constituents := strings.Split(ipv6, ":")
logging.Log.WriteInfof(ipv6)
constituents = constituents[4:] constituents = constituents[4:]
return strings.Join(constituents, ":") return strings.Join(constituents, ":")
} }
// GetAlias: get the alias of the node
func (m *MeshNodeCrdt) GetAlias() string {
return m.Alias
}
// GetServices: get all the services the node is advertising
func (m *MeshNodeCrdt) GetServices() map[string]string {
services := make(map[string]string)
for key, service := range m.Services {
services[key] = service
}
return services
}
// GetType refers to the type of the node. Peer means that the node is globally accessible
// Client means the node is only accessible through another peer
func (n *MeshNodeCrdt) GetType() conf.NodeType {
return conf.NodeType(n.Type)
}
// GetNodes: get all the nodes in the network
func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode { func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
nodes := make(map[string]mesh.MeshNode) nodes := make(map[string]mesh.MeshNode)
@ -348,8 +621,27 @@ func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
Timestamp: node.Timestamp, Timestamp: node.Timestamp,
Routes: node.Routes, Routes: node.Routes,
Description: node.Description, Description: node.Description,
Alias: node.Alias,
Services: node.GetServices(),
Type: node.Type,
} }
} }
return nodes return nodes
} }
// GetDestination: get destination of the route
func (r *Route) GetDestination() *net.IPNet {
_, ipnet, _ := net.ParseCIDR(r.Destination)
return ipnet
}
// GetHopCount: get the number of hops to the destination
func (r *Route) GetHopCount() int {
return len(r.Path)
}
// GetPath: get the total path which includes the number of hops
func (r *Route) GetPath() []string {
return r.Path
}

View File

@ -1,15 +1,24 @@
package crdt // automerge: automerge is a CRDT library. Defines a CRDT
// datastore and methods to resolve conflicts
package automerge
import ( import (
"github.com/automerge/automerge-go" "github.com/automerge/automerge-go"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
) )
// AutomergeSync: defines a synchroniser to bi-directionally synchronise the
// two states
type AutomergeSync struct { type AutomergeSync struct {
state *automerge.SyncState // state: the automerge sync state to use
state *automerge.SyncState
// manager: the corresponding data store that we are merging
manager *CrdtMeshManager manager *CrdtMeshManager
} }
// GenerateMessage: geenrate a new automerge message to synchronise
// returns a byte of the message and a boolean of whether or not there
// are more messages in the sequence
func (a *AutomergeSync) GenerateMessage() ([]byte, bool) { func (a *AutomergeSync) GenerateMessage() ([]byte, bool) {
msg, valid := a.state.GenerateMessage() msg, valid := a.state.GenerateMessage()
@ -20,6 +29,8 @@ func (a *AutomergeSync) GenerateMessage() ([]byte, bool) {
return msg.Bytes(), true return msg.Bytes(), true
} }
// RecvMessage: receive an automerge message to merge in the datastore
// returns an error if unsuccessful
func (a *AutomergeSync) RecvMessage(msg []byte) error { func (a *AutomergeSync) RecvMessage(msg []byte) error {
_, err := a.state.ReceiveMessage(msg) _, err := a.state.ReceiveMessage(msg)
@ -30,11 +41,13 @@ func (a *AutomergeSync) RecvMessage(msg []byte) error {
return nil return nil
} }
// Complete: complete the synchronisation process
func (a *AutomergeSync) Complete() { func (a *AutomergeSync) Complete() {
logging.Log.WriteInfof("Sync Completed") logging.Log.WriteInfof("sync completed")
a.manager.SaveChanges() a.manager.SaveChanges()
} }
// NewAutomergeSync: instantiates a new automerge syncer
func NewAutomergeSync(manager *CrdtMeshManager) *AutomergeSync { func NewAutomergeSync(manager *CrdtMeshManager) *AutomergeSync {
return &AutomergeSync{ return &AutomergeSync{
state: automerge.NewSyncState(manager.doc), state: automerge.NewSyncState(manager.doc),

View File

@ -1,14 +1,14 @@
package crdt package automerge
import ( import (
"slices" "net"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@ -22,7 +22,7 @@ func setUpTests() *TestParams {
DevName: "wg0", DevName: "wg0",
Port: 5000, Port: 5000,
Client: nil, Client: nil,
Conf: conf.WgMeshConfiguration{}, Conf: &conf.WgConfiguration{},
}) })
return &TestParams{ return &TestParams{
@ -31,22 +31,26 @@ func setUpTests() *TestParams {
} }
func getTestNode() mesh.MeshNode { func getTestNode() mesh.MeshNode {
pubKey, _ := wgtypes.GeneratePrivateKey()
return &MeshNodeCrdt{ return &MeshNodeCrdt{
HostEndpoint: "public-endpoint:8080", HostEndpoint: "public-endpoint:8080",
WgEndpoint: "public-endpoint:21906", WgEndpoint: "public-endpoint:21906",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d218/128", WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d218/128",
PublicKey: "AAAAAAAAAAAA", PublicKey: pubKey.String(),
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
Description: "A node that we are adding", Description: "A node that we are adding",
} }
} }
func getTestNode2() mesh.MeshNode { func getTestNode2() mesh.MeshNode {
pubKey, _ := wgtypes.GeneratePrivateKey()
return &MeshNodeCrdt{ return &MeshNodeCrdt{
HostEndpoint: "public-endpoint:8081", HostEndpoint: "public-endpoint:8081",
WgEndpoint: "public-endpoint:21907", WgEndpoint: "public-endpoint:21907",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d219/128", WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d219/128",
PublicKey: "BBBBBBBBB", PublicKey: pubKey.String(),
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
Description: "A node that we are adding", Description: "A node that we are adding",
} }
@ -54,9 +58,11 @@ func getTestNode2() mesh.MeshNode {
func TestAddNodeNodeExists(t *testing.T) { func TestAddNodeNodeExists(t *testing.T) {
testParams := setUpTests() testParams := setUpTests()
testParams.manager.AddNode(getTestNode()) node := getTestNode()
testParams.manager.AddNode(node)
node, err := testParams.manager.GetNode("public-endpoint:8080") pubKey, _ := node.GetPublicKey()
node, err := testParams.manager.GetNode(pubKey.String())
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -70,25 +76,27 @@ func TestAddNodeNodeExists(t *testing.T) {
func TestAddNodeAddRoute(t *testing.T) { func TestAddNodeAddRoute(t *testing.T) {
testParams := setUpTests() testParams := setUpTests()
testNode := getTestNode() testNode := getTestNode()
testParams.manager.AddNode(testNode) pubKey, _ := testNode.GetPublicKey()
testParams.manager.AddRoutes(testNode.GetHostEndpoint(), "fd:1c64:1d00::/48")
updatedNode, err := testParams.manager.GetNode(testNode.GetHostEndpoint()) _, destination, _ := net.ParseCIDR("fd:1c64:1d00::/48")
testParams.manager.AddNode(testNode)
testParams.manager.AddRoutes(pubKey.String(), &mesh.RouteStub{
Destination: destination,
Path: make([]string, 0),
})
updatedNode, err := testParams.manager.GetNode(pubKey.String())
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if updatedNode == nil { if updatedNode == nil {
t.Fatalf(`Node does not exist in the mesh`) t.Fatalf(`node does not exist in the mesh`)
} }
routes := updatedNode.GetRoutes() routes := updatedNode.GetRoutes()
if !slices.Contains(routes, "fd:1c64:1d00::/48") {
t.Fatal("Route node not added")
}
if len(routes) != 1 { if len(routes) != 1 {
t.Fatal(`Route length mismatch`) t.Fatal(`Route length mismatch`)
} }
@ -253,7 +261,9 @@ func TestUpdateTimeStampNodeExists(t *testing.T) {
node := getTestNode() node := getTestNode()
testParams.manager.AddNode(node) testParams.manager.AddNode(node)
err := testParams.manager.UpdateTimeStamp(node.GetHostEndpoint()) pubKey, _ := node.GetPublicKey()
err := testParams.manager.UpdateTimeStamp(pubKey.String())
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -282,7 +292,12 @@ func TestSetDescriptionNodeExists(t *testing.T) {
func TestAddRoutesNodeDoesNotExist(t *testing.T) { func TestAddRoutesNodeDoesNotExist(t *testing.T) {
testParams := setUpTests() testParams := setUpTests()
err := testParams.manager.AddRoutes("AAAAA", "fd:1c64:1d00::/48") _, destination, _ := net.ParseCIDR("fd:1c64:1d00::/48")
err := testParams.manager.AddRoutes("AAAAA", &mesh.RouteStub{
Destination: destination,
Path: make([]string, 0),
})
if err == nil { if err == nil {
t.Error(err) t.Error(err)
@ -293,16 +308,11 @@ func TestCompareComparesByPublicKey(t *testing.T) {
node := getTestNode().(*MeshNodeCrdt) node := getTestNode().(*MeshNodeCrdt)
node2 := getTestNode2().(*MeshNodeCrdt) node2 := getTestNode2().(*MeshNodeCrdt)
if node.Compare(node2) != -1 { pubKey1, _ := node.GetPublicKey()
t.Fatalf(`node is alphabetically before node2`) pubKey2, _ := node2.GetPublicKey()
}
if node2.Compare(node) != 1 { if node.Compare(node2) != strings.Compare(pubKey1.String(), pubKey2.String()) {
t.Fatalf(`node is alphabetical;y before node2`) t.Fatalf(`compare failed`)
}
if node.Compare(node) != 0 {
t.Fatalf(`node is equal to node`)
} }
} }

View File

@ -1,54 +1,79 @@
package crdt package automerge
import ( import (
"fmt" "fmt"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
) )
// CrdtProviderFactory: abstracts the instantiation of an automerge
// datastore
type CrdtProviderFactory struct{} type CrdtProviderFactory struct{}
// CreateMesh: create a new mesh datastore
func (f *CrdtProviderFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) { func (f *CrdtProviderFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) {
return NewCrdtNodeManager(&NewCrdtNodeMangerParams{ return NewCrdtNodeManager(&NewCrdtNodeMangerParams{
MeshId: params.MeshId, MeshId: params.MeshId,
DevName: params.DevName, DevName: params.DevName,
Conf: *params.Conf, Conf: params.Conf,
Client: params.Client, Client: params.Client,
}) })
} }
// MeshNodeFactory: abstracts the instnatiation of a node
type MeshNodeFactory struct { type MeshNodeFactory struct {
Config conf.WgMeshConfiguration Config conf.DaemonConfiguration
} }
// Build builds the mesh node that represents the host machine to add // Build: builds the mesh node that represents the host machine to add
// to the mesh // to the mesh
func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode { func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode {
hostName := f.getAddress(params) hostName := f.getAddress(params)
grpcEndpoint := fmt.Sprintf("%s:%d", hostName, f.Config.GrpcPort)
if *params.MeshConfig.Role == conf.CLIENT_ROLE {
grpcEndpoint = "-"
}
return &MeshNodeCrdt{ return &MeshNodeCrdt{
HostEndpoint: fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort), HostEndpoint: grpcEndpoint,
PublicKey: params.PublicKey.String(), PublicKey: params.PublicKey.String(),
WgEndpoint: fmt.Sprintf("%s:%d", hostName, params.WgPort), WgEndpoint: fmt.Sprintf("%s:%d", hostName, params.WgPort),
WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()), WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()),
// Always set the routes as empty. // Always set the routes as empty.
// Routes handled by external component // Routes handled by external component
Routes: map[string]interface{}{}, Routes: make(map[string]Route),
Description: "",
Alias: "",
Type: string(*params.MeshConfig.Role),
} }
} }
// getAddress returns the routable address of the machine. // getAddress: returns the routable address of the machine.
func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string { func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string {
var hostName string = "" var hostName string = ""
if params.Endpoint != "" { if params.Endpoint != "" {
hostName = params.Endpoint hostName = params.Endpoint
} else if len(f.Config.Endpoint) != 0 { } else if len(*params.MeshConfig.Endpoint) != 0 {
hostName = f.Config.Endpoint hostName = *params.MeshConfig.Endpoint
} else { } else {
hostName = lib.GetOutboundIP().String() ipFunc := lib.GetPublicIP
if *params.MeshConfig.IPDiscovery == conf.OUTGOING_IP_DISCOVERY {
ipFunc = lib.GetOutboundIP
}
ip, err := ipFunc()
if err != nil {
return ""
}
hostName = ip.String()
} }
return hostName return hostName

View File

@ -1,14 +1,23 @@
package crdt package automerge
// Route: Represents a CRDT of the given route
type Route struct {
Destination string `automerge:"destination"`
Path []string `automerge:"path"`
}
// MeshNodeCrdt: Represents a CRDT for a mesh nodes // MeshNodeCrdt: Represents a CRDT for a mesh nodes
type MeshNodeCrdt struct { type MeshNodeCrdt struct {
HostEndpoint string `automerge:"hostEndpoint"` HostEndpoint string `automerge:"hostEndpoint"`
WgEndpoint string `automerge:"wgEndpoint"` WgEndpoint string `automerge:"wgEndpoint"`
PublicKey string `automerge:"publicKey"` PublicKey string `automerge:"publicKey"`
WgHost string `automerge:"wgHost"` WgHost string `automerge:"wgHost"`
Timestamp int64 `automerge:"timestamp"` Timestamp int64 `automerge:"timestamp"`
Routes map[string]interface{} `automerge:"routes"` Routes map[string]Route `automerge:"routes"`
Description string `automerge:"description"` Alias string `automerge:"alias"`
Description string `automerge:"description"`
Services map[string]string `automerge:"services"`
Type string `automerge:"type"`
} }
// MeshCrdt: Represents the mesh network as a whole // MeshCrdt: Represents the mesh network as a whole

36
pkg/cmd/cmd.go Normal file
View File

@ -0,0 +1,36 @@
// cmd is a package for running commands in the different operating systems implementations
package cmd
import (
"os/exec"
"strings"
)
// CmdRunner: run cmd commands when instantiating a network
type CmdRunner interface {
RunCommands(commands ...string) error
}
// UnixCmdRunner: Run UNIX commands
type UnixCmdRunner struct{}
// RunCommand: runs the unix command. It splits the command into fields
// and then runs the command accordingly
func RunCommand(cmd string) error {
args := strings.Fields(cmd)
c := exec.Command(args[0], args[1:]...)
return c.Run()
}
// RunCommands: run a series of commands
func (l *UnixCmdRunner) RunCommands(commands ...string) error {
for _, cmd := range commands {
err := RunCommand(cmd)
if err != nil {
return err
}
}
return nil
}

View File

@ -4,152 +4,210 @@ package conf
import ( import (
"os" "os"
logging "github.com/tim-beatham/wgmesh/pkg/log" "github.com/go-playground/validator/v10"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
type WgMeshConfigurationError struct { // NodeType types of the node either peer or client
msg string type NodeType string
const (
PEER_ROLE NodeType = "peer"
CLIENT_ROLE NodeType = "client"
)
// IPDiscovery: what IPDiscovery service to use
type IPDiscovery string
const (
// Public IP use an IP service to discover your IP
PUBLIC_IP_DISCOVERY IPDiscovery = "public"
// Outgonig: Use your labelled packet IP
OUTGOING_IP_DISCOVERY IPDiscovery = "outgoing"
)
// Loglevel: what log level to use either error info or warning
type LogLevel string
const (
ERROR LogLevel = "error"
WARNING LogLevel = "warning"
INFO LogLevel = "info"
)
// WgConfiguration contains per-mesh WireGuard configuration. Contains poitner types only so we can
// tell if the attribute is set
type WgConfiguration struct {
// IPDIscovery: how to discover your IP if not specified. Use your outgoing IP or use a public
// service for IPDiscoverability
IPDiscovery *IPDiscovery `yaml:"ipDiscovery" validate:"required,eq=public|eq=outgoing"`
// AdvertiseRoutes: specifies whether the node can act as a router routing packets between meshes
AdvertiseRoutes *bool `yaml:"advertiseRoute" validate:"required"`
// AdvertiseDefaultRoute: specifies whether or not this route should advertise a default route
// for all nodes to route their packets to
AdvertiseDefaultRoute *bool `yaml:"advertiseDefaults" validate:"required"`
// Endpoint contains what value should be set as the public endpoint of this node
Endpoint *string `yaml:"publicEndpoint"`
// Role specifies whether or not the user is globally accessible.
// If the user is globaly accessible they specify themselves as a client.
Role *NodeType `yaml:"role" validate:"required,eq=client|eq=peer"`
// KeepAliveWg configures the implementation so that we send keep alive packets to peers.
KeepAliveWg *int `yaml:"keepAliveWg" validate:"omitempty,gte=0"`
// PreUp are WireGuard commands to run before adding the WG interface
PreUp []string `yaml:"preUp"`
// PostUp are WireGuard commands to run after adding the WG interface
PostUp []string `yaml:"postUp"`
// PreDown are WireGuard commands to run prior to removing the WG interface
PreDown []string `yaml:"preDown"`
// PostDown are WireGuard command to run after removing the WG interface
PostDown []string `yaml:"postDown"`
} }
func (m *WgMeshConfigurationError) Error() string { type DaemonConfiguration struct {
return m.msg
}
type WgMeshConfiguration struct {
// CertificatePath is the path to the certificate to use in mTLS // CertificatePath is the path to the certificate to use in mTLS
CertificatePath string `yaml:"certificatePath"` CertificatePath string `yaml:"certificatePath" validate:"required"`
// PrivateKeypath is the path to the clients private key in mTLS // PrivateKeypath is the path to the clients private key in mTLS
PrivateKeyPath string `yaml:"privateKeyPath"` PrivateKeyPath string `yaml:"privateKeyPath" validate:"required"`
// CaCeritifcatePath path to the certificate of the trust certificate authority // CaCeritifcatePath path to the certificate of the trust certificate authority
CaCertificatePath string `yaml:"caCertificatePath"` CaCertificatePath string `yaml:"caCertificatePath" validate:"required"`
// SkipCertVerification specify to skip certificate verification. Should only be used // SkipCertVerification specify to skip certificate verification. Should only be used
// in test environments // in test environments
SkipCertVerification bool `yaml:"skipCertVerification"` SkipCertVerification bool `yaml:"skipCertVerification"`
// Port to run the GrpcServer on // Port to run the GrpcServer on
GrpcPort string `yaml:"gRPCPort"` GrpcPort int `yaml:"gRPCPort" validate:"required"`
// AdvertiseRoutes advertises other meshes if the node is in multiple meshes // Timeout number of seconds without response that a node is considered unreachable by gRPC
AdvertiseRoutes bool `yaml:"advertiseRoutes"` Timeout int `yaml:"timeout" validate:"required,gte=1"`
// Endpoint is the IP in which this computer is publicly reachable. // Profile whether or not to include a http server that profiles the code
// usecase is when the node has multiple IP addresses Profile bool `yaml:"profile"`
Endpoint string `yaml:"publicEndpoint"` // StubWg whether or not to stub the WireGuard types
// ClusterSize size of the cluster to split on StubWg bool `yaml:"stubWg"`
ClusterSize int `yaml:"clusterSize"` // SyncInterval specifies how long the minimum time should be between synchronisation
// SyncRate number of times per second to perform a sync SyncInterval int `yaml:"syncInterval" validate:"required,gte=1"`
SyncRate float64 `yaml:"syncRate"` // PullInterval specifies the interval between checking for configuration changes
// InterClusterChance proability of inter-cluster communication in a sync round PullInterval int `yaml:"pullInterval" validate:"gte=0"`
InterClusterChance float64 `yaml:"interClusterChance"` // Heartbeat: number of seconds before the leader of the mesh sends an update to
// BranchRate number of nodes to randomly communicate with // send to every member in the mesh
BranchRate int `yaml:"branchRate"` Heartbeat int `yaml:"heartbeatInterval" validate:"required,gte=1"`
// InfectionCount number of times we sync before we can no longer catch the udpate // ClusterSize specifies how many neighbours you should synchronise with per round
InfectionCount int `yaml:"infectionCount"` ClusterSize int `yaml:"clusterSize" validate:"gte=1"`
// KeepAliveTime number of seconds before we update node indicating that we are still alive // InterClusterChance specifies the probabilityof inter-cluster communication in a sync round
KeepAliveTime int `yaml:"keepAliveTime"` InterClusterChance float64 `yaml:"interClusterChance" validate:"gt=0"`
// Timeout number of seconds before we consider the node as dead // Branch specifies the number of nodes to synchronise with when a node has
Timeout int `yaml:"timeout"` // new changes to send to the mesh
// PruneTime number of seconds before we consider the 'node' as dead Branch int `yaml:"branch" validate:"required,gte=1"`
PruneTime int `yaml:"pruneTime"` // InfectionCount: number of time to sync before an update can no longer be 'caught'
InfectionCount int `yaml:"infectionCount" validate:"required,gte=1"`
// BaseConfiguration base WireGuard configuration to use, this is used when none is provided
BaseConfiguration WgConfiguration `yaml:"baseConfiguration" validate:"required"`
// LogLevel specifies the log level to output, defaults is warning
LogLevel LogLevel `yaml:"logLevel" validate:"eq=info|eq=warning|eq=error"`
} }
func ValidateConfiguration(c *WgMeshConfiguration) error { // ValdiateMeshConfiguration: validates the mesh configuration
if len(c.CertificatePath) == 0 { func ValidateMeshConfiguration(conf *WgConfiguration) error {
return &WgMeshConfigurationError{ validate := validator.New(validator.WithRequiredStructEnabled())
msg: "A public certificate must be specified for mTLS", err := validate.Struct(conf)
}
if conf.PostDown == nil {
conf.PostDown = make([]string, 0)
} }
if len(c.PrivateKeyPath) == 0 { if conf.PostUp == nil {
return &WgMeshConfigurationError{ conf.PostUp = make([]string, 0)
msg: "A private key must be specified for mTLS",
}
} }
if len(c.CaCertificatePath) == 0 { if conf.PreDown == nil {
return &WgMeshConfigurationError{ conf.PreDown = make([]string, 0)
msg: "A ca certificate must be specified for mTLS",
}
} }
if len(c.GrpcPort) == 0 { if conf.PreUp == nil {
return &WgMeshConfigurationError{ conf.PreUp = make([]string, 0)
msg: "A grpc port must be specified",
}
} }
if c.ClusterSize <= 0 { return err
return &WgMeshConfigurationError{
msg: "A cluster size must not be 0",
}
}
if c.SyncRate <= 0 {
return &WgMeshConfigurationError{
msg: "SyncRate cannot be negative",
}
}
if c.BranchRate <= 0 {
return &WgMeshConfigurationError{
msg: "Branch rate cannot be negative",
}
}
if c.InfectionCount <= 0 {
return &WgMeshConfigurationError{
msg: "Infection count cannot be less than 1",
}
}
if c.KeepAliveTime <= 0 {
return &WgMeshConfigurationError{
msg: "KeepAliveRate cannot be less than negative",
}
}
if c.InterClusterChance <= 0 {
return &WgMeshConfigurationError{
msg: "Intercluster chance cannot be less than 0",
}
}
if c.Timeout < 1 {
return &WgMeshConfigurationError{
msg: "Timeout should be greater than or equal to 1",
}
}
if c.PruneTime <= 1 {
return &WgMeshConfigurationError{
msg: "Prune time cannot be <= 1",
}
}
if c.KeepAliveTime <= 1 {
return &WgMeshConfigurationError{
msg: "Prune time cannot be less than keep alive time",
}
}
return nil
} }
// ParseConfiguration parses the mesh configuration // ValidateDaemonConfiguration: validates the dameon configuration that is used.
func ParseConfiguration(filePath string) (*WgMeshConfiguration, error) { func ValidateDaemonConfiguration(conf *DaemonConfiguration) error {
var conf WgMeshConfiguration if conf.BaseConfiguration.KeepAliveWg == nil {
var keepAlive int = 0
conf.BaseConfiguration.KeepAliveWg = &keepAlive
}
if conf.LogLevel == "" {
conf.LogLevel = WARNING
}
validate := validator.New(validator.WithRequiredStructEnabled())
err := validate.Struct(conf)
return err
}
// ParseDaemonConfiguration parses the mesh configuration and validates the configuration
func ParseDaemonConfiguration(filePath string) (*DaemonConfiguration, error) {
var conf DaemonConfiguration
yamlBytes, err := os.ReadFile(filePath) yamlBytes, err := os.ReadFile(filePath)
if err != nil { if err != nil {
logging.Log.WriteErrorf("Read file error: %s\n", err.Error())
return nil, err return nil, err
} }
err = yaml.Unmarshal(yamlBytes, &conf) err = yaml.Unmarshal(yamlBytes, &conf)
if err != nil { if err != nil {
logging.Log.WriteErrorf("Unmarshal error: %s\n", err.Error())
return nil, err return nil, err
} }
return &conf, ValidateConfiguration(&conf) return &conf, ValidateDaemonConfiguration(&conf)
}
// MergemeshConfiguration: merges the configuration in precedence where the last
// element in the list takes the most and the first takes the least
func MergeMeshConfiguration(cfgs ...WgConfiguration) (WgConfiguration, error) {
var result WgConfiguration
for _, cfg := range cfgs {
if cfg.AdvertiseDefaultRoute != nil {
result.AdvertiseDefaultRoute = cfg.AdvertiseDefaultRoute
}
if cfg.AdvertiseRoutes != nil {
result.AdvertiseRoutes = cfg.AdvertiseRoutes
}
if cfg.Endpoint != nil {
result.Endpoint = cfg.Endpoint
}
if cfg.IPDiscovery != nil {
result.IPDiscovery = cfg.IPDiscovery
}
if cfg.KeepAliveWg != nil {
result.KeepAliveWg = cfg.KeepAliveWg
}
if cfg.PostDown != nil {
result.PostDown = cfg.PostDown
}
if cfg.PostUp != nil {
result.PostUp = cfg.PostUp
}
if cfg.PreDown != nil {
result.PreDown = cfg.PreDown
}
if cfg.PreUp != nil {
result.PreUp = cfg.PreUp
}
if cfg.Role != nil {
result.Role = cfg.Role
}
}
return result, ValidateMeshConfiguration(&result)
} }

View File

@ -1,24 +1,41 @@
package conf package conf
import "testing" import (
"testing"
)
func getExampleConfiguration() *WgMeshConfiguration { func getExampleConfiguration() *DaemonConfiguration {
return &WgMeshConfiguration{ discovery := PUBLIC_IP_DISCOVERY
CertificatePath: "./cert/cert.pem", advertiseRoutes := false
PrivateKeyPath: "./cert/key.pem", advertiseDefaultRoute := false
CaCertificatePath: "./cert/ca.pems", endpoint := "abc.com:123"
nodeType := CLIENT_ROLE
keepAliveWg := 0
return &DaemonConfiguration{
CertificatePath: "../../../cert/cert.pem",
PrivateKeyPath: "../../../cert/priv.pem",
CaCertificatePath: "../../../cert/cacert.pem",
SkipCertVerification: true, SkipCertVerification: true,
GrpcPort: "8080", GrpcPort: 25,
AdvertiseRoutes: true, Timeout: 5,
Endpoint: "localhost", Profile: false,
ClusterSize: 1, StubWg: false,
SyncRate: 1, SyncInterval: 2,
InterClusterChance: 0.1, Heartbeat: 2,
BranchRate: 2, ClusterSize: 64,
KeepAliveTime: 4, InterClusterChance: 0.15,
InfectionCount: 1, Branch: 3,
Timeout: 2, PullInterval: 0,
PruneTime: 20, InfectionCount: 2,
BaseConfiguration: WgConfiguration{
IPDiscovery: &discovery,
AdvertiseRoutes: &advertiseRoutes,
AdvertiseDefaultRoute: &advertiseDefaultRoute,
Endpoint: &endpoint,
Role: &nodeType,
KeepAliveWg: &keepAliveWg,
},
} }
} }
@ -26,7 +43,7 @@ func TestConfigurationCertificatePathEmpty(t *testing.T) {
conf := getExampleConfiguration() conf := getExampleConfiguration()
conf.CertificatePath = "" conf.CertificatePath = ""
err := ValidateConfiguration(conf) err := ValidateDaemonConfiguration(conf)
if err == nil { if err == nil {
t.Fatal(`error should be thrown`) t.Fatal(`error should be thrown`)
@ -37,7 +54,7 @@ func TestConfigurationPrivateKeyPathEmpty(t *testing.T) {
conf := getExampleConfiguration() conf := getExampleConfiguration()
conf.PrivateKeyPath = "" conf.PrivateKeyPath = ""
err := ValidateConfiguration(conf) err := ValidateDaemonConfiguration(conf)
if err == nil { if err == nil {
t.Fatal(`error should be thrown`) t.Fatal(`error should be thrown`)
@ -48,7 +65,7 @@ func TestConfigurationCaCertificatePathEmpty(t *testing.T) {
conf := getExampleConfiguration() conf := getExampleConfiguration()
conf.CaCertificatePath = "" conf.CaCertificatePath = ""
err := ValidateConfiguration(conf) err := ValidateDaemonConfiguration(conf)
if err == nil { if err == nil {
t.Fatal(`error should be thrown`) t.Fatal(`error should be thrown`)
@ -57,9 +74,110 @@ func TestConfigurationCaCertificatePathEmpty(t *testing.T) {
func TestConfigurationGrpcPortEmpty(t *testing.T) { func TestConfigurationGrpcPortEmpty(t *testing.T) {
conf := getExampleConfiguration() conf := getExampleConfiguration()
conf.GrpcPort = "" conf.GrpcPort = 0
err := ValidateConfiguration(conf) err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestIPDiscoveryNotSet(t *testing.T) {
conf := getExampleConfiguration()
ipDiscovery := IPDiscovery("djdsjdskd")
conf.BaseConfiguration.IPDiscovery = &ipDiscovery
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestAdvertiseRoutesNotSet(t *testing.T) {
conf := getExampleConfiguration()
conf.BaseConfiguration.AdvertiseRoutes = nil
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestAdvertiseDefaultRouteNotSet(t *testing.T) {
conf := getExampleConfiguration()
conf.BaseConfiguration.AdvertiseDefaultRoute = nil
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestKeepAliveWgNegative(t *testing.T) {
conf := getExampleConfiguration()
keepAliveWg := -1
conf.BaseConfiguration.KeepAliveWg = &keepAliveWg
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestRoleTypeNotValid(t *testing.T) {
conf := getExampleConfiguration()
role := NodeType("bruhhh")
conf.BaseConfiguration.Role = &role
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestRoleTypeNotSpecified(t *testing.T) {
conf := getExampleConfiguration()
conf.BaseConfiguration.Role = nil
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`invalid role type`)
}
}
func TestBranchRateZero(t *testing.T) {
conf := getExampleConfiguration()
conf.Branch = 0
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestsyncTimeZero(t *testing.T) {
conf := getExampleConfiguration()
conf.SyncInterval = 0
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestKeepAliveTimeZero(t *testing.T) {
conf := getExampleConfiguration()
conf.Heartbeat = 0
err := ValidateDaemonConfiguration(conf)
if err == nil { if err == nil {
t.Fatal(`error should be thrown`) t.Fatal(`error should be thrown`)
@ -69,97 +187,51 @@ func TestConfigurationGrpcPortEmpty(t *testing.T) {
func TestClusterSizeZero(t *testing.T) { func TestClusterSizeZero(t *testing.T) {
conf := getExampleConfiguration() conf := getExampleConfiguration()
conf.ClusterSize = 0 conf.ClusterSize = 0
err := ValidateDaemonConfiguration(conf)
err := ValidateConfiguration(conf)
if err == nil { if err == nil {
t.Fatal(`error should be thrown`) t.Fatal(`error should be thrown`)
} }
} }
func SyncRateZero(t *testing.T) { func TestInterClusterChanceZero(t *testing.T) {
conf := getExampleConfiguration() conf := getExampleConfiguration()
conf.SyncRate = 0 conf.InterClusterChance = 0
err := ValidateConfiguration(conf) err := ValidateDaemonConfiguration(conf)
if err == nil { if err == nil {
t.Fatal(`error should be thrown`) t.Fatal(`error should be thrown`)
} }
} }
func BranchRateZero(t *testing.T) { func TestInfectionCountOne(t *testing.T) {
conf := getExampleConfiguration()
conf.BranchRate = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func InfectionCountZero(t *testing.T) {
conf := getExampleConfiguration() conf := getExampleConfiguration()
conf.InfectionCount = 0 conf.InfectionCount = 0
err := ValidateConfiguration(conf) err := ValidateDaemonConfiguration(conf)
if err == nil { if err == nil {
t.Fatal(`error should be thrown`) t.Fatal(`error should be thrown`)
} }
} }
func KeepAliveRateZero(t *testing.T) { func TestPullTimeNegative(t *testing.T) {
conf := getExampleConfiguration() conf := getExampleConfiguration()
conf.KeepAliveTime = 0 conf.PullInterval = -1
err := ValidateConfiguration(conf) err := ValidateDaemonConfiguration(conf)
if err == nil { if err == nil {
t.Fatal(`error should be thrown`) t.Fatal(`error should be thrown`)
} }
} }
func TestValidCOnfiguration(t *testing.T) { func TestValidConfiguration(t *testing.T) {
conf := getExampleConfiguration() conf := getExampleConfiguration()
err := ValidateDaemonConfiguration(conf)
err := ValidateConfiguration(conf)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
} }
func TestTimeout(t *testing.T) {
conf := getExampleConfiguration()
conf.Timeout = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestPruneTimeZero(t *testing.T) {
conf := getExampleConfiguration()
conf.PruneTime = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatalf(`Error should be thrown`)
}
}
func TestPruneTimeLessThanKeepAliveTime(t *testing.T) {
conf := getExampleConfiguration()
conf.PruneTime = 1
err := ValidateConfiguration(conf)
if err == nil {
t.Fatalf(`Error should be thrown`)
}
}

View File

@ -7,25 +7,30 @@ import (
"slices" "slices"
) )
// ConnCluster splits nodes into clusters where nodes in a cluster communicate // ConnCluster: splits nodes into clusters where nodes in a cluster communicate
// frequently and nodes outside of a cluster communicate infrequently // frequently and nodes outside of a cluster communicate infrequently
type ConnCluster interface { type ConnCluster interface {
// Getneighbours: get neighbours of the cluster the node is in
GetNeighbours(global []string, selfId string) []string GetNeighbours(global []string, selfId string) []string
// GetInterCluster: get the cluster to communicate with
GetInterCluster(global []string, selfId string) string GetInterCluster(global []string, selfId string) string
} }
// ConnnClusterImpl: implementation of the connection cluster
type ConnClusterImpl struct { type ConnClusterImpl struct {
clusterSize int clusterSize int
} }
// perform binary search to attain a size of a group
func binarySearch(global []string, selfId string, groupSize int) (int, int) { func binarySearch(global []string, selfId string, groupSize int) (int, int) {
slices.Sort(global) slices.Sort(global)
lower := 0 lower := 0
higher := len(global) - 1 higher := len(global) - 1
mid := (lower + higher) / 2
for (higher+1)-lower > groupSize { for (higher+1)-lower > groupSize {
mid := (lower + higher) / 2
if global[mid] < selfId { if global[mid] < selfId {
lower = mid + 1 lower = mid + 1
} else if global[mid] > selfId { } else if global[mid] > selfId {
@ -33,14 +38,12 @@ func binarySearch(global []string, selfId string, groupSize int) (int, int) {
} else { } else {
break break
} }
mid = (lower + higher) / 2
} }
return lower, int(math.Min(float64(lower+groupSize), float64(len(global)))) return lower, int(math.Min(float64(lower+groupSize), float64(len(global))))
} }
// GetNeighbours return the neighbours 'nearest' to you. In this implementation the // GetNeighbours: return the neighbours 'nearest' to you. In this implementation the
// neighbours aren't actually the ones nearest to you but just the ones nearest // neighbours aren't actually the ones nearest to you but just the ones nearest
// to you alphabetically. Perform binary search to get the total group // to you alphabetically. Perform binary search to get the total group
func (i *ConnClusterImpl) GetNeighbours(global []string, selfId string) []string { func (i *ConnClusterImpl) GetNeighbours(global []string, selfId string) []string {
@ -51,19 +54,22 @@ func (i *ConnClusterImpl) GetNeighbours(global []string, selfId string) []string
return global[lower:higher] return global[lower:higher]
} }
// GetInterCluster get nodes not in your cluster. Every round there is a given chance // GetInterCluster: get nodes not in your cluster. Every round there is a given chance
// you will communicate with a random node that is not in your cluster. // you will communicate with a random node that is not in your cluster.
func (i *ConnClusterImpl) GetInterCluster(global []string, selfId string) string { func (i *ConnClusterImpl) GetInterCluster(global []string, selfId string) string {
// Doesn't matter if not in it. Get index of where the node 'should' be // Doesn't matter if not in it. Get index of where the node 'should' be
slices.Sort(global)
index, _ := binarySearch(global, selfId, 1) index, _ := binarySearch(global, selfId, 1)
numClusters := math.Ceil(float64(len(global)) / float64(i.clusterSize))
randomCluster := rand.Intn(int(numClusters)-1) + 1 randomCluster := rand.Intn(2) + 1
neighbourIndex := (index + randomCluster) % len(global) // cluster is considered a heap
neighbourIndex := (2*index + (randomCluster * i.clusterSize)) % len(global)
return global[neighbourIndex] return global[neighbourIndex]
} }
// NewConnCluster: instantiate a new connection cluster of a given group size.
func NewConnCluster(clusterSize int) (ConnCluster, error) { func NewConnCluster(clusterSize int) (ConnCluster, error) {
log2Cluster := math.Log2(float64(clusterSize)) log2Cluster := math.Log2(float64(clusterSize))

View File

@ -6,7 +6,7 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
) )
@ -18,6 +18,7 @@ type PeerConnection interface {
GetClient() (*grpc.ClientConn, error) GetClient() (*grpc.ClientConn, error)
} }
// PeerConenctionFactory: create a new connection to a peer
type PeerConnectionFactory = func(clientConfig *tls.Config, server string) (PeerConnection, error) type PeerConnectionFactory = func(clientConfig *tls.Config, server string) (PeerConnection, error)
// WgCtrlConnection implements PeerConnection. // WgCtrlConnection implements PeerConnection.

View File

@ -7,7 +7,7 @@ import (
"os" "os"
"sync" "sync"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
) )
// ConnectionManager defines an interface for maintaining peer connections // ConnectionManager defines an interface for maintaining peer connections
@ -19,9 +19,11 @@ type ConnectionManager interface {
// If the endpoint does not exist then add the connection. Returns an error // If the endpoint does not exist then add the connection. Returns an error
// if something went wrong // if something went wrong
GetConnection(endPoint string) (PeerConnection, error) GetConnection(endPoint string) (PeerConnection, error)
// HasConnections returns true if a client has already registered at the givne // HasConnections returns true if a peer has already registered at the given
// endpoint or false otherwise. // endpoint or false otherwise.
HasConnection(endPoint string) bool HasConnection(endPoint string) bool
// Removes a connection if it exists
RemoveConnection(endPoint string) error
// Goes through all the connections and closes eachone // Goes through all the connections and closes eachone
Close() error Close() error
} }
@ -32,7 +34,6 @@ type ConnectionManagerImpl struct {
// clientConnections maps an endpoint to a connection // clientConnections maps an endpoint to a connection
conLoc sync.RWMutex conLoc sync.RWMutex
clientConnections map[string]PeerConnection clientConnections map[string]PeerConnection
serverConfig *tls.Config
clientConfig *tls.Config clientConfig *tls.Config
connFactory PeerConnectionFactory connFactory PeerConnectionFactory
} }
@ -61,37 +62,25 @@ func NewConnectionManager(params *NewConnectionManagerParams) (ConnectionManager
return nil, err return nil, err
} }
serverAuth := tls.RequireAndVerifyClientCert
if params.SkipCertVerification {
serverAuth = tls.RequireAnyClientCert
}
certPool := x509.NewCertPool() certPool := x509.NewCertPool()
if !params.SkipCertVerification { if params.CaCert == "" {
return nil, errors.New("CA Cert is not specified")
if params.CaCert == "" {
return nil, errors.New("CA Cert is not specified")
}
caCert, err := os.ReadFile(params.CaCert)
if err != nil {
return nil, err
}
certPool.AppendCertsFromPEM(caCert)
} }
serverConfig := &tls.Config{ caCert, err := os.ReadFile(params.CaCert)
ClientAuth: serverAuth,
Certificates: []tls.Certificate{cert}, if err != nil {
return nil, err
}
if ok := certPool.AppendCertsFromPEM(caCert); !ok {
return nil, errors.New("could not parse PEM")
} }
clientConfig := &tls.Config{ clientConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: params.SkipCertVerification, InsecureSkipVerify: params.SkipCertVerification,
Certificates: []tls.Certificate{cert},
RootCAs: certPool, RootCAs: certPool,
} }
@ -99,7 +88,6 @@ func NewConnectionManager(params *NewConnectionManagerParams) (ConnectionManager
connMgr := ConnectionManagerImpl{ connMgr := ConnectionManagerImpl{
sync.RWMutex{}, sync.RWMutex{},
connections, connections,
serverConfig,
clientConfig, clientConfig,
params.ConnFactory, params.ConnFactory,
} }
@ -150,6 +138,15 @@ func (m *ConnectionManagerImpl) HasConnection(endPoint string) bool {
return exists return exists
} }
// RemoveConnection removes the given connection if it exists
func (m *ConnectionManagerImpl) RemoveConnection(endPoint string) error {
m.conLoc.Lock()
err := m.clientConnections[endPoint].Close()
delete(m.clientConnections, endPoint)
m.conLoc.Unlock()
return err
}
func (m *ConnectionManagerImpl) Close() error { func (m *ConnectionManagerImpl) Close() error {
for _, conn := range m.clientConnections { for _, conn := range m.clientConnections {
if err := conn.Close(); err != nil { if err := conn.Close(); err != nil {

View File

@ -53,13 +53,13 @@ func TestNewConnectionManagerCACertDoesNotExistAndVerify(t *testing.T) {
func TestNewConnectionManagerCACertDoesNotExistAndNotVerify(t *testing.T) { func TestNewConnectionManagerCACertDoesNotExistAndNotVerify(t *testing.T) {
params := getConnectionManagerParams() params := getConnectionManagerParams()
params.CaCert = "" params.CaCert = "./cert/sdjsdjsdjk.pem"
params.SkipCertVerification = true params.SkipCertVerification = true
_, err := NewConnectionManager(params) _, err := NewConnectionManager(params)
if err != nil { if err == nil {
t.Fatal(`an error should not be thrown`) t.Fatalf(`an error should be thrown`)
} }
} }

View File

@ -2,32 +2,34 @@ package conn
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net" "net"
"os"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/smegmesh/pkg/rpc"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
) )
// ConnectionServer manages gRPC server peer connections // ConnectionServer manages gRPC server peer connections
type ConnectionServer struct { type ConnectionServer struct {
// tlsConfiguration of the server
serverConfig *tls.Config
// server an instance of the grpc server // server an instance of the grpc server
server *grpc.Server // the authentication service to authenticate nodes server *grpc.Server
// the ctrl service to manage node // the ctrl service to manage node
ctrlProvider rpc.MeshCtrlServerServer ctrlProvider rpc.MeshCtrlServerServer
// the sync service to synchronise nodes // the sync service to synchronise nodes
syncProvider rpc.SyncServiceServer syncProvider rpc.SyncServiceServer
Conf *conf.WgMeshConfiguration Conf *conf.DaemonConfiguration
listener net.Listener listener net.Listener
} }
// NewConnectionServerParams contains params for creating a new connection server // NewConnectionServerParams contains params for creating a new connection server
type NewConnectionServerParams struct { type NewConnectionServerParams struct {
Conf *conf.WgMeshConfiguration Conf *conf.DaemonConfiguration
CtrlProvider rpc.MeshCtrlServerServer CtrlProvider rpc.MeshCtrlServerServer
SyncProvider rpc.SyncServiceServer SyncProvider rpc.SyncServiceServer
} }
@ -47,9 +49,26 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
serverAuth = tls.RequireAnyClientCert serverAuth = tls.RequireAnyClientCert
} }
certPool := x509.NewCertPool()
if params.Conf.CaCertificatePath == "" {
return nil, errors.New("CA Cert is not specified")
}
caCert, err := os.ReadFile(params.Conf.CaCertificatePath)
if err != nil {
return nil, err
}
if ok := certPool.AppendCertsFromPEM(caCert); !ok {
return nil, errors.New("could not parse PEM")
}
serverConfig := &tls.Config{ serverConfig := &tls.Config{
ClientAuth: serverAuth, ClientAuth: serverAuth,
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
ClientCAs: certPool,
} }
server := grpc.NewServer( server := grpc.NewServer(
@ -60,7 +79,6 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
syncProvider := params.SyncProvider syncProvider := params.SyncProvider
connServer := ConnectionServer{ connServer := ConnectionServer{
serverConfig: serverConfig,
server: server, server: server,
ctrlProvider: ctrlProvider, ctrlProvider: ctrlProvider,
syncProvider: syncProvider, syncProvider: syncProvider,
@ -73,13 +91,12 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
// Listen for incoming requests. Returns an error if something went wrong. // Listen for incoming requests. Returns an error if something went wrong.
func (s *ConnectionServer) Listen() error { func (s *ConnectionServer) Listen() error {
rpc.RegisterMeshCtrlServerServer(s.server, s.ctrlProvider) rpc.RegisterMeshCtrlServerServer(s.server, s.ctrlProvider)
rpc.RegisterSyncServiceServer(s.server, s.syncProvider) rpc.RegisterSyncServiceServer(s.server, s.syncProvider)
lis, err := net.Listen("tcp", ":"+s.Conf.GrpcPort) lis, err := net.Listen("tcp", fmt.Sprintf(":%d", s.Conf.GrpcPort))
s.listener = lis s.listener = lis
logging.Log.WriteInfof("GRPC listening on %s\n", s.Conf.GrpcPort) logging.Log.WriteInfof("GRPC listening on %d\n", s.Conf.GrpcPort)
if err != nil { if err != nil {
logging.Log.WriteErrorf(err.Error()) logging.Log.WriteErrorf(err.Error())

View File

@ -16,6 +16,11 @@ func (s *ConnectionManagerStub) AddConnection(endPoint string) (PeerConnection,
return mock, nil return mock, nil
} }
func (s *ConnectionManagerStub) RemoveConnection(endPoint string) error {
delete(s.Endpoints, endPoint)
return nil
}
func (s *ConnectionManagerStub) GetConnection(endPoint string) (PeerConnection, error) { func (s *ConnectionManagerStub) GetConnection(endPoint string) (PeerConnection, error) {
endpoint, ok := s.Endpoints[endPoint] endpoint, ok := s.Endpoints[endPoint]

View File

@ -1,84 +0,0 @@
package conn
import (
"errors"
"slices"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
// ConnectionWindow maintains a sliding window of connections between users
type ConnectionWindow interface {
// GetWindow is a list of connections to choose from
GetWindow() []string
// SlideConnection removes a node from the window and adds a random node
// not already in the window. connList represents the list of possible
// connections to choose from
SlideConnection(connList []string) error
// PushConneciton is used when connection list less than window size.
PutConnection(conn []string) error
// IsFull returns true if the window is full. In which case we must slide the window
IsFull() bool
}
type ConnectionWindowImpl struct {
window []string
windowSize int
}
// GetWindow gets the current list of active connections in
// the window
func (c *ConnectionWindowImpl) GetWindow() []string {
return c.window
}
// SlideConnection slides the connection window by one shuffling items
// in the windows
func (c *ConnectionWindowImpl) SlideConnection(connList []string) error {
// If the number of peer connections is less than the length of the window
// then exit early. Can't slide the window it should contain all nodes!
if len(c.window) < c.windowSize {
return nil
}
filter := func(node string) bool {
return !slices.Contains(c.window, node)
}
pool := lib.Filter(connList, filter)
newNode := lib.RandomSubsetOfLength(pool, 1)
if len(newNode) == 0 {
return errors.New("could not slide window")
}
for i := len(c.window) - 1; i >= 1; i-- {
c.window[i] = c.window[i-1]
}
c.window[0] = newNode[0]
return nil
}
// PutConnection put random connections in the connection
func (c *ConnectionWindowImpl) PutConnection(connList []string) error {
if len(c.window) >= c.windowSize {
return errors.New("cannot place connection. Window full need to slide")
}
c.window = lib.RandomSubsetOfLength(connList, c.windowSize)
return nil
}
func (c *ConnectionWindowImpl) IsFull() bool {
return len(c.window) >= c.windowSize
}
func NewConnectionWindow(windowLength int) ConnectionWindow {
window := &ConnectionWindowImpl{
window: make([]string, 0),
windowSize: windowLength,
}
return window
}

529
pkg/crdt/datastore.go Normal file
View File

@ -0,0 +1,529 @@
package crdt
import (
"bytes"
"encoding/gob"
"fmt"
"net"
"slices"
"strings"
"time"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// Route: represents a route within the data store
type Route struct {
// Destination the route is advertising
Destination string
// Path to the destination
Path []string
}
// GetDestination implements mesh.Route.
func (r *Route) GetDestination() *net.IPNet {
_, ipnet, _ := net.ParseCIDR(r.Destination)
return ipnet
}
// GetHopCount implements mesh.Route.
func (r *Route) GetHopCount() int {
return len(r.Path)
}
// GetPath implements mesh.Route.
func (r *Route) GetPath() []string {
return r.Path
}
type MeshNode struct {
HostEndpoint string
WgEndpoint string
PublicKey string
WgHost string
Timestamp int64
Routes map[string]Route
Alias string
Description string
Services map[string]string
Type string
Tombstone bool
}
// Mark: marks the node is unreachable. This is not broadcast on
// syncrhonisation
func (m *TwoPhaseStoreMeshManager) Mark(nodeId string) {
m.store.Mark(nodeId)
}
// GetHostEndpoint: gets the gRPC endpoint of the node
func (n *MeshNode) GetHostEndpoint() string {
return n.HostEndpoint
}
// GetPublicKey: gets the public key of the node
func (n *MeshNode) GetPublicKey() (wgtypes.Key, error) {
return wgtypes.ParseKey(n.PublicKey)
}
// GetWgEndpoint(): get IP and port of the wireguard endpoint
func (n *MeshNode) GetWgEndpoint() string {
return n.WgEndpoint
}
// GetWgHost: get the IP address of the WireGuard node
func (n *MeshNode) GetWgHost() *net.IPNet {
_, ipnet, _ := net.ParseCIDR(n.WgHost)
return ipnet
}
// GetTimestamp: get the UNIX time stamp of the ndoe
func (n *MeshNode) GetTimeStamp() int64 {
return n.Timestamp
}
// GetRoutes: returns the routes that the nodes provides
func (n *MeshNode) GetRoutes() []mesh.Route {
routes := make([]mesh.Route, len(n.Routes))
for index, route := range lib.MapValues(n.Routes) {
routes[index] = &Route{
Destination: route.Destination,
Path: route.Path,
}
}
return routes
}
// GetIdentifier: returns the identifier of the node
func (m *MeshNode) GetIdentifier() string {
ipv6 := m.WgHost[:len(m.WgHost)-4]
constituents := strings.Split(ipv6, ":")
constituents = constituents[4:]
return strings.Join(constituents, ":")
}
// GetDescription: returns the description for this node
func (n *MeshNode) GetDescription() string {
return n.Description
}
// GetAlias: associates the node with an alias. Potentially used
// for DNS and so forth.
func (n *MeshNode) GetAlias() string {
return n.Alias
}
// GetServices: returns a list of services offered by the node
func (n *MeshNode) GetServices() map[string]string {
return n.Services
}
func (n *MeshNode) GetType() conf.NodeType {
return conf.NodeType(n.Type)
}
type MeshSnapshot struct {
Nodes map[string]MeshNode
}
// GetNodes() returns the nodes in the mesh
func (m *MeshSnapshot) GetNodes() map[string]mesh.MeshNode {
newMap := make(map[string]mesh.MeshNode)
for key, value := range m.Nodes {
newMap[key] = &MeshNode{
HostEndpoint: value.HostEndpoint,
PublicKey: value.PublicKey,
WgHost: value.WgHost,
WgEndpoint: value.WgEndpoint,
Timestamp: value.Timestamp,
Routes: value.Routes,
Alias: value.Alias,
Description: value.Description,
Services: value.Services,
Type: value.Type,
}
}
return newMap
}
type TwoPhaseStoreMeshManager struct {
MeshId string
IfName string
Client *wgctrl.Client
LastClock uint64
Conf *conf.WgConfiguration
DaemonConf *conf.DaemonConfiguration
store *TwoPhaseMap[string, MeshNode]
}
// AddNode() adds a node to the mesh
func (m *TwoPhaseStoreMeshManager) AddNode(node mesh.MeshNode) {
crdt, ok := node.(*MeshNode)
if !ok {
panic("node must be of type mesh node")
}
crdt.Routes = make(map[string]Route)
crdt.Services = make(map[string]string)
crdt.Timestamp = time.Now().Unix()
m.store.Put(crdt.PublicKey, *crdt)
}
// GetMesh() returns a snapshot of the mesh provided by the mesh provider.
func (m *TwoPhaseStoreMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
nodes := m.store.AsList()
snapshot := make(map[string]MeshNode)
for _, node := range nodes {
snapshot[node.PublicKey] = node
}
return &MeshSnapshot{
Nodes: snapshot,
}, nil
}
// GetMeshId() returns the ID of the mesh network
func (m *TwoPhaseStoreMeshManager) GetMeshId() string {
return m.MeshId
}
// Save() saves the mesh network
func (m *TwoPhaseStoreMeshManager) Save() []byte {
snapshot := m.store.Snapshot()
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(*snapshot)
if err != nil {
logging.Log.WriteInfof(err.Error())
}
return buf.Bytes()
}
// Load() loads a mesh network
func (m *TwoPhaseStoreMeshManager) Load(bs []byte) error {
buf := bytes.NewBuffer(bs)
dec := gob.NewDecoder(buf)
var snapshot TwoPhaseMapSnapshot[string, MeshNode]
err := dec.Decode(&snapshot)
m.store.Merge(snapshot)
return err
}
// GetDevice() get the device corresponding with the mesh
func (m *TwoPhaseStoreMeshManager) GetDevice() (*wgtypes.Device, error) {
dev, err := m.Client.Device(m.IfName)
if err != nil {
return nil, err
}
return dev, nil
}
// HasChanges returns true if we have changes since last time we synced
func (m *TwoPhaseStoreMeshManager) HasChanges() bool {
clockValue := m.store.GetHash()
return clockValue != m.LastClock
}
// Record that we have changes and save the corresponding changes
func (m *TwoPhaseStoreMeshManager) SaveChanges() {
clockValue := m.store.GetHash()
m.LastClock = clockValue
}
// UpdateTimeStamp: update the timestamp of the given node, causes a configuration refresh if the node
// is the leader causing all nodes to update their vector clocks
func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
// Sort nodes by their public key
peers := m.GetPeers()
slices.Sort(peers)
if len(peers) == 0 {
return nil
}
peerToUpdate := peers[0]
if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.DaemonConf.Heartbeat) {
m.store.Mark(peerToUpdate)
if len(peers) < 2 {
return nil
}
peerToUpdate = peers[1]
}
if peerToUpdate != nodeId {
return nil
}
// Refresh causing node to update it's time stamp
node := m.store.Get(nodeId)
node.Timestamp = time.Now().Unix()
m.store.Put(nodeId, node)
return nil
}
// AddRoutes: adds routes to the given node
func (m *TwoPhaseStoreMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
if len(routes) == 0 {
return nil
}
node := m.store.Get(nodeId)
changes := false
for _, route := range routes {
prevRoute, ok := node.Routes[route.GetDestination().String()]
if !ok || route.GetHopCount() < prevRoute.GetHopCount() {
changes = true
node.Routes[route.GetDestination().String()] = Route{
Destination: route.GetDestination().String(),
Path: route.GetPath(),
}
}
}
// Only add nodes on changes. Otherwise the node will advertise new
// information whenever they get new routes
if changes {
m.store.Put(nodeId, node)
}
return nil
}
// RemoveRoute: deletes the routes from the given node
func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
if len(routes) == 0 {
return nil
}
node := m.store.Get(nodeId)
changes := false
for _, route := range routes {
changes = true
logging.Log.WriteInfof("deleting: %s", route.GetDestination().String())
delete(node.Routes, route.GetDestination().String())
}
if changes {
m.store.Put(nodeId, node)
}
return nil
}
// GetSyncer: returns the bi-directionally synchroniser to merge documents
func (m *TwoPhaseStoreMeshManager) GetSyncer() mesh.MeshSyncer {
return NewTwoPhaseSyncer(m)
}
// GetNode: get a particular not within the mesh network
func (m *TwoPhaseStoreMeshManager) GetNode(nodeId string) (mesh.MeshNode, error) {
if !m.store.Contains(nodeId) {
return nil, fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
node := m.store.Get(nodeId)
return &node, nil
}
// NodeExists: returns true if a particular node exists false otherwise
func (m *TwoPhaseStoreMeshManager) NodeExists(nodeId string) bool {
return m.store.Contains(nodeId)
}
// SetDescription: sets the description of this automerge data type
func (m *TwoPhaseStoreMeshManager) SetDescription(nodeId string, description string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
node := m.store.Get(nodeId)
node.Description = description
m.store.Put(nodeId, node)
return nil
}
// SetAlias: set the alias of the given node
func (m *TwoPhaseStoreMeshManager) SetAlias(nodeId string, alias string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
node := m.store.Get(nodeId)
node.Alias = alias
m.store.Put(nodeId, node)
return nil
}
// AddService: adds a service to the given node
func (m *TwoPhaseStoreMeshManager) AddService(nodeId string, key string, value string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
node := m.store.Get(nodeId)
node.Services[key] = value
m.store.Put(nodeId, node)
return nil
}
// RemoveService: removes the service form a node, throws an error if the service does not exist
func (m *TwoPhaseStoreMeshManager) RemoveService(nodeId string, key string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
node := m.store.Get(nodeId)
if _, ok := node.Services[key]; !ok {
return fmt.Errorf("datastore: node does not contain service %s", key)
}
delete(node.Services, key)
m.store.Put(nodeId, node)
return nil
}
// Prune: prunes all nodes that have not updated their vector clock in a given amount
// of time
func (m *TwoPhaseStoreMeshManager) Prune() error {
m.store.Prune()
return nil
}
// GetPeers: get a list of contactable peers
func (m *TwoPhaseStoreMeshManager) GetPeers() []string {
nodes := m.store.AsList()
nodes = lib.Filter(nodes, func(mn MeshNode) bool {
if mn.Type != string(conf.PEER_ROLE) {
return false
}
// If the node is marked as unreachable don't consider it a peer.
// this help to optimize convergence time for unreachable nodes.
// However advertising it to other nodes could result in flapping.
if m.store.IsMarked(mn.PublicKey) {
return false
}
return true
})
return lib.Map(nodes, func(mn MeshNode) string {
return mn.PublicKey
})
}
// getRoutes: get all routes the target node is advertising
func (m *TwoPhaseStoreMeshManager) getRoutes(targetNode string) (map[string]Route, error) {
if !m.store.Contains(targetNode) {
return nil, fmt.Errorf("getRoute: cannot get route %s does not exist", targetNode)
}
node := m.store.Get(targetNode)
return node.Routes, nil
}
// GetRoutes: Get all unique routes the target node is advertising.
// on conflicts the route with the least hop count is chosen
func (m *TwoPhaseStoreMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) {
node, err := m.GetNode(targetNode)
if err != nil {
return nil, err
}
routes := make(map[string]mesh.Route)
// Add routes that the node directly has
for _, route := range node.GetRoutes() {
routes[route.GetDestination().String()] = route
}
// Work out the other routes in the mesh
for _, node := range m.GetPeers() {
nodeRoutes, err := m.getRoutes(node)
if err != nil {
return nil, err
}
for _, route := range nodeRoutes {
otherRoute, ok := routes[route.GetDestination().String()]
hopCount := route.GetHopCount()
if node != targetNode {
hopCount += 1
}
if !ok || route.GetHopCount()+1 < otherRoute.GetHopCount() {
routes[route.GetDestination().String()] = &Route{
Destination: route.GetDestination().String(),
Path: append(route.GetPath(), m.GetMeshId()),
}
}
}
}
return routes, nil
}
// RemoveNode: remove the node from the mesh
func (m *TwoPhaseStoreMeshManager) RemoveNode(nodeId string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
m.store.Remove(nodeId)
return nil
}
// GetConfiguration gets the WireGuard configuration to use for this
// network
func (m *TwoPhaseStoreMeshManager) GetConfiguration() *conf.WgConfiguration {
return m.Conf
}

440
pkg/crdt/datastore_test.go Normal file
View File

@ -0,0 +1,440 @@
package crdt
import (
"net"
"slices"
"testing"
"time"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type TestParams struct {
manager mesh.MeshProvider
publicKey *wgtypes.Key
}
func setUpTests() *TestParams {
advertiseRoutes := false
advertiseDefaultRoute := false
role := conf.PEER_ROLE
discovery := conf.OUTGOING_IP_DISCOVERY
factory := &TwoPhaseMapFactory{
Config: &conf.DaemonConfiguration{
CertificatePath: "/somecertificatepath",
PrivateKeyPath: "/someprivatekeypath",
CaCertificatePath: "/somecacertificatepath",
SkipCertVerification: true,
GrpcPort: 0,
Timeout: 20,
Profile: false,
SyncInterval: 2,
Heartbeat: 10,
ClusterSize: 32,
InterClusterChance: 0.15,
Branch: 3,
InfectionCount: 3,
BaseConfiguration: conf.WgConfiguration{
IPDiscovery: &discovery,
AdvertiseRoutes: &advertiseRoutes,
AdvertiseDefaultRoute: &advertiseDefaultRoute,
Role: &role,
},
},
}
key, _ := wgtypes.GeneratePrivateKey()
mesh, _ := factory.CreateMesh(&mesh.MeshProviderFactoryParams{
DevName: "bob",
MeshId: "meshid123",
Client: nil,
Conf: &factory.Config.BaseConfiguration,
DaemonConf: factory.Config,
NodeID: "bob",
})
publicKey := key.PublicKey()
return &TestParams{
manager: mesh,
publicKey: &publicKey,
}
}
func getOurNode(testParams *TestParams) *MeshNode {
return &MeshNode{
HostEndpoint: "public-endpoint:8080",
WgEndpoint: "public-endpoint:21906",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d218/128",
PublicKey: testParams.publicKey.String(),
Timestamp: time.Now().Unix(),
Description: "A node that we are adding",
Type: "peer",
}
}
func getRandomNode() *MeshNode {
key, _ := wgtypes.GeneratePrivateKey()
publicKey := key.PublicKey()
return &MeshNode{
HostEndpoint: "public-endpoint:8081",
WgEndpoint: "public-endpoint:21907",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d234/128",
PublicKey: publicKey.String(),
Timestamp: time.Now().Unix(),
Description: "A node that we are adding",
Type: "peer",
}
}
func TestAddNodeAddsTheNodesToTheStore(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
if !testParams.manager.NodeExists(testParams.publicKey.String()) {
t.Fatalf(`node %s should have been added to the mesh network`, testParams.publicKey.String())
}
}
func TestAddNodeNodeAlreadyExistsReplacesTheNode(t *testing.T) {
TestAddNodeAddsTheNodesToTheStore(t)
TestAddNodeAddsTheNodesToTheStore(t)
}
func TestSaveThenLoad(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
bytes := testParams.manager.Save()
if err := testParams.manager.Load(bytes); err != nil {
t.Fatalf(`error caused by loading datastore: %s`, err.Error())
}
}
func TestHasChangesReturnsTrueWhenThereAreChangesInTheMesh(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
if !testParams.manager.HasChanges() {
t.Fatalf(`mesh has change but HasChanges returned false`)
}
testParams.manager.SetDescription(testParams.publicKey.String(), "Bob marley")
if !testParams.manager.HasChanges() {
t.Fatalf(`mesh has change but HasChanges returned false`)
}
testParams.manager.SaveChanges()
}
func TestHasChangesWhenThereAreNoChangesInTheMeshReturnsFalse(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.SaveChanges()
if testParams.manager.HasChanges() {
t.Fatalf(`mesh has no changes but HasChanges was true`)
}
testParams.manager.SetDescription(testParams.publicKey.String(), "Bob marley")
testParams.manager.SaveChanges()
if testParams.manager.HasChanges() {
t.Fatalf(`mesh has no changes but HasChanges was true`)
}
}
func TestUpdateTimeStampUpdatesTheTimeStampOfTheGivenNodeIfItIsTheLeader(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
before, _ := testParams.manager.GetNode(testParams.publicKey.String())
time.Sleep(1 * time.Second)
testParams.manager.UpdateTimeStamp(testParams.publicKey.String())
after, _ := testParams.manager.GetNode(testParams.publicKey.String())
if before.GetTimeStamp() >= after.GetTimeStamp() {
t.Fatalf(`before should not be after after`)
}
}
func TestUpdateTimeStampUpdatesTheTimeStampOfTheGivenNodeIfItIsNotLeader(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
newNode := getRandomNode()
newNode.PublicKey = "aaaaaaaaaa"
testParams.manager.AddNode(newNode)
before, _ := testParams.manager.GetNode(testParams.publicKey.String())
time.Sleep(1 * time.Second)
after, _ := testParams.manager.GetNode(testParams.publicKey.String())
if before.GetTimeStamp() != after.GetTimeStamp() {
t.Fatalf(`before and after should be the same`)
}
}
func TestAddRoutesAddsARouteToTheGivenMesh(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
_, destination, _ := net.ParseCIDR("0353:1da7:7f33:acc0:7a3f:6e55:912b:bc1f/64")
testParams.manager.AddRoutes(testParams.publicKey.String(), &mesh.RouteStub{
Destination: destination,
Path: make([]string, 0),
})
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
containsDestination := lib.Contains(node.GetRoutes(), func(r mesh.Route) bool {
return r.GetDestination().Contains(destination.IP)
})
if !containsDestination {
t.Fatalf(`route has not been added to the node`)
}
}
func TestRemoveRoutesWithdrawsRoutesFromTheMesh(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
_, destination, _ := net.ParseCIDR("0353:1da7:7f33:acc0:7a3f:6e55:912b:bc1f/64")
route := &mesh.RouteStub{
Destination: destination,
Path: make([]string, 0),
}
testParams.manager.AddRoutes(testParams.publicKey.String(), route)
testParams.manager.RemoveRoutes(testParams.publicKey.String(), route)
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
containsDestination := lib.Contains(node.GetRoutes(), func(r mesh.Route) bool {
return r.GetDestination().Contains(destination.IP)
})
if containsDestination {
t.Fatalf(`route has not been removed from the node`)
}
}
func TestGetNodeGetsTheNodeWhenItExists(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
if node == nil {
t.Fatalf(`node not found returned nil`)
}
}
func TestGetNodeReturnsNilWhenItDoesNotExist(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.RemoveNode(testParams.publicKey.String())
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
if node != nil {
t.Fatalf(`node found but should be nil`)
}
}
func TestNodeExistsReturnsFalseWhenNotExists(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.RemoveNode(testParams.publicKey.String())
if testParams.manager.NodeExists(testParams.publicKey.String()) {
t.Fatalf(`nodeexists should be false`)
}
}
func TestSetDescriptionReturnsErrorWhenNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.SetDescription("djdjdj", "djdsjkd")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestSetDescriptionSetsTheDescription(t *testing.T) {
testParams := setUpTests()
descriptionToSet := "djdsjkd"
testParams.manager.AddNode(getOurNode(testParams))
err := testParams.manager.SetDescription(testParams.publicKey.String(), descriptionToSet)
if err != nil {
t.Fatalf(`error %s thrown`, err.Error())
}
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
description := node.GetDescription()
if description != descriptionToSet {
t.Fatalf(`description was %s should be %s`, description, descriptionToSet)
}
}
func TestAliasNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.SetAlias("djdjdj", "djdsjkd")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestSetAliasSetsAlias(t *testing.T) {
testParams := setUpTests()
aliasToSet := "djdsjkd"
testParams.manager.AddNode(getOurNode(testParams))
err := testParams.manager.SetAlias(testParams.publicKey.String(), aliasToSet)
if err != nil {
t.Fatalf(`error %s thrown`, err.Error())
}
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
alias := node.GetAlias()
if alias != aliasToSet {
t.Fatalf(`description was %s should be %s`, alias, aliasToSet)
}
}
func TestAddServiceNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.AddService("djdjdj", "djdsjkd", "sddsds")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestAddServiceNodeExists(t *testing.T) {
testParams := setUpTests()
service := "djdsjkd"
serviceValue := "dsdsds"
testParams.manager.AddNode(getOurNode(testParams))
err := testParams.manager.AddService(testParams.publicKey.String(), service, serviceValue)
if err != nil {
t.Fatalf(`error %s thrown`, err.Error())
}
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
services := node.GetServices()
if value, ok := services[service]; !ok || value != serviceValue {
t.Fatalf(`service not added to the data store`)
}
}
func TestRemoveServiceDoesNotExists(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.RemoveService("djdjdj", "dsdssd")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestRemoveServiceServiceDoesNotExist(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
if err := testParams.manager.RemoveService(testParams.publicKey.String(), "dhsdh"); err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestGetPeersReturnsAllPeersInTheMesh(t *testing.T) {
testParams := setUpTests()
peer1 := getRandomNode()
peer2 := getRandomNode()
client := getRandomNode()
client.Type = "client"
testParams.manager.AddNode(peer1)
testParams.manager.AddNode(peer2)
testParams.manager.AddNode(client)
peers := testParams.manager.GetPeers()
slices.Sort(peers)
if len(peers) != 2 {
t.Fatalf(`there should be two peers in the mesh`)
}
peer1Pub, _ := peer1.GetPublicKey()
if !slices.Contains(peers, peer1Pub.String()) {
t.Fatalf(`peer1 not in the list`)
}
peer2Pub, _ := peer2.GetPublicKey()
if !slices.Contains(peers, peer2Pub.String()) {
t.Fatalf(`peer2 not in the list`)
}
}
func TestRemoveNodeReturnsErrorIfNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.RemoveNode("dsjdssjk")
if err == nil {
t.Fatalf(`error should have returned`)
}
}

88
pkg/crdt/factory.go Normal file
View File

@ -0,0 +1,88 @@
package crdt
import (
"fmt"
"hash/fnv"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/mesh"
)
// TwoPhaseMapFactory: instantiate a new twophasemap
// datastore
type TwoPhaseMapFactory struct {
Config *conf.DaemonConfiguration
}
// CreateMesh: create a new mesh network
func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) {
return &TwoPhaseStoreMeshManager{
MeshId: params.MeshId,
IfName: params.DevName,
Client: params.Client,
Conf: params.Conf,
DaemonConf: params.DaemonConf,
store: NewTwoPhaseMap[string, MeshNode](params.NodeID, func(s string) uint64 {
h := fnv.New64a()
h.Write([]byte(s))
return h.Sum64()
}, uint64(3*f.Config.Heartbeat)),
}, nil
}
// MeshNodeFactory: create a new node in the mesh network
type MeshNodeFactory struct {
Config conf.DaemonConfiguration
}
// Build: build a new mesh network
func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode {
hostName := f.getAddress(params)
grpcEndpoint := fmt.Sprintf("%s:%d", hostName, f.Config.GrpcPort)
wgEndpoint := fmt.Sprintf("%s:%d", hostName, params.WgPort)
if *params.MeshConfig.Role == conf.CLIENT_ROLE {
grpcEndpoint = "-"
wgEndpoint = "-"
}
return &MeshNode{
HostEndpoint: grpcEndpoint,
PublicKey: params.PublicKey.String(),
WgEndpoint: wgEndpoint,
WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()),
Routes: make(map[string]Route),
Description: "",
Alias: "",
Type: string(*params.MeshConfig.Role),
}
}
// getAddress returns the routable address of the machine.
func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string {
var hostName string = ""
if params.Endpoint != "" {
hostName = params.Endpoint
} else if params.MeshConfig.Endpoint != nil && len(*params.MeshConfig.Endpoint) != 0 {
hostName = *params.MeshConfig.Endpoint
} else {
ipFunc := lib.GetPublicIP
if *params.MeshConfig.IPDiscovery == conf.OUTGOING_IP_DISCOVERY {
ipFunc = lib.GetOutboundIP
}
ip, err := ipFunc()
if err != nil {
return ""
}
hostName = ip.String()
}
return hostName
}

197
pkg/crdt/g_map.go Normal file
View File

@ -0,0 +1,197 @@
// crdt provides go implementations for crdts
package crdt
import (
"cmp"
"sync"
)
// Bucket: bucket represents a value in the grow only map
type Bucket[D any] struct {
Vector uint64
Contents D
Gravestone bool
}
// GMap is a set that can only grow in size
type GMap[K cmp.Ordered, D any] struct {
lock sync.RWMutex
contents map[uint64]Bucket[D]
clock *VectorClock[K]
}
// Put: put a new entry in the grow-only-map
func (g *GMap[K, D]) Put(key K, value D) {
g.lock.Lock()
clock := g.clock.IncrementClock()
g.contents[g.clock.hashFunc(key)] = Bucket[D]{
Vector: clock,
Contents: value,
}
g.lock.Unlock()
}
// Contains: returns whether or not the key is contained
// in the g-map
func (g *GMap[K, D]) Contains(key K) bool {
return g.contains(g.clock.hashFunc(key))
}
func (g *GMap[K, D]) contains(key uint64) bool {
g.lock.RLock()
_, ok := g.contents[key]
g.lock.RUnlock()
return ok
}
func (g *GMap[K, D]) put(key uint64, b Bucket[D]) {
g.lock.Lock()
if g.contents[key].Vector < b.Vector {
g.contents[key] = b
}
g.lock.Unlock()
}
func (g *GMap[K, D]) get(key uint64) Bucket[D] {
g.lock.RLock()
bucket := g.contents[key]
g.lock.RUnlock()
return bucket
}
// Get: get the value associated with the given key
func (g *GMap[K, D]) Get(key K) D {
if !g.Contains(key) {
var def D
return def
}
return g.get(g.clock.hashFunc(key)).Contents
}
// Mark: marks the node, this means the status of the node
// is an undefined state
func (g *GMap[K, D]) Mark(key K) {
if !g.Contains(key) {
return
}
g.lock.Lock()
bucket := g.contents[g.clock.hashFunc(key)]
bucket.Gravestone = true
g.contents[g.clock.hashFunc(key)] = bucket
g.lock.Unlock()
}
// IsMarked: returns true if the node is marked (in an undefined state)
func (g *GMap[K, D]) IsMarked(key K) bool {
marked := false
g.lock.RLock()
bucket, ok := g.contents[g.clock.hashFunc(key)]
if ok {
marked = bucket.Gravestone
}
g.lock.RUnlock()
return marked
}
// Keys: return all the keys in the grow-only map
func (g *GMap[K, D]) Keys() []uint64 {
g.lock.RLock()
contents := make([]uint64, len(g.contents))
index := 0
for key := range g.contents {
contents[index] = key
index++
}
g.lock.RUnlock()
return contents
}
// Save: saves the grow only map
func (g *GMap[K, D]) Save() map[uint64]Bucket[D] {
buckets := make(map[uint64]Bucket[D])
g.lock.RLock()
for key, value := range g.contents {
buckets[key] = value
}
g.lock.RUnlock()
return buckets
}
// SaveWithKeys: get all the values corresponding with the provided keys
func (g *GMap[K, D]) SaveWithKeys(keys []uint64) map[uint64]Bucket[D] {
buckets := make(map[uint64]Bucket[D])
g.lock.RLock()
for _, key := range keys {
buckets[key] = g.contents[key]
}
g.lock.RUnlock()
return buckets
}
// GetClock: get all the vector clocks in the g_map
func (g *GMap[K, D]) GetClock() map[uint64]uint64 {
clock := make(map[uint64]uint64)
g.lock.RLock()
for key, bucket := range g.contents {
clock[key] = bucket.Vector
}
g.lock.RUnlock()
return clock
}
// GetHash: get the hash of the g_map representing its state
func (g *GMap[K, D]) GetHash() uint64 {
hash := uint64(0)
g.lock.RLock()
for _, value := range g.contents {
hash += value.Vector
}
g.lock.RUnlock()
return hash
}
// Prune: prune all stale entries
func (g *GMap[K, D]) Prune() {
stale := g.clock.getStale()
g.lock.Lock()
for _, outlier := range stale {
delete(g.contents, outlier)
}
g.lock.Unlock()
}
func NewGMap[K cmp.Ordered, D any](clock *VectorClock[K]) *GMap[K, D] {
return &GMap[K, D]{
contents: make(map[uint64]Bucket[D]),
clock: clock,
}
}

224
pkg/crdt/g_map_test.go Normal file
View File

@ -0,0 +1,224 @@
// crdt_test unit tests the crdt implementations
package crdt
import (
"hash/fnv"
"slices"
"testing"
"time"
"github.com/tim-beatham/smegmesh/pkg/lib"
)
func NewGmap() *GMap[string, bool] {
vectorClock := NewVectorClock("a", func(key string) uint64 {
hash := fnv.New64a()
hash.Write([]byte(key))
return hash.Sum64()
}, 1) // 1 second stale time
gMap := NewGMap[string, bool](vectorClock)
return gMap
}
func TestGMapPutInsertsItems(t *testing.T) {
gMap := NewGmap()
gMap.Put("bruh1234", true)
if !gMap.Contains("bruh1234") {
t.Fatalf(`value not added to map`)
}
}
func TestGMapPutReplacesItems(t *testing.T) {
gMap := NewGmap()
gMap.Put("bruh1234", true)
gMap.Put("bruh1234", false)
value := gMap.Get("bruh1234")
if value {
t.Fatalf(`value should ahve been replaced to false`)
}
}
func TestContainsValueNotPresent(t *testing.T) {
gMap := NewGmap()
if gMap.Contains("sdhjsdhsdj") {
t.Fatalf(`value should not be present in the map`)
}
}
func TestContainsValuePresent(t *testing.T) {
gMap := NewGmap()
key := "hehehehe"
gMap.Put(key, false)
if !gMap.Contains(key) {
t.Fatalf(`%s should not be present in the map`, key)
}
}
func TestGMapGetNotPresentReturnsError(t *testing.T) {
gMap := NewGmap()
value := gMap.Get("bruh123")
if value != false {
t.Fatalf(`value should be default type false`)
}
}
func TestGMapGetReturnsValue(t *testing.T) {
gMap := NewGmap()
gMap.Put("bobdylan", true)
value := gMap.Get("bobdylan")
if !value {
t.Fatalf("value should be true but was false")
}
}
func TestMarkMarksTheValue(t *testing.T) {
gMap := NewGmap()
gMap.Put("hello123", true)
gMap.Mark("hello123")
if !gMap.IsMarked("hello123") {
t.Fatal(`hello123 should be marked`)
}
}
func TestMarkValueNotPresent(t *testing.T) {
gMap := NewGmap()
gMap.Mark("ok123456")
}
func TestKeysMapEmpty(t *testing.T) {
gMap := NewGmap()
keys := gMap.Keys()
if len(keys) != 0 {
t.Fatal(`list of keys was not empty but should be empty`)
}
}
func TestKeysMapReturnsKeysInMap(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
keys := gMap.Keys()
if len(keys) != 3 {
t.Fatal(`key length should be 3`)
}
}
func TestSaveMapEmptyReturnsEmptyMap(t *testing.T) {
gMap := NewGmap()
saveMap := gMap.Save()
if len(saveMap) != 0 {
t.Fatal(`saves should be empty`)
}
}
func TestSaveMapReturnsMapOfBuckets(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
saveMap := gMap.Save()
if len(saveMap) != 3 {
t.Fatalf(`save length should be 3`)
}
}
func TestSaveWithKeysNoKeysReturnsEmptyBucket(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
saveMap := gMap.SaveWithKeys([]uint64{})
if len(saveMap) != 0 {
t.Fatalf(`save map should be empty`)
}
}
func TestSaveWithKeysReturnsIntersection(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
clock := lib.MapKeys(gMap.GetClock())
clock = clock[:len(clock)-1]
values := gMap.SaveWithKeys(clock)
if len(values) != len(clock) {
t.Fatalf(`intersection not returned`)
}
}
func TestGetClockMapEmptyReturnsEmptyClock(t *testing.T) {
gMap := NewGmap()
clocks := gMap.GetClock()
if len(clocks) != 0 {
t.Fatalf(`vector clock is not empty`)
}
}
func TestGetClockReturnsAllCLocks(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
clocks := lib.MapValues(gMap.GetClock())
slices.Sort(clocks)
if !slices.Equal([]uint64{0, 1, 2}, clocks) {
t.Fatalf(`clocks are invalid`)
}
}
func TestGetHashChangesHashOnValueAdded(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
prevHash := gMap.GetHash()
gMap.Put("b", true)
if prevHash == gMap.GetHash() {
t.Fatalf(`hash should be different`)
}
}
func TestPruneGarbageCollectsValuesThatHaveNotBeenUpdated(t *testing.T) {
gMap := NewGmap()
gMap.clock.Put("c", 12)
gMap.Put("c", false)
gMap.Put("a", false)
time.Sleep(4 * time.Second)
gMap.Put("a", true)
gMap.Prune()
if gMap.Contains("c") {
t.Fatalf(`a should have been pruned`)
}
}

229
pkg/crdt/two_phase_map.go Normal file
View File

@ -0,0 +1,229 @@
package crdt
import (
"cmp"
"github.com/tim-beatham/smegmesh/pkg/lib"
)
// TwoPhaseMap: comprises of two grow-only maps
type TwoPhaseMap[K cmp.Ordered, D any] struct {
addMap *GMap[K, D]
removeMap *GMap[K, bool]
Clock *VectorClock[K]
processId K
}
type TwoPhaseMapSnapshot[K cmp.Ordered, D any] struct {
Add map[uint64]Bucket[D]
Remove map[uint64]Bucket[bool]
}
// Contains checks whether the value exists in the map
func (m *TwoPhaseMap[K, D]) Contains(key K) bool {
return m.contains(m.Clock.hashFunc(key))
}
// contains: checks whether the key exists in the map
func (m *TwoPhaseMap[K, D]) contains(key uint64) bool {
if !m.addMap.contains(key) {
return false
}
addValue := m.addMap.get(key)
if !m.removeMap.contains(key) {
return true
}
removeValue := m.removeMap.get(key)
return addValue.Vector >= removeValue.Vector
}
// Get: get the value corresponding with the given key
func (m *TwoPhaseMap[K, D]) Get(key K) D {
var result D
if !m.Contains(key) {
return result
}
return m.addMap.Get(key)
}
func (m *TwoPhaseMap[K, D]) get(key uint64) D {
var result D
if !m.contains(key) {
return result
}
return m.addMap.get(key).Contents
}
// Put: places the key K in the map with the associated data D
func (m *TwoPhaseMap[K, D]) Put(key K, data D) {
msgSequence := m.Clock.IncrementClock()
m.Clock.Put(key, msgSequence)
m.addMap.Put(key, data)
}
// Mark: marks the status of the node as undetermiend
func (m *TwoPhaseMap[K, D]) Mark(key K) {
m.addMap.Mark(key)
}
// Remove: removes the value from the map
func (m *TwoPhaseMap[K, D]) Remove(key K) {
m.removeMap.Put(key, true)
}
func (m *TwoPhaseMap[K, D]) keys() []uint64 {
keys := make([]uint64, 0)
addKeys := m.addMap.Keys()
for _, key := range addKeys {
if !m.contains(key) {
continue
}
keys = append(keys, key)
}
return keys
}
// AsList: convert the map to a list
func (m *TwoPhaseMap[K, D]) AsList() []D {
theList := make([]D, 0)
keys := m.keys()
for _, key := range keys {
theList = append(theList, m.get(key))
}
return theList
}
// Snapshot: convert the map into an immutable snapshot.
// contains the contents of the add and remove map
func (m *TwoPhaseMap[K, D]) Snapshot() *TwoPhaseMapSnapshot[K, D] {
return &TwoPhaseMapSnapshot[K, D]{
Add: m.addMap.Save(),
Remove: m.removeMap.Save(),
}
}
// SnapshotFromState: create a snapshot of the intersection of values provided
// in the given state
func (m *TwoPhaseMap[K, D]) SnapShotFromState(state *TwoPhaseMapState[K]) *TwoPhaseMapSnapshot[K, D] {
addKeys := lib.MapKeys(state.AddContents)
removeKeys := lib.MapKeys(state.RemoveContents)
return &TwoPhaseMapSnapshot[K, D]{
Add: m.addMap.SaveWithKeys(addKeys),
Remove: m.removeMap.SaveWithKeys(removeKeys),
}
}
// TwoPhaseMapState: encapsulates the state of the map
// without specifying the data that is stored
type TwoPhaseMapState[K cmp.Ordered] struct {
// Vectors: the vector ID of each process
Vectors map[uint64]uint64
// AddContents: the contents of the add map
AddContents map[uint64]uint64
// RemoveContents: the contents of the remove map
RemoveContents map[uint64]uint64
}
// IsMarked: returns true if the given value is marked in an undetermined state
func (m *TwoPhaseMap[K, D]) IsMarked(key K) bool {
return m.addMap.IsMarked(key)
}
// GetHash: Get the hash of the current state of the map
// Sums the current values of the vectors. Provides good approximation
// of increasing numbers
func (m *TwoPhaseMap[K, D]) GetHash() uint64 {
return (m.addMap.GetHash() + 1) * (m.removeMap.GetHash() + 1)
}
// GetState: get the current vector clock of the add and remove
// map
func (m *TwoPhaseMap[K, D]) GenerateMessage() *TwoPhaseMapState[K] {
addContents := m.addMap.GetClock()
removeContents := m.removeMap.GetClock()
return &TwoPhaseMapState[K]{
Vectors: m.Clock.GetClock(),
AddContents: addContents,
RemoveContents: removeContents,
}
}
// Difference: compute the set difference between the two states.
// highestStale represents the highest vector clock that has been marked as stale
func (m *TwoPhaseMapState[K]) Difference(highestStale uint64, state *TwoPhaseMapState[K]) *TwoPhaseMapState[K] {
mapState := &TwoPhaseMapState[K]{
AddContents: make(map[uint64]uint64),
RemoveContents: make(map[uint64]uint64),
}
for key, value := range state.AddContents {
otherValue, ok := m.AddContents[key]
if value > highestStale && (!ok || otherValue < value) {
mapState.AddContents[key] = value
}
}
for key, value := range state.RemoveContents {
otherValue, ok := m.RemoveContents[key]
if value > highestStale && (!ok || otherValue < value) {
mapState.RemoveContents[key] = value
}
}
return mapState
}
// Merge: merge a snapshot into the map
func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) {
for key, value := range snapshot.Add {
// Gravestone is local only to that node.
// Discover ourselves if the node is alive
m.addMap.put(key, value)
m.Clock.put(key, value.Vector)
}
for key, value := range snapshot.Remove {
m.removeMap.put(key, value)
m.Clock.put(key, value.Vector)
}
}
// Prune: garbage collect all stale entries in the map
func (m *TwoPhaseMap[K, D]) Prune() {
m.addMap.Prune()
m.removeMap.Prune()
m.Clock.Prune()
}
// NewTwoPhaseMap: create a new two phase map. Consists of two maps
// a grow map and a remove map. If both timestamps equal then favour keeping
// it in the map
func NewTwoPhaseMap[K cmp.Ordered, D any](processId K, hashKey func(K) uint64, staleTime uint64) *TwoPhaseMap[K, D] {
m := TwoPhaseMap[K, D]{
processId: processId,
Clock: NewVectorClock(processId, hashKey, staleTime),
}
m.addMap = NewGMap[K, D](m.Clock)
m.removeMap = NewGMap[K, bool](m.Clock)
return &m
}

View File

@ -0,0 +1,190 @@
package crdt
import (
"bytes"
"encoding/gob"
logging "github.com/tim-beatham/smegmesh/pkg/log"
)
type SyncState int
const (
HASH SyncState = iota
PREPARE
PRESENT
EXCHANGE
MERGE
FINISHED
)
// TwoPhaseSyncer is a type to sync a TwoPhase data store
type TwoPhaseSyncer struct {
manager *TwoPhaseStoreMeshManager
generateMessageFSM SyncFSM
state SyncState
mapState *TwoPhaseMapState[string]
peerMsg []byte
}
type TwoPhaseHash struct {
Hash uint64
}
type SyncFSM map[SyncState]func(*TwoPhaseSyncer) ([]byte, bool)
func hash(syncer *TwoPhaseSyncer) ([]byte, bool) {
hash := TwoPhaseHash{
Hash: syncer.manager.store.Clock.GetHash(),
}
var buffer bytes.Buffer
enc := gob.NewEncoder(&buffer)
err := enc.Encode(hash)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
syncer.IncrementState()
return buffer.Bytes(), true
}
func prepare(syncer *TwoPhaseSyncer) ([]byte, bool) {
var recvBuffer = bytes.NewBuffer(syncer.peerMsg)
dec := gob.NewDecoder(recvBuffer)
var hash TwoPhaseHash
err := dec.Decode(&hash)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
// If vector clocks are equal then no need to merge state
// Helps to reduce bandwidth by detecting early
if hash.Hash == syncer.manager.store.Clock.GetHash() {
return nil, false
}
// Increment the clock here so the clock gets
// distributed to everyone else in the mesh
syncer.manager.store.Clock.IncrementClock()
var buffer bytes.Buffer
enc := gob.NewEncoder(&buffer)
mapState := syncer.manager.store.GenerateMessage()
syncer.mapState = mapState
err = enc.Encode(*syncer.mapState)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
syncer.IncrementState()
return buffer.Bytes(), true
}
func present(syncer *TwoPhaseSyncer) ([]byte, bool) {
if syncer.peerMsg == nil {
panic("peer msg is nil")
}
var recvBuffer = bytes.NewBuffer(syncer.peerMsg)
dec := gob.NewDecoder(recvBuffer)
var mapState TwoPhaseMapState[string]
err := dec.Decode(&mapState)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
difference := syncer.mapState.Difference(syncer.manager.store.Clock.GetStaleCount(), &mapState)
syncer.manager.store.Clock.Merge(mapState.Vectors)
var sendBuffer bytes.Buffer
enc := gob.NewEncoder(&sendBuffer)
enc.Encode(*difference)
syncer.IncrementState()
return sendBuffer.Bytes(), true
}
func exchange(syncer *TwoPhaseSyncer) ([]byte, bool) {
if syncer.peerMsg == nil {
panic("peer msg is nil")
}
var recvBuffer = bytes.NewBuffer(syncer.peerMsg)
dec := gob.NewDecoder(recvBuffer)
var mapState TwoPhaseMapState[string]
dec.Decode(&mapState)
snapshot := syncer.manager.store.SnapShotFromState(&mapState)
var sendBuffer bytes.Buffer
enc := gob.NewEncoder(&sendBuffer)
enc.Encode(*snapshot)
syncer.IncrementState()
return sendBuffer.Bytes(), true
}
func merge(syncer *TwoPhaseSyncer) ([]byte, bool) {
if syncer.peerMsg == nil {
panic("peer msg is nil")
}
var recvBuffer = bytes.NewBuffer(syncer.peerMsg)
dec := gob.NewDecoder(recvBuffer)
var snapshot TwoPhaseMapSnapshot[string, MeshNode]
dec.Decode(&snapshot)
syncer.manager.store.Merge(snapshot)
return nil, false
}
func (t *TwoPhaseSyncer) IncrementState() {
t.state = min(t.state+1, FINISHED)
}
func (t *TwoPhaseSyncer) GenerateMessage() ([]byte, bool) {
fsmFunc, ok := t.generateMessageFSM[t.state]
if !ok {
panic("state not handled")
}
return fsmFunc(t)
}
func (t *TwoPhaseSyncer) RecvMessage(msg []byte) error {
t.peerMsg = msg
return nil
}
func (t *TwoPhaseSyncer) Complete() {
logging.Log.WriteInfof("SYNC COMPLETED")
}
func NewTwoPhaseSyncer(manager *TwoPhaseStoreMeshManager) *TwoPhaseSyncer {
var generateMessageFsm SyncFSM = SyncFSM{
HASH: hash,
PREPARE: prepare,
PRESENT: present,
EXCHANGE: exchange,
MERGE: merge,
}
return &TwoPhaseSyncer{
manager: manager,
state: HASH,
generateMessageFSM: generateMessageFsm,
}
}

View File

@ -0,0 +1,214 @@
package crdt
import (
"hash/fnv"
"slices"
"testing"
)
func NewMap(processId string) *TwoPhaseMap[string, string] {
theMap := NewTwoPhaseMap[string, string](processId, func(key string) uint64 {
hash := fnv.New64a()
hash.Write([]byte(key))
return hash.Sum64()
}, 1)
return theMap
}
func TestTwoPhaseMapEmpty(t *testing.T) {
theMap := NewMap("a")
if theMap.Contains("a") {
t.Fatalf(`a should not be present in the map`)
}
}
func TestTwoPhaseMapValuePresent(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "")
if !theMap.Contains("a") {
t.Fatalf(`should be present within the map`)
}
}
func TestTwoPhaseMapValueNotPresent(t *testing.T) {
theMap := NewMap("a")
theMap.Put("b", "")
if theMap.Contains("a") {
t.Fatalf(`a should not be present in the map`)
}
}
func TestTwoPhaseMapPutThenRemove(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "")
theMap.Remove("a")
if theMap.Contains("a") {
t.Fatalf(`a should not be present within the map`)
}
}
func TestTwoPhaseMapPutThenRemoveThenPut(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "")
theMap.Remove("a")
theMap.Put("a", "")
if !theMap.Contains("a") {
t.Fatalf(`a should be present within the map`)
}
}
func TestMarkMarksTheValueIn2PMap(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "")
theMap.Mark("a")
if !theMap.IsMarked("a") {
t.Fatalf(`a should be marked`)
}
}
func TestAsListReturnsItemsInList(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "bob")
theMap.Put("b", "dylan")
keys := theMap.AsList()
slices.Sort(keys)
if !slices.Equal([]string{"bob", "dylan"}, keys) {
t.Fatalf(`values should be bob, dylan`)
}
}
func TestSnapShotRemoveMapEmpty(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "bob")
theMap.Put("b", "dylan")
snapshot := theMap.Snapshot()
if len(snapshot.Add) != 2 {
t.Fatalf(`add values length should be 2`)
}
if len(snapshot.Remove) != 0 {
t.Fatalf(`remove map length should be 0`)
}
}
func TestSnapshotMapEmpty(t *testing.T) {
theMap := NewMap("a")
snapshot := theMap.Snapshot()
if len(snapshot.Add) != 0 || len(snapshot.Remove) != 0 {
t.Fatalf(`snapshot length should be 0`)
}
}
func TestSnapShotFromStateReturnsIntersection(t *testing.T) {
map1 := NewMap("a")
map1.Put("a", "heyy")
map2 := NewMap("b")
map2.Put("b", "hmmm")
message := map2.GenerateMessage()
snapShot := map1.SnapShotFromState(message)
if len(snapShot.Add) != 1 {
t.Fatalf(`add length should be 1`)
}
if len(snapShot.Remove) != 0 {
t.Fatalf(`remove length should be 0`)
}
}
func TestGetHashDifferentOnChange(t *testing.T) {
theMap := NewMap("a")
prevHash := theMap.GetHash()
theMap.Put("b", "hmmhmhmh")
if prevHash == theMap.GetHash() {
t.Fatalf(`hashes should not be the same`)
}
}
func TestGenerateMessageReturnsClocks(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "hmm")
theMap.Put("b", "hmm")
theMap.Remove("a")
message := theMap.GenerateMessage()
if len(message.AddContents) != 2 {
t.Fatalf(`two items added add should be 2`)
}
if len(message.RemoveContents) != 1 {
t.Fatalf(`a was removed remove map should be length 1`)
}
}
func TestDifferenceReturnsDifferenceOfMaps(t *testing.T) {
map1 := NewMap("a")
map1.Put("a", "ssms")
map1.Put("b", "sdmdsmd")
map2 := NewMap("b")
map2.Put("d", "eek")
map2.Put("c", "meh")
message1 := map1.GenerateMessage()
message2 := map2.GenerateMessage()
difference := message1.Difference(0, message2)
if len(difference.AddContents) != 2 {
t.Fatalf(`d and c are not in map1 they should be in add contents`)
}
if len(difference.RemoveContents) != 0 {
t.Fatalf(`remove should be empty`)
}
}
func TestMergeMergesValuesThatAreGreaterThanCurrentClock(t *testing.T) {
map1 := NewMap("a")
map1.Put("a", "ssms")
map1.Put("b", "sdmdsmd")
map2 := NewMap("b")
map2.Put("d", "eek")
map2.Put("c", "meh")
message1 := map1.GenerateMessage()
message2 := map2.GenerateMessage()
difference := message1.Difference(0, message2)
state := map2.SnapShotFromState(difference)
map1.Merge(*state)
if !map1.Contains("d") {
t.Fatalf(`d should be in the map`)
}
if !map2.Contains("c") {
t.Fatalf(`c should be in the map`)
}
}

180
pkg/crdt/vector_clock.go Normal file
View File

@ -0,0 +1,180 @@
package crdt
import (
"cmp"
"sync"
"time"
"github.com/tim-beatham/smegmesh/pkg/lib"
)
// VectorBucket: represents a vector clock in the bucket
// recording both the time changes were last seen
// and when the lastUpdate epoch was recorded
type VectorBucket struct {
// clock current value of the node's clock
clock uint64
// lastUpdate we've seen
lastUpdate uint64
}
// VectorClock: defines an abstract data type
// for a vector clock implementation. Including a mechanism to
// garbage collect stale entries
type VectorClock[K cmp.Ordered] struct {
vectors map[uint64]*VectorBucket
lock sync.RWMutex
processID K
staleTime uint64
hashFunc func(K) uint64
// highest update that's been garbage collected
highestStale uint64
}
// IncrementClock: increments the node's value in the vector clock
func (m *VectorClock[K]) IncrementClock() uint64 {
maxClock := uint64(0)
m.lock.Lock()
for _, value := range m.vectors {
maxClock = max(maxClock, value.clock)
}
newBucket := VectorBucket{
clock: maxClock + 1,
lastUpdate: uint64(time.Now().Unix()),
}
m.vectors[m.hashFunc(m.processID)] = &newBucket
m.lock.Unlock()
return maxClock
}
// GetHash: gets the hash of the vector clock used to determine if there
// are any changes
func (m *VectorClock[K]) GetHash() uint64 {
m.lock.RLock()
hash := uint64(0)
for key, bucket := range m.vectors {
hash += key * (bucket.clock + 1)
}
m.lock.RUnlock()
return hash
}
// Merge: merge two clocks together
func (m *VectorClock[K]) Merge(vectors map[uint64]uint64) {
for key, value := range vectors {
m.put(key, value)
}
}
// getStale: get all entries that are stale within the mesh
func (m *VectorClock[K]) getStale() []uint64 {
m.lock.RLock()
maxTimeStamp := lib.Reduce(0, lib.MapValues(m.vectors), func(i uint64, vb *VectorBucket) uint64 {
return max(i, vb.lastUpdate)
})
toRemove := make([]uint64, 0)
for key, bucket := range m.vectors {
if maxTimeStamp-bucket.lastUpdate > m.staleTime {
toRemove = append(toRemove, key)
m.highestStale = max(bucket.clock, m.highestStale)
}
}
m.lock.RUnlock()
return toRemove
}
// GetStaleCount: returns a vector clock which is considered to be stale.
// all updates must be greater than this
func (m *VectorClock[K]) GetStaleCount() uint64 {
m.lock.RLock()
staleCount := m.highestStale
m.lock.RUnlock()
return staleCount
}
// Prune: prunes all stale entries in the vector clock
func (m *VectorClock[K]) Prune() {
stale := m.getStale()
m.lock.Lock()
for _, key := range stale {
delete(m.vectors, key)
}
m.lock.Unlock()
}
// GetTimeStamp: get the last time the node was updated in UNIX
// epoch time
func (m *VectorClock[K]) GetTimestamp(processId K) uint64 {
m.lock.RLock()
lastUpdate := m.vectors[m.hashFunc(m.processID)].lastUpdate
m.lock.RUnlock()
return lastUpdate
}
// Put: places the key with vector clock in the clock of the given
// process
func (m *VectorClock[K]) Put(key K, value uint64) {
m.put(m.hashFunc(key), value)
}
func (m *VectorClock[K]) put(key uint64, value uint64) {
clockValue := uint64(0)
m.lock.Lock()
bucket, ok := m.vectors[key]
if ok {
clockValue = bucket.clock
}
// Make sure that entries that were garbage collected don't get
// highestStale represents the highest vector clock that has been
// invalidated
if value > clockValue && value > m.highestStale {
newBucket := VectorBucket{
clock: value,
lastUpdate: uint64(time.Now().Unix()),
}
m.vectors[key] = &newBucket
}
m.lock.Unlock()
}
// GetClock: serialize the vector clock into an immutable map
func (m *VectorClock[K]) GetClock() map[uint64]uint64 {
clock := make(map[uint64]uint64)
m.lock.RLock()
for key, value := range m.vectors {
clock[key] = value.clock
}
m.lock.RUnlock()
return clock
}
func NewVectorClock[K cmp.Ordered](processID K, hashFunc func(K) uint64, staleTime uint64) *VectorClock[K] {
return &VectorClock[K]{
vectors: make(map[uint64]*VectorBucket),
processID: processID,
staleTime: staleTime,
hashFunc: hashFunc,
}
}

View File

@ -1,22 +1,23 @@
package ctrlserver package ctrlserver
import ( import (
crdt "github.com/tim-beatham/wgmesh/pkg/automerge" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/conn" "github.com/tim-beatham/smegmesh/pkg/crdt"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query" "github.com/tim-beatham/smegmesh/pkg/query"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/smegmesh/pkg/rpc"
"github.com/tim-beatham/wgmesh/pkg/wg" "github.com/tim-beatham/smegmesh/pkg/sync"
"github.com/tim-beatham/smegmesh/pkg/wg"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
) )
// NewCtrlServerParams are the params requried to create a new ctrl server // NewCtrlServerParams are the params requried to create a new ctrl server
type NewCtrlServerParams struct { type NewCtrlServerParams struct {
Conf *conf.WgMeshConfiguration Conf *conf.DaemonConfiguration
Client *wgctrl.Client Client *wgctrl.Client
CtrlProvider rpc.MeshCtrlServerServer CtrlProvider rpc.MeshCtrlServerServer
SyncProvider rpc.SyncServiceServer SyncProvider rpc.SyncServiceServer
@ -27,25 +28,38 @@ type NewCtrlServerParams struct {
// operation failed // operation failed
func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) { func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
ctrlServer := new(MeshCtrlServer) ctrlServer := new(MeshCtrlServer)
meshFactory := crdt.CrdtProviderFactory{} meshFactory := &crdt.TwoPhaseMapFactory{
nodeFactory := crdt.MeshNodeFactory{ Config: params.Conf,
}
nodeFactory := &crdt.MeshNodeFactory{
Config: *params.Conf, Config: *params.Conf,
} }
idGenerator := &lib.UUIDGenerator{} idGenerator := &lib.ShortIDGenerator{}
ipAllocator := &ip.ULABuilder{} ipAllocator := &ip.ULABuilder{}
interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client) interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client)
ctrlServer.timers = make([]*lib.Timer, 0)
configApplyer := mesh.NewWgMeshConfigApplyer() configApplyer := mesh.NewWgMeshConfigApplyer()
var syncer sync.Syncer
meshManagerParams := &mesh.NewMeshManagerParams{ meshManagerParams := &mesh.NewMeshManagerParams{
Conf: *params.Conf, Conf: *params.Conf,
Client: params.Client, Client: params.Client,
MeshProvider: &meshFactory, MeshProvider: meshFactory,
NodeFactory: &nodeFactory, NodeFactory: nodeFactory,
IdGenerator: idGenerator, IdGenerator: idGenerator,
IPAllocator: ipAllocator, IPAllocator: ipAllocator,
InterfaceManipulator: interfaceManipulator, InterfaceManipulator: interfaceManipulator,
ConfigApplyer: configApplyer, ConfigApplyer: configApplyer,
OnDelete: func(mesh mesh.MeshProvider) {
_, err := syncer.Sync(mesh)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
},
} }
ctrlServer.MeshManager = mesh.NewMeshManager(meshManagerParams) ctrlServer.MeshManager = mesh.NewMeshManager(meshManagerParams)
@ -79,13 +93,41 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
return nil, err return nil, err
} }
syncer = sync.NewSyncer(&sync.NewSyncerParams{
MeshManager: ctrlServer.MeshManager,
ConnectionManager: ctrlServer.ConnectionManager,
Configuration: params.Conf,
})
// Check any syncs every 1 second
syncTimer := lib.NewTimer(func() error {
err = syncer.SyncMeshes()
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
return nil
}, 1)
heartbeatTimer := lib.NewTimer(func() error {
logging.Log.WriteInfof("checking heartbeat")
return ctrlServer.MeshManager.UpdateTimeStamp()
}, params.Conf.Heartbeat)
ctrlServer.timers = append(ctrlServer.timers, syncTimer, heartbeatTimer)
ctrlServer.Querier = query.NewJmesQuerier(ctrlServer.MeshManager) ctrlServer.Querier = query.NewJmesQuerier(ctrlServer.MeshManager)
ctrlServer.ConnectionServer = connServer ctrlServer.ConnectionServer = connServer
for _, timer := range ctrlServer.timers {
go timer.Run()
}
return ctrlServer, nil return ctrlServer, nil
} }
func (s *MeshCtrlServer) GetConfiguration() *conf.WgMeshConfiguration { func (s *MeshCtrlServer) GetConfiguration() *conf.DaemonConfiguration {
return s.Conf return s.Conf
} }
@ -119,5 +161,13 @@ func (s *MeshCtrlServer) Close() error {
logging.Log.WriteErrorf(err.Error()) logging.Log.WriteErrorf(err.Error())
} }
for _, timer := range s.timers {
err := timer.Stop()
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
}
return nil return nil
} }

View File

@ -1,32 +1,57 @@
package ctrlserver package ctrlserver
import ( import (
"github.com/tim-beatham/wgmesh/pkg/conf" "net"
"github.com/tim-beatham/wgmesh/pkg/conn" "time"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/query"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// Represents a WireGuard MeshNode // MeshRoute: represents a route in the mesh that is
// available to client applications
type MeshRoute struct {
Destination string
Path []string
}
// WireGuardStats: Represents the WireGuard configuration attached to the node
type WireGuardStats struct {
AllowedIPs []string
TransmitBytes int64
ReceivedBytes int64
PersistentKeepAliveInterval time.Duration
}
// MeshNode: represents a node in the WireGuard mesh that can be
// sent to ip chandlers
type MeshNode struct { type MeshNode struct {
HostEndpoint string HostEndpoint string
WgEndpoint string WgEndpoint string
PublicKey string PublicKey string
WgHost string WgHost string
Timestamp int64 Timestamp int64
Routes []string Routes []MeshRoute
Description string
Alias string
Services map[string]string
Stats WireGuardStats
} }
// Represents a WireGuard Mesh // Mesh: Represents a WireGuard Mesh network that can be sent
// along ipc to client frameworks
type Mesh struct { type Mesh struct {
SharedKey *wgtypes.Key
Nodes map[string]MeshNode Nodes map[string]MeshNode
} }
// CtrlServer: Encapsulates th ctrlserver
type CtrlServer interface { type CtrlServer interface {
GetConfiguration() *conf.WgMeshConfiguration GetConfiguration() *conf.DaemonConfiguration
GetClient() *wgctrl.Client GetClient() *wgctrl.Client
GetQuerier() query.Querier GetQuerier() query.Querier
GetMeshManager() mesh.MeshManager GetMeshManager() mesh.MeshManager
@ -34,12 +59,63 @@ type CtrlServer interface {
GetConnectionManager() conn.ConnectionManager GetConnectionManager() conn.ConnectionManager
} }
// Represents a ctrlserver to be used in WireGuard // MeshCtrlServer: Represents a ctrlserver to be used in WireGuard
type MeshCtrlServer struct { type MeshCtrlServer struct {
Client *wgctrl.Client Client *wgctrl.Client
MeshManager mesh.MeshManager MeshManager mesh.MeshManager
ConnectionManager conn.ConnectionManager ConnectionManager conn.ConnectionManager
ConnectionServer *conn.ConnectionServer ConnectionServer *conn.ConnectionServer
Conf *conf.WgMeshConfiguration Conf *conf.DaemonConfiguration
Querier query.Querier Querier query.Querier
timers []*lib.Timer
}
// NewCtrlNode create an instance of a ctrl node to send over an
// IPC call
func NewCtrlNode(provider mesh.MeshProvider, node mesh.MeshNode) *MeshNode {
pubKey, _ := node.GetPublicKey()
ctrlNode := MeshNode{
HostEndpoint: node.GetHostEndpoint(),
WgEndpoint: node.GetWgEndpoint(),
PublicKey: pubKey.String(),
WgHost: node.GetWgHost().String(),
Timestamp: node.GetTimeStamp(),
Routes: lib.Map(node.GetRoutes(), func(r mesh.Route) MeshRoute {
return MeshRoute{
Destination: r.GetDestination().String(),
Path: r.GetPath(),
}
}),
Description: node.GetDescription(),
Alias: node.GetAlias(),
Services: node.GetServices(),
}
device, err := provider.GetDevice()
if err != nil {
return &ctrlNode
}
peers := lib.Filter(device.Peers, func(p wgtypes.Peer) bool {
return p.PublicKey.String() == pubKey.String()
})
if len(peers) > 0 {
peer := peers[0]
stats := WireGuardStats{
AllowedIPs: lib.Map(peer.AllowedIPs, func(i net.IPNet) string {
return i.String()
}),
TransmitBytes: peer.TransmitBytes,
ReceivedBytes: peer.ReceiveBytes,
PersistentKeepAliveInterval: peer.PersistentKeepaliveInterval,
}
ctrlNode.Stats = stats
}
return &ctrlNode
} }

View File

@ -1,10 +1,10 @@
package ctrlserver package ctrlserver
import ( import (
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn" "github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query" "github.com/tim-beatham/smegmesh/pkg/query"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
) )
@ -23,10 +23,10 @@ func NewCtrlServerStub() *CtrlServerStub {
} }
} }
func (c *CtrlServerStub) GetConfiguration() *conf.WgMeshConfiguration { func (c *CtrlServerStub) GetConfiguration() *conf.DaemonConfiguration {
return &conf.WgMeshConfiguration{ return &conf.DaemonConfiguration{
GrpcPort: "8080", GrpcPort: 8080,
Endpoint: "abc.com", BaseConfiguration: conf.WgConfiguration{},
} }
} }

114
pkg/dns/dns.go Normal file
View File

@ -0,0 +1,114 @@
// smegdns: example of how to implement dns in the mesh
package smegdns
import (
"encoding/json"
"fmt"
"net"
"github.com/miekg/dns"
"github.com/tim-beatham/smegmesh/pkg/ipc"
"github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/query"
)
const MeshRegularExpression = `(?P<meshId>.+)\.(?P<alias>.+)\.smeg\.`
type DNSHandler struct {
client *ipc.SmegmeshIpc
server *dns.Server
}
// queryMesh: queries the mesh network for the given meshId and node
// with alias
func (d *DNSHandler) queryMesh(meshId, alias string) net.IP {
var reply string
err := d.client.Query(ipc.QueryMesh{
MeshId: meshId,
Query: fmt.Sprintf("[?alias == '%s'] | [0]", alias),
}, &reply)
if err != nil {
return nil
}
var node *query.QueryNode
err = json.Unmarshal([]byte(reply), &node)
if err != nil || node == nil {
return nil
}
ip, _, _ := net.ParseCIDR(node.WgHost)
return ip
}
// handleQuery: handles a DNS query
func (d *DNSHandler) handleQuery(m *dns.Msg) {
for _, q := range m.Question {
switch q.Qtype {
case dns.TypeAAAA:
logging.Log.WriteInfof("Query for %s", q.Name)
groups := lib.MatchCaptureGroup(MeshRegularExpression, q.Name)
if len(groups) == 0 {
continue
}
ip := d.queryMesh(groups["meshId"], groups["alias"])
rr, err := dns.NewRR(fmt.Sprintf("%s AAAA %s", q.Name, ip))
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
}
}
// handleDNS query: handle a DNS request
func (h *DNSHandler) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
msg := new(dns.Msg)
msg.SetReply(r)
msg.Authoritative = true
switch r.Opcode {
case dns.OpcodeQuery:
h.handleQuery(msg)
}
w.WriteMsg(msg)
}
func (h *DNSHandler) Listen() error {
return h.server.ListenAndServe()
}
func (h *DNSHandler) Close() error {
return h.server.Shutdown()
}
func NewDns(udpPort int) (*DNSHandler, error) {
client, err := ipc.NewClientIpc()
if err != nil {
return nil, err
}
dnsHander := DNSHandler{
client: client,
}
dns.HandleFunc("smeg.", dnsHander.handleDnsRequest)
dnsHander.server = &dns.Server{Addr: fmt.Sprintf(":%d", udpPort), Net: "udp"}
return &dnsHander, nil
}

249
pkg/dot/dot.go Normal file
View File

@ -0,0 +1,249 @@
// Graph allows the definition of a DOT graph in golang
package graph
import (
"fmt"
"hash/fnv"
"strings"
"github.com/tim-beatham/smegmesh/pkg/lib"
)
type GraphType string
type Shape string
const (
GRAPH GraphType = "graph"
DIGRAPH GraphType = "digraph"
)
const (
CIRCLE Shape = "circle"
STAR Shape = "star"
HEXAGON Shape = "hexagon"
PARALLELOGRAM Shape = "parallelogram"
)
type Graph interface {
Dottable
GetType() GraphType
}
// Cluster: represents a subgraph in the graphs
type Cluster struct {
Type GraphType
Name string
Label string
nodes map[string]*Node
edges map[string]Edge
}
// RootGraph: Represents the top level graph
type RootGraph struct {
Type GraphType
Label string
nodes map[string]*Node
clusters map[string]*Cluster
edges map[string]Edge
}
// Node: represents a graphviz not
type Node struct {
Name string
Label string
Shape Shape
Size int
}
// Edge: represents an edge between adjacent nodes
type Edge interface {
Dottable
}
// DirectEdge: contains a directed edge between any two nodes
type DirectedEdge struct {
Name string
Label string
From string
To string
}
// UndirectedEdge: contains an undirected edge between any two
// nodes
type UndirectedEdge struct {
Name string
Label string
From string
To string
}
// Dottable means an implementer can convert the struct to DOT representation
type Dottable interface {
GetDOT() (string, error)
}
// PutNode: puts a node in the root graph
func (g *RootGraph) PutNode(name, label string, size int, shape Shape) error {
_, exists := g.nodes[name]
if exists {
// If exists no need to add the ndoe
return nil
}
g.nodes[name] = &Node{Name: name, Label: label, Size: size, Shape: shape}
return nil
}
// PutCluster: puts a cluster in the root graph
func (g *RootGraph) PutCluster(graph *Cluster) {
g.clusters[graph.Label] = graph
}
func writeContituents[D Dottable](result *strings.Builder, elements ...D) error {
for _, node := range elements {
dot, err := node.GetDOT()
if err != nil {
return err
}
_, err = result.WriteString(dot)
if err != nil {
return err
}
}
return nil
}
// GetDOT: convert the root graph into dot format
func (g *RootGraph) GetDOT() (string, error) {
var result strings.Builder
result.WriteString(fmt.Sprintf("%s {\n", g.Type))
result.WriteString("node [colorscheme=set312];\n")
result.WriteString("layout = fdp;\n")
nodes := lib.MapValues(g.nodes)
edges := lib.MapValues(g.edges)
writeContituents(&result, nodes...)
writeContituents(&result, edges...)
for _, cluster := range g.clusters {
clusterDOT, err := cluster.GetDOT()
if err != nil {
return "", err
}
result.WriteString(clusterDOT)
}
result.WriteString("}")
return result.String(), nil
}
// GetType: get the graph type. DIRECTED|UNDIRECTED
func (r *RootGraph) GetType() GraphType {
return r.Type
}
func constructEdge(graph Graph, name, label, from, to string) Edge {
switch graph.GetType() {
case DIGRAPH:
return &DirectedEdge{Name: name, Label: label, From: from, To: to}
default:
return &UndirectedEdge{Name: name, Label: label, From: from, To: to}
}
}
// AddEdge: adds an edge between two nodes in the root graph
func (g *RootGraph) AddEdge(name string, label string, from string, to string) error {
g.edges[name] = constructEdge(g, name, label, from, to)
return nil
}
const numColours = 12
func (n *Node) hash() int {
h := fnv.New32a()
h.Write([]byte(n.Name))
return (int(h.Sum32()) % numColours) + 1
}
// GetDOT: convert the node into DOT format
func (n *Node) GetDOT() (string, error) {
return fmt.Sprintf("node[label=\"%s\",shape=%s, style=\"filled\", fillcolor=%d, width=%d, height=%d, fixedsize=true] \"%s\";\n",
n.Label, n.Shape, n.hash(), n.Size, n.Size, n.Name), nil
}
// GetDOT: Convert a directed edge into dot format
func (e *DirectedEdge) GetDOT() (string, error) {
return fmt.Sprintf("\"%s\" -> \"%s\" [label=\"%s\"];\n", e.From, e.To, e.Label), nil
}
// GetDOT: convert an undirected edge into dot format
func (e *UndirectedEdge) GetDOT() (string, error) {
return fmt.Sprintf("\"%s\" -- \"%s\" [label=\"%s\"];\n", e.From, e.To, e.Label), nil
}
// AddEdge: adds an edge between two nodes in the graph
func (g *Cluster) AddEdge(name string, label string, from string, to string) error {
g.edges[name] = constructEdge(g, name, label, from, to)
return nil
}
// PutNode: puts a node in the graph
func (g *Cluster) PutNode(name, label string, size int, shape Shape) error {
_, exists := g.nodes[name]
if exists {
// If exists no need to add the ndoe
return nil
}
g.nodes[name] = &Node{Name: name, Label: label, Shape: shape, Size: size}
return nil
}
// GetDOT: convert the cluster into dot format
func (g *Cluster) GetDOT() (string, error) {
var builder strings.Builder
builder.WriteString(fmt.Sprintf("subgraph \"cluster%s\" {\n", g.Label))
builder.WriteString(fmt.Sprintf("label = \"%s\"\n", g.Label))
nodes := lib.MapValues(g.nodes)
edges := lib.MapValues(g.edges)
writeContituents(&builder, nodes...)
writeContituents(&builder, edges...)
builder.WriteString("}\n")
return builder.String(), nil
}
// GetType: get the type of the subgraph (directed|undirected)
func (g *Cluster) GetType() GraphType {
return g.Type
}
// NewSubGraph: instantiate a new subgraph
func NewSubGraph(name string, label string, graphType GraphType) *Cluster {
return &Cluster{
Label: name,
Type: graphType,
Name: name,
nodes: make(map[string]*Node),
edges: make(map[string]Edge),
}
}
// NewGraph: create a new root graph
func NewGraph(label string, graphType GraphType) *RootGraph {
return &RootGraph{
Type: graphType,
Label: label,
clusters: map[string]*Cluster{},
nodes: make(map[string]*Node),
edges: make(map[string]Edge),
}
}

116
pkg/dot/wg.go Normal file
View File

@ -0,0 +1,116 @@
package graph
import (
"fmt"
"slices"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
)
// MeshGraphConverter converts a mesh to a graph
type MeshGraphConverter interface {
// convert the mesh to textual form
Generate() (string, error)
}
type MeshDOTConverter struct {
meshes map[string][]ctrlserver.MeshNode
destinations map[string]interface{}
}
func (c *MeshDOTConverter) Generate() (string, error) {
g := NewGraph("Smegmesh", GRAPH)
for meshId := range c.meshes {
err := c.generateMesh(g, meshId)
if err != nil {
return "", err
}
}
for mesh := range c.meshes {
g.PutNode(mesh, mesh, 1, CIRCLE)
}
for destination := range c.destinations {
g.PutNode(destination, destination, 1, HEXAGON)
}
return g.GetDOT()
}
func (c *MeshDOTConverter) generateMesh(g *RootGraph, meshId string) error {
nodes := c.meshes[meshId]
g.PutNode(meshId, meshId, 1, CIRCLE)
for _, node := range nodes {
c.graphNode(g, node, meshId)
}
for _, node := range nodes {
g.AddEdge(fmt.Sprintf("%s to %s", node.PublicKey, meshId), "", node.PublicKey, meshId)
}
return nil
}
// graphNode: graphs a node within the mesh
func (c *MeshDOTConverter) graphNode(g *RootGraph, node ctrlserver.MeshNode, meshId string) {
alias := node.Alias
if alias == "" {
alias = node.WgHost[1:len(node.WgHost)-20] + "\\n" + node.WgHost[len(node.WgHost)-20:len(node.WgHost)]
}
g.PutNode(node.PublicKey, alias, 2, CIRCLE)
for _, route := range node.Routes {
if len(route.Path) == 0 {
g.AddEdge(route.Destination, "", node.PublicKey, route.Destination)
continue
}
reversedPath := slices.Clone(route.Path)
slices.Reverse(reversedPath)
g.AddEdge(fmt.Sprintf("%s to %s", node.PublicKey, reversedPath[0]), "", node.PublicKey, reversedPath[0])
for _, mesh := range route.Path {
if _, ok := c.meshes[mesh]; !ok {
c.destinations[mesh] = struct{}{}
}
}
for index := range reversedPath[0 : len(reversedPath)-1] {
routeID := fmt.Sprintf("%s to %s", reversedPath[index], reversedPath[index+1])
g.AddEdge(routeID, "", reversedPath[index], reversedPath[index+1])
}
if route.Destination == "::/0" {
c.destinations[route.Destination] = struct{}{}
lastMesh := reversedPath[len(reversedPath)-1]
routeID := fmt.Sprintf("%s to %s", lastMesh, route.Destination)
g.AddEdge(routeID, "", lastMesh, route.Destination)
}
}
for service := range node.Services {
c.putService(g, service, meshId, node)
}
}
// putService: construct a service node and a link between the nodes
func (c *MeshDOTConverter) putService(g *RootGraph, key, meshId string, node ctrlserver.MeshNode) {
serviceID := fmt.Sprintf("%s%s%s", key, node.PublicKey, meshId)
g.PutNode(serviceID, key, 1, PARALLELOGRAM)
g.AddEdge(fmt.Sprintf("%s to %s", node.PublicKey, serviceID), "", node.PublicKey, serviceID)
}
func NewMeshGraphConverter(meshes map[string][]ctrlserver.MeshNode) MeshGraphConverter {
return &MeshDOTConverter{
meshes: meshes,
destinations: make(map[string]interface{}),
}
}

View File

@ -1,178 +0,0 @@
// Graph allows the definition of a DOT graph in golang
package graph
import (
"errors"
"fmt"
"hash/fnv"
"strings"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
type GraphType string
type Shape string
const (
GRAPH GraphType = "graph"
DIGRAPH = "digraph"
)
const (
CIRCLE Shape = "circle"
STAR Shape = "star"
HEXAGON Shape = "hexagon"
)
type Graph struct {
Type GraphType
Label string
nodes map[string]*Node
edges []Edge
}
type Node struct {
Name string
Shape Shape
}
type Edge interface {
Dottable
}
type DirectedEdge struct {
Label string
From *Node
To *Node
}
type UndirectedEdge struct {
Label string
From *Node
To *Node
}
// Dottable means an implementer can convert the struct to DOT representation
type Dottable interface {
GetDOT() (string, error)
}
func NewGraph(label string, graphType GraphType) *Graph {
return &Graph{Type: graphType, Label: label, nodes: make(map[string]*Node), edges: make([]Edge, 0)}
}
// PutNode: puts a node in the graph
func (g *Graph) PutNode(label string, shape Shape) error {
_, exists := g.nodes[label]
if exists {
// If exists no need to add the ndoe
return nil
}
g.nodes[label] = &Node{Name: label, Shape: shape}
return nil
}
func writeContituents[D Dottable](result *strings.Builder, elements ...D) error {
for _, node := range elements {
dot, err := node.GetDOT()
if err != nil {
return err
}
_, err = result.WriteString(dot)
if err != nil {
return err
}
}
return nil
}
func (g *Graph) GetDOT() (string, error) {
var result strings.Builder
_, err := result.WriteString(fmt.Sprintf("%s {\n", g.Type))
if err != nil {
return "", err
}
_, err = result.WriteString("node [colorscheme=set312];\n")
if err != nil {
return "", err
}
nodes := lib.MapValues(g.nodes)
err = writeContituents(&result, nodes...)
if err != nil {
return "", err
}
err = writeContituents(&result, g.edges...)
if err != nil {
return "", err
}
_, err = result.WriteString("}")
if err != nil {
return "", err
}
return result.String(), nil
}
func (g *Graph) constructEdge(label string, from *Node, to *Node) Edge {
switch g.Type {
case DIGRAPH:
return &DirectedEdge{Label: label, From: from, To: to}
default:
return &UndirectedEdge{Label: label, From: from, To: to}
}
}
// AddEdge: adds an edge between two nodes in the graph
func (g *Graph) AddEdge(label string, from string, to string) error {
fromNode, exists := g.nodes[from]
if !exists {
return errors.New(fmt.Sprintf("Node %s does not exist", from))
}
toNode, exists := g.nodes[to]
if !exists {
return errors.New(fmt.Sprintf("Node %s does not exist", to))
}
g.edges = append(g.edges, g.constructEdge(label, fromNode, toNode))
return nil
}
const numColours = 12
func (n *Node) hash() int {
h := fnv.New32a()
h.Write([]byte(n.Name))
return (int(h.Sum32()) % numColours) + 1
}
func (n *Node) GetDOT() (string, error) {
return fmt.Sprintf("node[shape=%s, style=\"filled\", fillcolor=%d] %s;\n",
n.Shape, n.hash(), n.Name), nil
}
func (e *DirectedEdge) GetDOT() (string, error) {
return fmt.Sprintf("%s -> %s;\n", e.From.Name, e.To.Name), nil
}
func (e *UndirectedEdge) GetDOT() (string, error) {
return fmt.Sprintf("%s -- %s;\n", e.From.Name, e.To.Name), nil
}

212
pkg/grpc/ctrlserver.pb.go Normal file
View File

@ -0,0 +1,212 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.1
// protoc v3.21.12
// source: pkg/grpc/ctrlserver.proto
package rpc
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type GetMeshRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
MeshId string `protobuf:"bytes,1,opt,name=meshId,proto3" json:"meshId,omitempty"`
}
func (x *GetMeshRequest) Reset() {
*x = GetMeshRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *GetMeshRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GetMeshRequest) ProtoMessage() {}
func (x *GetMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GetMeshRequest.ProtoReflect.Descriptor instead.
func (*GetMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_proto_rawDescGZIP(), []int{0}
}
func (x *GetMeshRequest) GetMeshId() string {
if x != nil {
return x.MeshId
}
return ""
}
type GetMeshReply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Mesh []byte `protobuf:"bytes,1,opt,name=mesh,proto3" json:"mesh,omitempty"`
}
func (x *GetMeshReply) Reset() {
*x = GetMeshReply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *GetMeshReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GetMeshReply) ProtoMessage() {}
func (x *GetMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GetMeshReply.ProtoReflect.Descriptor instead.
func (*GetMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_proto_rawDescGZIP(), []int{1}
}
func (x *GetMeshReply) GetMesh() []byte {
if x != nil {
return x.Mesh
}
return nil
}
var File_pkg_grpc_ctrlserver_proto protoreflect.FileDescriptor
var file_pkg_grpc_ctrlserver_proto_rawDesc = []byte{
0x0a, 0x19, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73,
0x65, 0x72, 0x76, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x22, 0x28, 0x0a, 0x0e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68,
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49,
0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x22,
0x22, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12,
0x12, 0x0a, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x6d,
0x65, 0x73, 0x68, 0x32, 0x4f, 0x0a, 0x0e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x74, 0x72, 0x6c, 0x53,
0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x3d, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68,
0x12, 0x18, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d,
0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70,
0x6c, 0x79, 0x22, 0x00, 0x42, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62,
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_pkg_grpc_ctrlserver_proto_rawDescOnce sync.Once
file_pkg_grpc_ctrlserver_proto_rawDescData = file_pkg_grpc_ctrlserver_proto_rawDesc
)
func file_pkg_grpc_ctrlserver_proto_rawDescGZIP() []byte {
file_pkg_grpc_ctrlserver_proto_rawDescOnce.Do(func() {
file_pkg_grpc_ctrlserver_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_grpc_ctrlserver_proto_rawDescData)
})
return file_pkg_grpc_ctrlserver_proto_rawDescData
}
var file_pkg_grpc_ctrlserver_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_pkg_grpc_ctrlserver_proto_goTypes = []interface{}{
(*GetMeshRequest)(nil), // 0: rpctypes.GetMeshRequest
(*GetMeshReply)(nil), // 1: rpctypes.GetMeshReply
}
var file_pkg_grpc_ctrlserver_proto_depIdxs = []int32{
0, // 0: rpctypes.MeshCtrlServer.GetMesh:input_type -> rpctypes.GetMeshRequest
1, // 1: rpctypes.MeshCtrlServer.GetMesh:output_type -> rpctypes.GetMeshReply
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_pkg_grpc_ctrlserver_proto_init() }
func file_pkg_grpc_ctrlserver_proto_init() {
if File_pkg_grpc_ctrlserver_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_pkg_grpc_ctrlserver_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetMeshRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_ctrlserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetMeshReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pkg_grpc_ctrlserver_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_pkg_grpc_ctrlserver_proto_goTypes,
DependencyIndexes: file_pkg_grpc_ctrlserver_proto_depIdxs,
MessageInfos: file_pkg_grpc_ctrlserver_proto_msgTypes,
}.Build()
File_pkg_grpc_ctrlserver_proto = out.File
file_pkg_grpc_ctrlserver_proto_rawDesc = nil
file_pkg_grpc_ctrlserver_proto_goTypes = nil
file_pkg_grpc_ctrlserver_proto_depIdxs = nil
}

16
pkg/grpc/ctrlserver.proto Normal file
View File

@ -0,0 +1,16 @@
syntax = "proto3";
package rpctypes;
option go_package = "pkg/rpc";
service MeshCtrlServer {
rpc GetMesh(GetMeshRequest) returns (GetMeshReply) {}
}
message GetMeshRequest {
string meshId = 1;
}
message GetMeshReply {
bytes mesh = 1;
}

View File

@ -1,18 +0,0 @@
syntax = "proto3";
package rpctypes;
option go_package = "pkg/rpc";
service Authentication {
rpc JoinMesh(JoinAuthMeshRequest) returns (JoinAuthMeshReply) {}
}
message JoinAuthMeshRequest {
string meshId = 1;
string alias = 2;
}
message JoinAuthMeshReply {
bool success = 1;
optional string token = 2;
}

View File

@ -1,16 +0,0 @@
syntax = "proto3";
package rpctypes;
option go_package = "pkg/rpc";
service MeshCtrlServer {
rpc JoinMesh(JoinMeshRequest) returns (JoinMeshReply) {}
}
message JoinMeshRequest {
string meshId = 2;
}
message JoinMeshReply {
bool success = 1;
}

View File

@ -0,0 +1,105 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.2.0
// - protoc v3.21.12
// source: pkg/grpc/ctrlserver.proto
package rpc
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// MeshCtrlServerClient is the client API for MeshCtrlServer service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type MeshCtrlServerClient interface {
GetMesh(ctx context.Context, in *GetMeshRequest, opts ...grpc.CallOption) (*GetMeshReply, error)
}
type meshCtrlServerClient struct {
cc grpc.ClientConnInterface
}
func NewMeshCtrlServerClient(cc grpc.ClientConnInterface) MeshCtrlServerClient {
return &meshCtrlServerClient{cc}
}
func (c *meshCtrlServerClient) GetMesh(ctx context.Context, in *GetMeshRequest, opts ...grpc.CallOption) (*GetMeshReply, error) {
out := new(GetMeshReply)
err := c.cc.Invoke(ctx, "/rpctypes.MeshCtrlServer/GetMesh", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// MeshCtrlServerServer is the server API for MeshCtrlServer service.
// All implementations must embed UnimplementedMeshCtrlServerServer
// for forward compatibility
type MeshCtrlServerServer interface {
GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error)
mustEmbedUnimplementedMeshCtrlServerServer()
}
// UnimplementedMeshCtrlServerServer must be embedded to have forward compatible implementations.
type UnimplementedMeshCtrlServerServer struct {
}
func (UnimplementedMeshCtrlServerServer) GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetMesh not implemented")
}
func (UnimplementedMeshCtrlServerServer) mustEmbedUnimplementedMeshCtrlServerServer() {}
// UnsafeMeshCtrlServerServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to MeshCtrlServerServer will
// result in compilation errors.
type UnsafeMeshCtrlServerServer interface {
mustEmbedUnimplementedMeshCtrlServerServer()
}
func RegisterMeshCtrlServerServer(s grpc.ServiceRegistrar, srv MeshCtrlServerServer) {
s.RegisterService(&MeshCtrlServer_ServiceDesc, srv)
}
func _MeshCtrlServer_GetMesh_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(GetMeshRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(MeshCtrlServerServer).GetMesh(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/rpctypes.MeshCtrlServer/GetMesh",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(MeshCtrlServerServer).GetMesh(ctx, req.(*GetMeshRequest))
}
return interceptor(ctx, in, info, handler)
}
// MeshCtrlServer_ServiceDesc is the grpc.ServiceDesc for MeshCtrlServer service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var MeshCtrlServer_ServiceDesc = grpc.ServiceDesc{
ServiceName: "rpctypes.MeshCtrlServer",
HandlerType: (*MeshCtrlServerServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "GetMesh",
Handler: _MeshCtrlServer_GetMesh_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "pkg/grpc/ctrlserver.proto",
}

233
pkg/grpc/syncservice.pb.go Normal file
View File

@ -0,0 +1,233 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.1
// protoc v3.21.12
// source: pkg/grpc/syncservice.proto
package rpc
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type SyncMeshRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
MeshId string `protobuf:"bytes,1,opt,name=meshId,proto3" json:"meshId,omitempty"`
Changes []byte `protobuf:"bytes,2,opt,name=changes,proto3" json:"changes,omitempty"`
}
func (x *SyncMeshRequest) Reset() {
*x = SyncMeshRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_syncservice_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SyncMeshRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SyncMeshRequest) ProtoMessage() {}
func (x *SyncMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_syncservice_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SyncMeshRequest.ProtoReflect.Descriptor instead.
func (*SyncMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_syncservice_proto_rawDescGZIP(), []int{0}
}
func (x *SyncMeshRequest) GetMeshId() string {
if x != nil {
return x.MeshId
}
return ""
}
func (x *SyncMeshRequest) GetChanges() []byte {
if x != nil {
return x.Changes
}
return nil
}
type SyncMeshReply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"`
Changes []byte `protobuf:"bytes,2,opt,name=changes,proto3" json:"changes,omitempty"`
}
func (x *SyncMeshReply) Reset() {
*x = SyncMeshReply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_syncservice_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SyncMeshReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SyncMeshReply) ProtoMessage() {}
func (x *SyncMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_syncservice_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SyncMeshReply.ProtoReflect.Descriptor instead.
func (*SyncMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_syncservice_proto_rawDescGZIP(), []int{1}
}
func (x *SyncMeshReply) GetSuccess() bool {
if x != nil {
return x.Success
}
return false
}
func (x *SyncMeshReply) GetChanges() []byte {
if x != nil {
return x.Changes
}
return nil
}
var File_pkg_grpc_syncservice_proto protoreflect.FileDescriptor
var file_pkg_grpc_syncservice_proto_rawDesc = []byte{
0x0a, 0x1a, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x73, 0x79, 0x6e, 0x63, 0x73,
0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0b, 0x73, 0x79,
0x6e, 0x63, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x22, 0x43, 0x0a, 0x0f, 0x53, 0x79, 0x6e,
0x63, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06,
0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65,
0x73, 0x68, 0x49, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18,
0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x22, 0x43,
0x0a, 0x0d, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12,
0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08,
0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x68, 0x61,
0x6e, 0x67, 0x65, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x63, 0x68, 0x61, 0x6e,
0x67, 0x65, 0x73, 0x32, 0x59, 0x0a, 0x0b, 0x53, 0x79, 0x6e, 0x63, 0x53, 0x65, 0x72, 0x76, 0x69,
0x63, 0x65, 0x12, 0x4a, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x1c,
0x2e, 0x73, 0x79, 0x6e, 0x63, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x53, 0x79, 0x6e,
0x63, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x73,
0x79, 0x6e, 0x63, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x4d,
0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x09,
0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x33,
}
var (
file_pkg_grpc_syncservice_proto_rawDescOnce sync.Once
file_pkg_grpc_syncservice_proto_rawDescData = file_pkg_grpc_syncservice_proto_rawDesc
)
func file_pkg_grpc_syncservice_proto_rawDescGZIP() []byte {
file_pkg_grpc_syncservice_proto_rawDescOnce.Do(func() {
file_pkg_grpc_syncservice_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_grpc_syncservice_proto_rawDescData)
})
return file_pkg_grpc_syncservice_proto_rawDescData
}
var file_pkg_grpc_syncservice_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_pkg_grpc_syncservice_proto_goTypes = []interface{}{
(*SyncMeshRequest)(nil), // 0: syncservice.SyncMeshRequest
(*SyncMeshReply)(nil), // 1: syncservice.SyncMeshReply
}
var file_pkg_grpc_syncservice_proto_depIdxs = []int32{
0, // 0: syncservice.SyncService.SyncMesh:input_type -> syncservice.SyncMeshRequest
1, // 1: syncservice.SyncService.SyncMesh:output_type -> syncservice.SyncMeshReply
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_pkg_grpc_syncservice_proto_init() }
func file_pkg_grpc_syncservice_proto_init() {
if File_pkg_grpc_syncservice_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_pkg_grpc_syncservice_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SyncMeshRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_syncservice_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SyncMeshReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pkg_grpc_syncservice_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_pkg_grpc_syncservice_proto_goTypes,
DependencyIndexes: file_pkg_grpc_syncservice_proto_depIdxs,
MessageInfos: file_pkg_grpc_syncservice_proto_msgTypes,
}.Build()
File_pkg_grpc_syncservice_proto = out.File
file_pkg_grpc_syncservice_proto_rawDesc = nil
file_pkg_grpc_syncservice_proto_goTypes = nil
file_pkg_grpc_syncservice_proto_depIdxs = nil
}

View File

@ -4,18 +4,9 @@ package syncservice;
option go_package = "pkg/rpc"; option go_package = "pkg/rpc";
service SyncService { service SyncService {
rpc GetConf(GetConfRequest) returns (GetConfReply) {}
rpc SyncMesh(stream SyncMeshRequest) returns (stream SyncMeshReply) {} rpc SyncMesh(stream SyncMeshRequest) returns (stream SyncMeshReply) {}
} }
message GetConfRequest {
string meshId = 1;
}
message GetConfReply {
bytes mesh = 1;
}
message SyncMeshRequest { message SyncMeshRequest {
string meshId = 1; string meshId = 1;
bytes changes = 2; bytes changes = 2;

View File

@ -0,0 +1,137 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.2.0
// - protoc v3.21.12
// source: pkg/grpc/syncservice.proto
package rpc
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// SyncServiceClient is the client API for SyncService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type SyncServiceClient interface {
SyncMesh(ctx context.Context, opts ...grpc.CallOption) (SyncService_SyncMeshClient, error)
}
type syncServiceClient struct {
cc grpc.ClientConnInterface
}
func NewSyncServiceClient(cc grpc.ClientConnInterface) SyncServiceClient {
return &syncServiceClient{cc}
}
func (c *syncServiceClient) SyncMesh(ctx context.Context, opts ...grpc.CallOption) (SyncService_SyncMeshClient, error) {
stream, err := c.cc.NewStream(ctx, &SyncService_ServiceDesc.Streams[0], "/syncservice.SyncService/SyncMesh", opts...)
if err != nil {
return nil, err
}
x := &syncServiceSyncMeshClient{stream}
return x, nil
}
type SyncService_SyncMeshClient interface {
Send(*SyncMeshRequest) error
Recv() (*SyncMeshReply, error)
grpc.ClientStream
}
type syncServiceSyncMeshClient struct {
grpc.ClientStream
}
func (x *syncServiceSyncMeshClient) Send(m *SyncMeshRequest) error {
return x.ClientStream.SendMsg(m)
}
func (x *syncServiceSyncMeshClient) Recv() (*SyncMeshReply, error) {
m := new(SyncMeshReply)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// SyncServiceServer is the server API for SyncService service.
// All implementations must embed UnimplementedSyncServiceServer
// for forward compatibility
type SyncServiceServer interface {
SyncMesh(SyncService_SyncMeshServer) error
mustEmbedUnimplementedSyncServiceServer()
}
// UnimplementedSyncServiceServer must be embedded to have forward compatible implementations.
type UnimplementedSyncServiceServer struct {
}
func (UnimplementedSyncServiceServer) SyncMesh(SyncService_SyncMeshServer) error {
return status.Errorf(codes.Unimplemented, "method SyncMesh not implemented")
}
func (UnimplementedSyncServiceServer) mustEmbedUnimplementedSyncServiceServer() {}
// UnsafeSyncServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to SyncServiceServer will
// result in compilation errors.
type UnsafeSyncServiceServer interface {
mustEmbedUnimplementedSyncServiceServer()
}
func RegisterSyncServiceServer(s grpc.ServiceRegistrar, srv SyncServiceServer) {
s.RegisterService(&SyncService_ServiceDesc, srv)
}
func _SyncService_SyncMesh_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(SyncServiceServer).SyncMesh(&syncServiceSyncMeshServer{stream})
}
type SyncService_SyncMeshServer interface {
Send(*SyncMeshReply) error
Recv() (*SyncMeshRequest, error)
grpc.ServerStream
}
type syncServiceSyncMeshServer struct {
grpc.ServerStream
}
func (x *syncServiceSyncMeshServer) Send(m *SyncMeshReply) error {
return x.ServerStream.SendMsg(m)
}
func (x *syncServiceSyncMeshServer) Recv() (*SyncMeshRequest, error) {
m := new(SyncMeshRequest)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// SyncService_ServiceDesc is the grpc.ServiceDesc for SyncService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var SyncService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "syncservice.SyncService",
HandlerType: (*SyncServiceServer)(nil),
Methods: []grpc.MethodDesc{},
Streams: []grpc.StreamDesc{
{
StreamName: "SyncMesh",
Handler: _SyncService_SyncMesh_Handler,
ServerStreams: true,
ClientStreams: true,
},
},
Metadata: "pkg/grpc/syncservice.proto",
}

View File

@ -1,8 +1,7 @@
package ip package ip
/* // Generates a CGA see RFC 3972
* Use a WireGuard public key to generate a unique interface ID // https://datatracker.ietf.org/doc/html/rfc3972
*/
import ( import (
"crypto/rand" "crypto/rand"
@ -22,19 +21,23 @@ const (
InterfaceIdLen = 8 InterfaceIdLen = 8
) )
/* // CGAParameters: parameters used to create a new cryotpgraphically generated
* Cga parameters used to generate an IPV6 interface ID // address
*/
type CgaParameters struct { type CgaParameters struct {
Modifier [ModifierLength]byte Modifier [ModifierLength]byte
// SubnetPrefix: prefix of the subnetwork
SubnetPrefix [2 * InterfaceIdLen]byte SubnetPrefix [2 * InterfaceIdLen]byte
// CollisionCount: total number of times we have atempted to generate a porefix
CollisionCount uint8 CollisionCount uint8
// PublicKey: WireGuard public key of our interface
PublicKey wgtypes.Key PublicKey wgtypes.Key
// interfaceId: the generated interfaceId
interfaceId [2 * InterfaceIdLen]byte interfaceId [2 * InterfaceIdLen]byte
// flag: represents whether or not an IP address has been generated
flag byte flag byte
} }
func NewCga(key wgtypes.Key, subnetPrefix [2 * InterfaceIdLen]byte) (*CgaParameters, error) { func NewCga(key wgtypes.Key, collisionCount uint8, subnetPrefix [2 * InterfaceIdLen]byte) (*CgaParameters, error) {
var params CgaParameters var params CgaParameters
_, err := rand.Read(params.Modifier[:]) _, err := rand.Read(params.Modifier[:])
@ -45,25 +48,10 @@ func NewCga(key wgtypes.Key, subnetPrefix [2 * InterfaceIdLen]byte) (*CgaParamet
params.PublicKey = key params.PublicKey = key
params.SubnetPrefix = subnetPrefix params.SubnetPrefix = subnetPrefix
params.CollisionCount = collisionCount
return &params, nil return &params, nil
} }
func (c *CgaParameters) generateHash2() []byte {
var byteVal [hash2Length]byte
for i := 0; i < ModifierLength; i++ {
byteVal[i] = c.Modifier[i]
}
for i := 0; i < wgtypes.KeyLen; i++ {
byteVal[ModifierLength+ZeroLength+i] = c.PublicKey[i]
}
hash := sha1.Sum(byteVal[:])
return hash[:Hash2Prefix]
}
func (c *CgaParameters) generateHash1() []byte { func (c *CgaParameters) generateHash1() []byte {
var byteVal [hash1Length]byte var byteVal [hash1Length]byte
@ -78,7 +66,6 @@ func (c *CgaParameters) generateHash1() []byte {
byteVal[hash1Length-1] = c.CollisionCount byteVal[hash1Length-1] = c.CollisionCount
hash := sha1.Sum(byteVal[:]) hash := sha1.Sum(byteVal[:])
return hash[:Hash1Prefix] return hash[:Hash1Prefix]
} }
@ -90,9 +77,6 @@ func clearBit(num, pos int) byte {
} }
func (c *CgaParameters) generateInterface() []byte { func (c *CgaParameters) generateInterface() []byte {
// TODO: On duplicate address detection increment collision.
// Also incorporate SEC
hash1 := c.generateHash1() hash1 := c.generateHash1()
var interfaceId []byte = make([]byte, InterfaceIdLen) var interfaceId []byte = make([]byte, InterfaceIdLen)
@ -101,7 +85,6 @@ func (c *CgaParameters) generateInterface() []byte {
interfaceId[0] = clearBit(int(interfaceId[0]), 6) interfaceId[0] = clearBit(int(interfaceId[0]), 6)
interfaceId[0] = clearBit(int(interfaceId[1]), 7) interfaceId[0] = clearBit(int(interfaceId[1]), 7)
return interfaceId return interfaceId
} }

View File

@ -6,6 +6,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// IPAllocator: abstracts the process of creating an IP address
type IPAllocator interface { type IPAllocator interface {
GetIP(key wgtypes.Key, meshId string) (net.IP, error) GetIP(key wgtypes.Key, meshId string, collisionCount uint8) (net.IP, error)
} }

View File

@ -8,6 +8,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// ULABuilder: Create a new ULA in WireGuard
type ULABuilder struct{} type ULABuilder struct{}
func getMeshPrefix(meshId string) [16]byte { func getMeshPrefix(meshId string) [16]byte {
@ -39,10 +40,10 @@ func (u *ULABuilder) GetIPNet(meshId string) (*net.IPNet, error) {
return net, nil return net, nil
} }
func (u *ULABuilder) GetIP(key wgtypes.Key, meshId string) (net.IP, error) { func (u *ULABuilder) GetIP(key wgtypes.Key, meshId string, collisionCount uint8) (net.IP, error) {
ulaPrefix := getMeshPrefix(meshId) ulaPrefix := getMeshPrefix(meshId)
c, err := NewCga(key, ulaPrefix) c, err := NewCga(key, collisionCount, ulaPrefix)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -5,65 +5,195 @@ import (
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
ipcRPC "net/rpc"
"os" "os"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
) )
type NewMeshArgs struct { const SockAddr = "/tmp/smeg.sock"
// IfName is the interface that the mesh instance will run on
IfName string type MeshIpc interface {
CreateMesh(args *NewMeshArgs, reply *string) error
ListMeshes(name string, reply *ListMeshReply) error
JoinMesh(args *JoinMeshArgs, reply *string) error
LeaveMesh(meshId string, reply *string) error
GetMesh(meshId string, reply *GetMeshReply) error
Query(query QueryMesh, reply *string) error
PutDescription(args PutDescriptionArgs, reply *string) error
PutAlias(args PutAliasArgs, reply *string) error
PutService(args PutServiceArgs, reply *string) error
DeleteService(args DeleteServiceArgs, reply *string) error
}
// WireGuardArgs are provided args specific to WireGuard
type WireGuardArgs struct {
// WgPort is the WireGuard port to expose // WgPort is the WireGuard port to expose
WgPort int WgPort int
// KeepAliveWg is the number of seconds to keep alive
// for WireGuard NAT/firewall traversal
KeepAliveWg int
// AdvertiseRoutes whether or not to advertise routes to and from the
// mesh network
AdvertiseRoutes bool
// AdvertiseDefaultRoute whether or not to advertise the default route
// into the mesh network
AdvertiseDefaultRoute bool
// Endpoint is the routable alias of the machine. Can be an IP // Endpoint is the routable alias of the machine. Can be an IP
// or DNS entry // or DNS entry
Endpoint string Endpoint string
// Role is the role of the individual in the mesh
Role string
}
type NewMeshArgs struct {
// WgArgs are specific WireGuard args to use
WgArgs WireGuardArgs
} }
type JoinMeshArgs struct { type JoinMeshArgs struct {
// MeshId is the ID of the mesh to join // MeshId is the ID of the mesh to join
MeshId string MeshId string
// IpAddress is a routable IP in another mesh // IpAddress is a routable IP in another mesh
IpAdress string IpAddress string
// IfName is the interface name of the mesh // WgArgs is the WireGuard parameters to use.
IfName string WgArgs WireGuardArgs
// Port is the WireGuard port to expose
Port int
// Endpoint is the routable address of this machine. If not provided
// defaults to the default address
Endpoint string
} }
// PutServiceArgs: args to place a service into the data store
type PutServiceArgs struct {
Service string
Value string
MeshId string
}
// DeleteServiceArgs: args to remove a service from the data store
type DeleteServiceArgs struct {
Service string
MeshId string
}
// PutAliasArgs: args to assign an alias to a node
type PutAliasArgs struct {
// Alias: represents the alias of the node
Alias string
// MeshId: represents the meshID of the node
MeshId string
}
// PutDescriptionArgs: args to assign a description to a node
type PutDescriptionArgs struct {
// Description: descriptio to add to the network
Description string
// MeshID to add to the mesh network
MeshId string
}
// GetMeshReply: ipc reply to get the mesh network
type GetMeshReply struct { type GetMeshReply struct {
Nodes []ctrlserver.MeshNode Nodes []ctrlserver.MeshNode
} }
// ListMeshReply: ipc reply of the networks the node is part of
type ListMeshReply struct { type ListMeshReply struct {
Meshes []string Meshes []string
} }
// Querymesh: ipc args to query a mesh network
type QueryMesh struct { type QueryMesh struct {
// MeshId: id of the mesh to query
MeshId string MeshId string
Query string // JMESPath: query string to query
Query string
} }
type MeshIpc interface { // ClientIpc: Framework to invoke ipc calls to the daemon
type ClientIpc interface {
// CreateMesh: create a mesh network, return an error if the operation failed
CreateMesh(args *NewMeshArgs, reply *string) error CreateMesh(args *NewMeshArgs, reply *string) error
ListMeshes(name string, reply *ListMeshReply) error // ListMesh: list mesh network the node is a part of, return an error if the operation failed
ListMeshes(args *ListMeshReply, reply *string) error
// JoinMesh: join a mesh network return an error if the operation failed
JoinMesh(args JoinMeshArgs, reply *string) error JoinMesh(args JoinMeshArgs, reply *string) error
// LeaveMesh: leave a mesh network, return an error if the operation failed
LeaveMesh(meshId string, reply *string) error LeaveMesh(meshId string, reply *string) error
// GetMesh: get the given mesh network, return an error if the operation failed
GetMesh(meshId string, reply *GetMeshReply) error GetMesh(meshId string, reply *GetMeshReply) error
EnableInterface(meshId string, reply *string) error // Query: query the given mesh network
GetDOT(meshId string, reply *string) error
Query(query QueryMesh, reply *string) error Query(query QueryMesh, reply *string) error
PutDescription(description string, reply *string) error // PutDescription: assign a description to yourself
PutDescription(args PutDescriptionArgs, reply *string) error
// PutAlias: assign an alias to yourself
PutAlias(args PutAliasArgs, reply *string) error
// PutService: assign a service to yourself
PutService(args PutServiceArgs, reply *string) error
// DeleteService: retract a service
DeleteService(args DeleteServiceArgs, reply *string) error
} }
const SockAddr = "/tmp/wgmesh_ipc.sock" type SmegmeshIpc struct {
client *ipcRPC.Client
}
func NewClientIpc() (*SmegmeshIpc, error) {
client, err := ipcRPC.DialHTTP("unix", SockAddr)
if err != nil {
return nil, err
}
return &SmegmeshIpc{
client: client,
}, nil
}
func (c *SmegmeshIpc) CreateMesh(args *NewMeshArgs, reply *string) error {
return c.client.Call("IpcHandler.CreateMesh", args, reply)
}
func (c *SmegmeshIpc) ListMeshes(reply *ListMeshReply) error {
return c.client.Call("IpcHandler.ListMeshes", "", reply)
}
func (c *SmegmeshIpc) JoinMesh(args JoinMeshArgs, reply *string) error {
return c.client.Call("IpcHandler.JoinMesh", &args, reply)
}
func (c *SmegmeshIpc) LeaveMesh(meshId string, reply *string) error {
return c.client.Call("IpcHandler.LeaveMesh", &meshId, reply)
}
func (c *SmegmeshIpc) GetMesh(meshId string, reply *GetMeshReply) error {
return c.client.Call("IpcHandler.GetMesh", &meshId, reply)
}
func (c *SmegmeshIpc) Query(query QueryMesh, reply *string) error {
return c.client.Call("IpcHandler.Query", &query, reply)
}
func (c *SmegmeshIpc) PutDescription(args PutDescriptionArgs, reply *string) error {
return c.client.Call("IpcHandler.PutDescription", &args, reply)
}
func (c *SmegmeshIpc) PutAlias(args PutAliasArgs, reply *string) error {
return c.client.Call("IpcHandler.PutAlias", &args, reply)
}
func (c *SmegmeshIpc) PutService(args PutServiceArgs, reply *string) error {
return c.client.Call("IpcHandler.PutService", &args, reply)
}
func (c *SmegmeshIpc) DeleteService(args DeleteServiceArgs, reply *string) error {
return c.client.Call("IpcHandler.DeleteService", &args, reply)
}
func (c *SmegmeshIpc) Close() error {
return c.client.Close()
}
func RunIpcHandler(server MeshIpc) error { func RunIpcHandler(server MeshIpc) error {
if err := os.RemoveAll(SockAddr); err != nil { if err := os.RemoveAll(SockAddr); err != nil {
return errors.New("Could not find to address") return errors.New("could not find to address")
} }
rpc.Register(server) rpc.Register(server)

View File

@ -1,11 +1,34 @@
package lib package lib
import "cmp"
// MapToSlice converts a map to a slice in go // MapToSlice converts a map to a slice in go
func MapValues[K comparable, V any](m map[K]V) []V { func MapValues[K cmp.Ordered, V any](m map[K]V) []V {
return MapValuesWithExclude(m, map[K]struct{}{}) return MapValuesWithExclude(m, map[K]struct{}{})
} }
func MapValuesWithExclude[K comparable, V any](m map[K]V, exclude map[K]struct{}) []V { type MapItemsEntry[K cmp.Ordered, V any] struct {
Key K
Value V
}
func MapItems[K cmp.Ordered, V any](m map[K]V) []MapItemsEntry[K, V] {
keys := MapKeys(m)
values := MapValues(m)
vs := make([]MapItemsEntry[K, V], len(keys))
for index, _ := range keys {
vs[index] = MapItemsEntry[K, V]{
Key: keys[index],
Value: values[index],
}
}
return vs
}
func MapValuesWithExclude[K cmp.Ordered, V any](m map[K]V, exclude map[K]struct{}) []V {
values := make([]V, len(m)-len(exclude)) values := make([]V, len(m)-len(exclude))
i := 0 i := 0
@ -26,7 +49,7 @@ func MapValuesWithExclude[K comparable, V any](m map[K]V, exclude map[K]struct{}
return values return values
} }
func MapKeys[K comparable, V any](m map[K]V) []K { func MapKeys[K cmp.Ordered, V any](m map[K]V) []K {
values := make([]K, len(m)) values := make([]K, len(m))
i := 0 i := 0
@ -66,3 +89,23 @@ func Filter[V any](list []V, f filterFunc[V]) []V {
return newList return newList
} }
func Contains[V any](list []V, proposition func(V) bool) bool {
for _, elem := range list {
if proposition(elem) {
return true
}
}
return false
}
func Reduce[A any, V any](start A, values []V, reduce func(A, V) A) A {
accum := start
for _, elem := range values {
accum = reduce(accum, elem)
}
return accum
}

48
pkg/lib/hashing.go Normal file
View File

@ -0,0 +1,48 @@
package lib
import (
"hash/fnv"
"sort"
)
type consistentHashRecord[V any] struct {
record V
value int
}
func HashString(value string) int {
f := fnv.New32a()
f.Write([]byte(value))
return int(f.Sum32())
}
// ConsistentHash implementation. Traverse the values until we find a key
// less than ours.
func ConsistentHash[V any, K any](values []V, client K, bucketFunc func(V) int, keyFunc func(K) int) V {
if len(values) == 0 {
panic("values is empty")
}
vs := Map(values, func(v V) consistentHashRecord[V] {
return consistentHashRecord[V]{
v,
bucketFunc(v),
}
})
sort.SliceStable(vs, func(i, j int) bool {
return vs[i].value < vs[j].value
})
ourKey := keyFunc(client)
idx := sort.Search(len(vs), func(i int) bool {
return vs[i].value >= ourKey
})
if idx == len(vs) {
return vs[0].record
}
return vs[idx].record
}

View File

@ -1,6 +1,10 @@
package lib package lib
import "github.com/google/uuid" import (
"github.com/anandvarma/namegen"
"github.com/google/uuid"
"github.com/lithammer/shortuuid"
)
// IdGenerator generates unique ids // IdGenerator generates unique ids
type IdGenerator interface { type IdGenerator interface {
@ -15,3 +19,19 @@ func (g *UUIDGenerator) GetId() (string, error) {
id := uuid.New() id := uuid.New()
return id.String(), nil return id.String(), nil
} }
type ShortIDGenerator struct {
}
func (g *ShortIDGenerator) GetId() (string, error) {
id := shortuuid.New()
return id, nil
}
type IDNameGenerator struct {
}
func (i *IDNameGenerator) GetId() (string, error) {
name_schema := namegen.New()
return name_schema.Get(), nil
}

View File

@ -1,17 +1,61 @@
package lib package lib
import ( import (
"encoding/json"
"io"
"log" "log"
"net" "net"
"net/http"
) )
// GetOutboundIP: gets the oubound IP of this packet // GetOutboundIP: gets the oubound IP of this packet
func GetOutboundIP() net.IP { func GetOutboundIP() (net.IP, error) {
conn, err := net.Dial("udp", "8.8.8.8:80") conn, err := net.Dial("udp", "8.8.8.8:80")
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
defer conn.Close() defer conn.Close()
localAddr := conn.LocalAddr().(*net.UDPAddr) localAddr := conn.LocalAddr().(*net.UDPAddr)
return localAddr.IP return localAddr.IP, nil
}
const IP_SERVICE = "https://api.ipify.org?format=json"
type IpResponse struct {
Ip string `json:"ip"`
}
func (i *IpResponse) GetIP() net.IP {
return net.ParseIP(i.Ip)
}
// GetPublicIP: get the nodes public IP address. For when a node is behind NAT
func GetPublicIP() (net.IP, error) {
req, err := http.NewRequest(http.MethodGet, IP_SERVICE, nil)
if err != nil {
return nil, err
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
resBody, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
var jsonResponse IpResponse
err = json.Unmarshal([]byte(resBody), &jsonResponse)
if err != nil {
return nil, err
}
return jsonResponse.GetIP(), nil
} }

19
pkg/lib/regex.go Normal file
View File

@ -0,0 +1,19 @@
package lib
import "regexp"
func MatchCaptureGroup(pattern, payload string) map[string]string {
patterns := make(map[string]string)
expr := regexp.MustCompile(pattern)
match := expr.FindStringSubmatch(payload)
for i, name := range expr.SubexpNames() {
if i != 0 && name != "" {
patterns[name] = match[i]
}
}
return patterns
}

View File

@ -6,27 +6,21 @@ import (
"net" "net"
"github.com/jsimonetti/rtnetlink" "github.com/jsimonetti/rtnetlink"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// Maximum MTU to assin to WireGuard
// This isn't configurable
const WIREGUARD_MTU = 1420
// RtNetlinkConfig: represents an rtnetlkink configuration instance
type RtNetlinkConfig struct { type RtNetlinkConfig struct {
// conn: connection to the rtnetlink API
conn *rtnetlink.Conn conn *rtnetlink.Conn
} }
func NewRtNetlinkConfig() (*RtNetlinkConfig, error) { // CreateLink: Create a netlink interface if it does not exist. ifName is the name of the netlink interface
conn, err := rtnetlink.Dial(nil)
if err != nil {
return nil, err
}
return &RtNetlinkConfig{conn: conn}, nil
}
const WIREGUARD_MTU = 1420
// Create a netlink interface if it does not exist. ifName is the name of the netlink interface
func (c *RtNetlinkConfig) CreateLink(ifName string) error { func (c *RtNetlinkConfig) CreateLink(ifName string) error {
_, err := net.InterfaceByName(ifName) _, err := net.InterfaceByName(ifName)
@ -51,7 +45,7 @@ func (c *RtNetlinkConfig) CreateLink(ifName string) error {
return nil return nil
} }
// Delete link delete the specified interface // DeleteLink: delete the specified interface
func (c *RtNetlinkConfig) DeleteLink(ifName string) error { func (c *RtNetlinkConfig) DeleteLink(ifName string) error {
iface, err := net.InterfaceByName(ifName) iface, err := net.InterfaceByName(ifName)
@ -68,7 +62,7 @@ func (c *RtNetlinkConfig) DeleteLink(ifName string) error {
return nil return nil
} }
// AddAddress adds an address to the given interface. // AddAddress: adds an address to the given interface.
func (c *RtNetlinkConfig) AddAddress(ifName string, address string) error { func (c *RtNetlinkConfig) AddAddress(ifName string, address string) error {
iface, err := net.InterfaceByName(ifName) iface, err := net.InterfaceByName(ifName)
@ -140,32 +134,44 @@ func (c *RtNetlinkConfig) AddRoute(ifName string, route Route) error {
family = unix.AF_INET family = unix.AF_INET
} }
attr := rtnetlink.RouteAttributes{ routes, err := c.listRoutes(ifName, family)
Dst: dst.IP,
OutIface: uint32(iface.Index),
Gateway: gw,
}
ones, _ := dst.Mask.Size()
err = c.conn.Route.Replace(&rtnetlink.RouteMessage{
Family: family,
Table: unix.RT_TABLE_MAIN,
Protocol: unix.RTPROT_BOOT,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
DstLength: uint8(ones),
Attributes: attr,
})
if err != nil { if err != nil {
return fmt.Errorf("failed to add route %w", err) return err
}
// If it already exists no need to add the route
if !Contains(routes, func(prevRoute rtnetlink.RouteMessage) bool {
return prevRoute.Attributes.Dst.Equal(route.Destination.IP) &&
prevRoute.Attributes.Gateway.Equal(route.Gateway)
}) {
attr := rtnetlink.RouteAttributes{
Dst: dst.IP,
OutIface: uint32(iface.Index),
Gateway: gw,
}
ones, _ := dst.Mask.Size()
err = c.conn.Route.Replace(&rtnetlink.RouteMessage{
Family: family,
Table: unix.RT_TABLE_MAIN,
Protocol: unix.RTPROT_BOOT,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
DstLength: uint8(ones),
Attributes: attr,
})
if err != nil {
return fmt.Errorf("failed to add route %w", err)
}
} }
return nil return nil
} }
// DeleteRoute deletes routes with the gateway and destination // DeleteRoute: deletes routes with the gateway and destination
func (c *RtNetlinkConfig) DeleteRoute(ifName string, route Route) error { func (c *RtNetlinkConfig) DeleteRoute(ifName string, route Route) error {
iface, err := net.InterfaceByName(ifName) iface, err := net.InterfaceByName(ifName)
@ -201,40 +207,37 @@ func (c *RtNetlinkConfig) DeleteRoute(ifName string, route Route) error {
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to delete route %w", err) return fmt.Errorf("failed to delete route %s", dst.IP.String())
} }
return nil return nil
} }
// route: represents a rout to add to the RIB
type Route struct { type Route struct {
Gateway net.IP Gateway net.IP
Destination net.IPNet Destination net.IPNet
} }
func (r1 Route) equal(r2 Route) bool { func (r1 Route) equal(r2 Route) bool {
mask1Ones, _ := r1.Destination.Mask.Size()
mask2Ones, _ := r2.Destination.Mask.Size()
return r1.Gateway.String() == r2.Gateway.String() && return r1.Gateway.String() == r2.Gateway.String() &&
r1.Destination.String() == r2.Destination.String() (mask1Ones == 0 && mask2Ones == 0 || r1.Destination.IP.Equal(r2.Destination.IP))
} }
// DeleteRoutes deletes all routes not in exclude // DeleteRoutes: deletes all routes not in exclude on the given interface
func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...Route) error { func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...Route) error {
routes := make([]rtnetlink.RouteMessage, 0) routes, err := c.listRoutes(ifName, family)
if len(exclude) != 0 { if err != nil {
lRoutes, err := c.listRoutes(ifName, family, exclude[0].Gateway) return err
if err != nil {
return err
}
routes = lRoutes
} }
ifRoutes := make([]Route, 0) ifRoutes := make([]Route, 0)
for _, rtRoute := range routes { for _, rtRoute := range routes {
logging.Log.WriteInfof("Routes: %s", rtRoute.Attributes.Dst.String())
maskSize := 128 maskSize := 128
if family == unix.AF_INET { if family == unix.AF_INET {
@ -252,17 +255,18 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
shouldExclude := func(r Route) bool { shouldExclude := func(r Route) bool {
for _, route := range exclude { for _, route := range exclude {
if route.equal(r) { if r.equal(route) {
return false return false
} }
} }
return true return true
} }
toDelete := Filter(ifRoutes, shouldExclude) toDelete := Filter(ifRoutes, shouldExclude)
for _, route := range toDelete { for _, route := range toDelete {
logging.Log.WriteInfof("Deleting route %s", route.Destination.String()) logging.Log.WriteInfof("Deleting route: %s", route.Destination.String())
err := c.DeleteRoute(ifName, route) err := c.DeleteRoute(ifName, route)
if err != nil { if err != nil {
@ -273,8 +277,8 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
return nil return nil
} }
// listRoutes lists all routes on the interface // listRoutes: lists all routes on the interface
func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8, gateway net.IP) ([]rtnetlink.RouteMessage, error) { func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8) ([]rtnetlink.RouteMessage, error) {
iface, err := net.InterfaceByName(ifName) iface, err := net.InterfaceByName(ifName)
if err != nil { if err != nil {
@ -288,13 +292,25 @@ func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8, gateway net.IP
} }
filterFunc := func(r rtnetlink.RouteMessage) bool { filterFunc := func(r rtnetlink.RouteMessage) bool {
return r.Attributes.Gateway.Equal(gateway) && r.Attributes.OutIface == uint32(iface.Index) return r.Attributes.Gateway != nil && r.Attributes.OutIface == uint32(iface.Index)
} }
routes = Filter(routes, filterFunc) routes = Filter(routes, filterFunc)
return routes, nil return routes, nil
} }
// Close: close the Rtnetlink API
func (c *RtNetlinkConfig) Close() error { func (c *RtNetlinkConfig) Close() error {
return c.conn.Close() return c.conn.Close()
} }
// newRtNetlinkConfig: connect to the RtnetlinkAPI
func NewRtNetlinkConfig() (*RtNetlinkConfig, error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return nil, err
}
return &RtNetlinkConfig{conn: conn}, nil
}

View File

@ -2,9 +2,11 @@
package logging package logging
import ( import (
"io"
"os" "os"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/tim-beatham/smegmesh/pkg/conf"
) )
var ( var (
@ -15,6 +17,7 @@ type Logger interface {
WriteInfof(msg string, args ...interface{}) WriteInfof(msg string, args ...interface{})
WriteErrorf(msg string, args ...interface{}) WriteErrorf(msg string, args ...interface{})
WriteWarnf(msg string, args ...interface{}) WriteWarnf(msg string, args ...interface{})
Writer() io.Writer
} }
type LogrusLogger struct { type LogrusLogger struct {
@ -33,17 +36,33 @@ func (l *LogrusLogger) WriteWarnf(msg string, args ...interface{}) {
l.logger.Warnf(msg, args...) l.logger.Warnf(msg, args...)
} }
func NewLogrusLogger() *LogrusLogger { func (l *LogrusLogger) Writer() io.Writer {
return l.logger.Writer()
}
func NewLogrusLogger(confLevel conf.LogLevel) *LogrusLogger {
var level logrus.Level
switch confLevel {
case conf.ERROR:
level = logrus.ErrorLevel
case conf.WARNING:
level = logrus.WarnLevel
case conf.INFO:
level = logrus.InfoLevel
}
logger := logrus.New() logger := logrus.New()
logger.SetFormatter(&logrus.TextFormatter{FullTimestamp: true}) logger.SetFormatter(&logrus.TextFormatter{FullTimestamp: true})
logger.SetOutput(os.Stdout) logger.SetOutput(os.Stdout)
logger.SetLevel(logrus.InfoLevel) logger.SetLevel(level)
return &LogrusLogger{logger: logger} return &LogrusLogger{logger: logger}
} }
func init() { func init() {
SetLogger(NewLogrusLogger()) SetLogger(NewLogrusLogger(conf.INFO))
} }
func SetLogger(l Logger) { func SetLogger(l Logger) {

View File

@ -3,117 +3,506 @@ package mesh
import ( import (
"fmt" "fmt"
"net" "net"
"slices"
"strings"
"time"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/route"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// MeshConfigApplyer abstracts applying the mesh configuration // MeshConfigApplyer abstracts applying the mesh configuration
type MeshConfigApplyer interface { type MeshConfigApplyer interface {
// ApplyConfig: apply the configurtation
ApplyConfig() error ApplyConfig() error
RemovePeers(meshId string) error // SetMeshManager: sets the associated manager
SetMeshManager(manager MeshManager) SetMeshManager(manager MeshManager)
} }
// WgMeshConfigApplyer applies WireGuard configuration // WgMeshConfigApplyer: applies WireGuard configuration
type WgMeshConfigApplyer struct { type WgMeshConfigApplyer struct {
meshManager MeshManager meshManager MeshManager
routeInstaller route.RouteInstaller
hashFunc func(MeshNode) int
} }
func convertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) { type routeNode struct {
endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint()) gateway string
route Route
}
if err != nil { type convertMeshNodeParams struct {
return nil, err node MeshNode
} self MeshNode
mesh MeshProvider
device *wgtypes.Device
peerToClients map[string][]net.IPNet
routes map[string][]routeNode
}
pubKey, err := node.GetPublicKey() func (m *WgMeshConfigApplyer) convertMeshNode(params convertMeshNodeParams) (*wgtypes.PeerConfig, error) {
pubKey, err := params.node.GetPublicKey()
if err != nil { if err != nil {
return nil, err return nil, err
} }
allowedips := make([]net.IPNet, 1) allowedips := make([]net.IPNet, 1)
allowedips[0] = *node.GetWgHost() allowedips[0] = *params.node.GetWgHost()
for _, route := range node.GetRoutes() { clients, ok := params.peerToClients[pubKey.String()]
_, ipnet, _ := net.ParseCIDR(route)
allowedips = append(allowedips, *ipnet) if ok {
allowedips = append(allowedips, clients...)
}
for _, route := range params.node.GetRoutes() {
bestRoutes := params.routes[route.GetDestination().String()]
var pickedRoute routeNode
if len(bestRoutes) == 1 {
pickedRoute = bestRoutes[0]
} else if len(bestRoutes) > 1 {
bucketFunc := func(rn routeNode) int {
return lib.HashString(rn.gateway)
}
// Else there is more than one candidate so consistently hash
pickedRoute = lib.ConsistentHash(bestRoutes, params.self, bucketFunc, m.hashFunc)
}
if pickedRoute.gateway == pubKey.String() {
allowedips = append(allowedips, *pickedRoute.route.GetDestination())
}
}
config := params.mesh.GetConfiguration()
var keepAlive time.Duration = time.Duration(0)
if config.KeepAliveWg != nil {
keepAlive = time.Duration(*config.KeepAliveWg) * time.Second
}
existing := slices.IndexFunc(params.device.Peers, func(p wgtypes.Peer) bool {
pubKey, _ := params.node.GetPublicKey()
return p.PublicKey.String() == pubKey.String()
})
var endpoint *net.UDPAddr = nil
if params.node.GetType() == conf.PEER_ROLE {
endpoint, err = net.ResolveUDPAddr("udp", params.node.GetWgEndpoint())
}
if err != nil {
return nil, err
}
// Don't override the existing IP in case it already exists
if existing != -1 {
endpoint = params.device.Peers[existing].Endpoint
} }
peerConfig := wgtypes.PeerConfig{ peerConfig := wgtypes.PeerConfig{
PublicKey: pubKey, PublicKey: pubKey,
Endpoint: endpoint, Endpoint: endpoint,
AllowedIPs: allowedips, AllowedIPs: allowedips,
PersistentKeepaliveInterval: &keepAlive,
ReplaceAllowedIPs: true,
} }
return &peerConfig, nil return &peerConfig, nil
} }
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { // getRoutes: finds the routes with the least hop distance. If more than one route exists
snap, err := mesh.GetMesh() // consistently hash to evenly spread the distribution of traffic
func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) (map[string][]routeNode, error) {
mesh, err := meshProvider.GetMesh()
if err != nil { if err != nil {
return err return nil, err
} }
nodes := snap.GetNodes() routes := make(map[string][]routeNode)
peerConfigs := make([]wgtypes.PeerConfig, len(nodes))
var count int = 0 peers := lib.Filter(lib.MapValues(mesh.GetNodes()), func(p MeshNode) bool {
return p.GetType() == conf.PEER_ROLE
})
for _, n := range nodes { meshPrefixes := lib.Map(lib.MapValues(m.meshManager.GetMeshes()), func(mesh MeshProvider) *net.IPNet {
peer, err := convertMeshNode(n) ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
return ipNet
})
if err != nil { for _, node := range mesh.GetNodes() {
return err pubKey, _ := node.GetPublicKey()
for _, route := range node.GetRoutes() {
if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool {
if prefix.IP.Equal(net.IPv6zero) && *meshProvider.GetConfiguration().AdvertiseDefaultRoute {
return true
}
return prefix.Contains(route.GetDestination().IP)
}) {
continue
}
destination := route.GetDestination().String()
otherRoute, ok := routes[destination]
rn := routeNode{
gateway: pubKey.String(),
route: route,
}
// Client's only acessible by another peer
if node.GetType() == conf.CLIENT_ROLE {
peer := m.getCorrespondingPeer(peers, node)
self, err := meshProvider.GetNode(m.meshManager.GetPublicKey().String())
if err != nil {
return nil, err
}
if !NodeEquals(peer, self) {
peerPub, _ := peer.GetPublicKey()
rn.gateway = peerPub.String()
rn.route = &RouteStub{
Destination: rn.route.GetDestination(),
Path: append(rn.route.GetPath(), peer.GetWgHost().IP.String()),
}
}
}
if !ok {
otherRoute = make([]routeNode, 1)
otherRoute[0] = rn
routes[destination] = otherRoute
} else if route.GetHopCount() < otherRoute[0].route.GetHopCount() {
otherRoute[0] = rn
} else if otherRoute[0].route.GetHopCount() == route.GetHopCount() {
routes[destination] = append(otherRoute, rn)
}
}
}
return routes, nil
}
// getCorrespondignPeer: gets the peer corresponding to the client
func (m *WgMeshConfigApplyer) getCorrespondingPeer(peers []MeshNode, client MeshNode) MeshNode {
peer := lib.ConsistentHash(peers, client, m.hashFunc, m.hashFunc)
return peer
}
// getPeerCfgsToRemove: remove peer configurations that are no longer in the mesh
func (m *WgMeshConfigApplyer) getPeerCfgsToRemove(dev *wgtypes.Device, newPeers []wgtypes.PeerConfig) []wgtypes.PeerConfig {
peers := dev.Peers
peers = lib.Filter(peers, func(p1 wgtypes.Peer) bool {
return !lib.Contains(newPeers, func(p2 wgtypes.PeerConfig) bool {
return p1.PublicKey.String() == p2.PublicKey.String()
})
})
return lib.Map(peers, func(p wgtypes.Peer) wgtypes.PeerConfig {
return wgtypes.PeerConfig{
PublicKey: p.PublicKey,
Remove: true,
}
})
}
type GetConfigParams struct {
mesh MeshProvider
peers []MeshNode
clients []MeshNode
dev *wgtypes.Device
routes map[string][]routeNode
}
// getClientConfig: if the node is a client get their configuration
func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes.Config, error) {
ula := &ip.ULABuilder{}
meshNet, _ := ula.GetIPNet(params.mesh.GetMeshId())
routesForMesh := lib.Map(lib.MapValues(params.routes), func(rns []routeNode) []routeNode {
return lib.Filter(rns, func(rn routeNode) bool {
node, err := params.mesh.GetNode(rn.gateway)
return node != nil && err == nil
})
})
routesForMesh = lib.Filter(routesForMesh, func(rns []routeNode) bool {
return len(rns) != 0
})
routes := lib.Map(routesForMesh, func(rs []routeNode) net.IPNet {
return *rs[0].route.GetDestination()
})
routes = append(routes, *meshNet)
self, err := params.mesh.GetNode(m.meshManager.GetPublicKey().String())
if err != nil {
return nil, err
}
if len(params.peers) == 0 {
return nil, fmt.Errorf("no peers in the mesh")
}
peer := m.getCorrespondingPeer(params.peers, self)
pubKey, _ := peer.GetPublicKey()
config := params.mesh.GetConfiguration()
keepAlive := time.Duration(*config.KeepAliveWg) * time.Second
endpoint, err := net.ResolveUDPAddr("udp", peer.GetWgEndpoint())
if err != nil {
return nil, err
}
peerCfgs := make([]wgtypes.PeerConfig, 1)
peerCfgs[0] = wgtypes.PeerConfig{
PublicKey: pubKey,
Endpoint: endpoint,
PersistentKeepaliveInterval: &keepAlive,
AllowedIPs: routes,
ReplaceAllowedIPs: true,
}
installedRoutes := make([]lib.Route, 0)
for _, route := range peerCfgs[0].AllowedIPs {
// Don't install routes that we are directly apart
// Dont install default route wgctrl handles this for us
if !meshNet.Contains(route.IP) {
installedRoutes = append(installedRoutes, lib.Route{
Gateway: peer.GetWgHost().IP,
Destination: route,
})
}
}
cfg := wgtypes.Config{
Peers: peerCfgs,
}
m.routeInstaller.InstallRoutes(params.dev.Name, installedRoutes...)
return &cfg, err
}
// getRoutesToInstall: work out if the given node is advertising routes that should be installed into the
// RIB
func (m *WgMeshConfigApplyer) getRoutesToInstall(wgNode *wgtypes.PeerConfig, mesh MeshProvider, node MeshNode) []lib.Route {
routes := make([]lib.Route, 0)
for _, route := range wgNode.AllowedIPs {
ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
// Check there is no overlap in network and its not the default route
if !ipNet.Contains(route.IP) {
routes = append(routes, lib.Route{
Gateway: node.GetWgHost().IP,
Destination: route,
})
}
}
return routes
}
// getPeerConfig: creates the WireGuard configuration for a peer
func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.Config, error) {
peerToClients := make(map[string][]net.IPNet)
installedRoutes := make([]lib.Route, 0)
peerConfigs := make([]wgtypes.PeerConfig, 0)
self, err := params.mesh.GetNode(m.meshManager.GetPublicKey().String())
if err != nil {
return nil, err
}
for _, n := range params.clients {
if len(params.peers) > 0 {
peer := m.getCorrespondingPeer(params.peers, n)
pubKey, _ := peer.GetPublicKey()
clients, ok := peerToClients[pubKey.String()]
if !ok {
clients = make([]net.IPNet, 0)
peerToClients[pubKey.String()] = clients
}
peerToClients[pubKey.String()] = append(clients, *n.GetWgHost())
if NodeEquals(self, peer) {
cfg, err := m.convertMeshNode(convertMeshNodeParams{
node: n,
self: self,
mesh: params.mesh,
device: params.dev,
peerToClients: peerToClients,
routes: params.routes,
})
if err != nil {
return nil, err
}
installedRoutes = append(installedRoutes, m.getRoutesToInstall(cfg, params.mesh, n)...)
peerConfigs = append(peerConfigs, *cfg)
}
}
}
for _, n := range params.peers {
if NodeEquals(n, self) {
continue
} }
peerConfigs[count] = *peer peer, err := m.convertMeshNode(convertMeshNodeParams{
count++ node: n,
self: self,
mesh: params.mesh,
peerToClients: peerToClients,
routes: params.routes,
device: params.dev,
})
if err != nil {
return nil, err
}
installedRoutes = append(installedRoutes, m.getRoutesToInstall(peer, params.mesh, n)...)
peerConfigs = append(peerConfigs, *peer)
} }
cfg := wgtypes.Config{ cfg := wgtypes.Config{
Peers: peerConfigs, Peers: peerConfigs,
} }
dev, err := mesh.GetDevice() err = m.routeInstaller.InstallRoutes(params.dev.Name, installedRoutes...)
return &cfg, err
}
// updateWgConf: update the WireGuard configuration
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string][]routeNode) error {
snap, err := mesh.GetMesh()
if err != nil { if err != nil {
return err return err
} }
return m.meshManager.GetClient().ConfigureDevice(dev.Name, cfg) nodes := lib.MapValues(snap.GetNodes())
} dev, _ := mesh.GetDevice()
func (m *WgMeshConfigApplyer) ApplyConfig() error { slices.SortFunc(nodes, func(a, b MeshNode) int {
for _, mesh := range m.meshManager.GetMeshes() { return strings.Compare(string(a.GetType()), string(b.GetType()))
err := m.updateWgConf(mesh) })
if err != nil { peers := lib.Filter(nodes, func(mn MeshNode) bool {
return err return mn.GetType() == conf.PEER_ROLE
} })
clients := lib.Filter(nodes, func(mn MeshNode) bool {
return mn.GetType() == conf.CLIENT_ROLE
})
self, err := mesh.GetNode(m.meshManager.GetPublicKey().String())
if err != nil {
return err
}
var cfg *wgtypes.Config = nil
configParams := &GetConfigParams{
mesh: mesh,
peers: peers,
clients: clients,
dev: dev,
routes: routes,
}
switch self.GetType() {
case conf.PEER_ROLE:
cfg, err = m.getPeerConfig(configParams)
case conf.CLIENT_ROLE:
cfg, err = m.getClientConfig(configParams)
}
if err != nil {
return err
}
toRemove := m.getPeerCfgsToRemove(dev, cfg.Peers)
cfg.Peers = append(cfg.Peers, toRemove...)
err = m.meshManager.GetClient().ConfigureDevice(dev.Name, *cfg)
if err != nil {
return err
} }
return nil return nil
} }
func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error { // getAllRoutes: works out all the routes to install out of all the routes in the
mesh := m.meshManager.GetMesh(meshId) // set of networks the node is a part of
func (m *WgMeshConfigApplyer) getAllRoutes() (map[string][]routeNode, error) {
allRoutes := make(map[string][]routeNode)
if mesh == nil { for _, mesh := range m.meshManager.GetMeshes() {
return fmt.Errorf("mesh %s does not exist", meshId) routes, err := m.getRoutes(mesh)
if err != nil {
return nil, err
}
for destination, route := range routes {
_, ok := allRoutes[destination]
if !ok {
allRoutes[destination] = route
continue
}
if allRoutes[destination][0].route.GetHopCount() == route[0].route.GetHopCount() {
allRoutes[destination] = append(allRoutes[destination], route...)
} else if route[0].route.GetHopCount() < allRoutes[destination][0].route.GetHopCount() {
allRoutes[destination] = route
}
}
} }
dev, err := mesh.GetDevice() return allRoutes, nil
}
// ApplyConfig: apply the WireGuard configuration
func (m *WgMeshConfigApplyer) ApplyConfig() error {
allRoutes, err := m.getAllRoutes()
if err != nil { if err != nil {
return err return err
} }
m.meshManager.GetClient().ConfigureDevice(dev.Name, wgtypes.Config{ for _, mesh := range m.meshManager.GetMeshes() {
ReplacePeers: true, err := m.updateWgConf(mesh, allRoutes)
Peers: make([]wgtypes.PeerConfig, 1),
}) if err != nil {
return err
}
}
return nil return nil
} }
@ -123,5 +512,11 @@ func (m *WgMeshConfigApplyer) SetMeshManager(manager MeshManager) {
} }
func NewWgMeshConfigApplyer() MeshConfigApplyer { func NewWgMeshConfigApplyer() MeshConfigApplyer {
return &WgMeshConfigApplyer{} return &WgMeshConfigApplyer{
routeInstaller: route.NewRouteInstaller(),
hashFunc: func(mn MeshNode) int {
pubKey, _ := mn.GetPublicKey()
return lib.HashString(pubKey.String())
},
}
} }

View File

@ -1,77 +0,0 @@
package mesh
import (
"errors"
"fmt"
"github.com/tim-beatham/wgmesh/pkg/graph"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
// MeshGraphConverter converts a mesh to a graph
type MeshGraphConverter interface {
// convert the mesh to textual form
Generate(meshId string) (string, error)
}
type MeshDOTConverter struct {
manager MeshManager
}
func (c *MeshDOTConverter) Generate(meshId string) (string, error) {
mesh := c.manager.GetMesh(meshId)
if mesh == nil {
return "", errors.New("mesh does not exist")
}
g := graph.NewGraph(meshId, graph.GRAPH)
snapshot, err := mesh.GetMesh()
if err != nil {
return "", err
}
for _, node := range snapshot.GetNodes() {
c.graphNode(g, node, meshId)
}
nodes := lib.MapValues(snapshot.GetNodes())
for i, node1 := range nodes[:len(nodes)-1] {
for _, node2 := range nodes[i+1:] {
if node1.GetWgEndpoint() == node2.GetWgEndpoint() {
continue
}
node1Id := fmt.Sprintf("\"%s\"", node1.GetIdentifier())
node2Id := fmt.Sprintf("\"%s\"", node2.GetIdentifier())
g.AddEdge(fmt.Sprintf("%s to %s", node1Id, node2Id), node1Id, node2Id)
}
}
return g.GetDOT()
}
// graphNode: graphs a node within the mesh
func (c *MeshDOTConverter) graphNode(g *graph.Graph, node MeshNode, meshId string) {
nodeId := fmt.Sprintf("\"%s\"", node.GetIdentifier())
g.PutNode(nodeId, graph.CIRCLE)
self, _ := c.manager.GetSelf(meshId)
if node.GetHostEndpoint() == self.GetHostEndpoint() {
return
}
for _, route := range node.GetRoutes() {
routeId := fmt.Sprintf("\"%s\"", route)
g.PutNode(routeId, graph.HEXAGON)
g.AddEdge(fmt.Sprintf("%s to %s", nodeId, routeId), nodeId, routeId)
}
}
func NewMeshDotConverter(m MeshManager) MeshGraphConverter {
return &MeshDOTConverter{manager: m}
}

View File

@ -3,114 +3,228 @@ package mesh
import ( import (
"errors" "errors"
"fmt" "fmt"
"net"
"sync"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/cmd"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/ip"
logging "github.com/tim-beatham/wgmesh/pkg/log" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/wg" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/wg"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// MeshManager: abstracts maanging meshes, including installing the WireGuard configuration
// to the device, and adding and removing nodes
type MeshManager interface { type MeshManager interface {
CreateMesh(devName string, port int) (string, error) CreateMesh(params *CreateMeshParams) (string, error)
AddMesh(params *AddMeshParams) error AddMesh(params *AddMeshParams) error
HasChanges(meshid string) bool HasChanges(meshid string) bool
GetMesh(meshId string) MeshProvider GetMesh(meshId string) MeshProvider
EnableInterface(meshId string) error GetPublicKey() *wgtypes.Key
GetPublicKey(meshId string) (*wgtypes.Key, error)
AddSelf(params *AddSelfParams) error AddSelf(params *AddSelfParams) error
LeaveMesh(meshId string) error LeaveMesh(meshId string) error
GetSelf(meshId string) (MeshNode, error) GetSelf(meshId string) (MeshNode, error)
ApplyConfig() error ApplyConfig() error
SetDescription(description string) error SetDescription(meshId, description string) error
SetAlias(meshId, alias string) error
SetService(meshId, service, value string) error
RemoveService(meshId, service string) error
UpdateTimeStamp() error UpdateTimeStamp() error
GetClient() *wgctrl.Client GetClient() *wgctrl.Client
GetMeshes() map[string]MeshProvider GetMeshes() map[string]MeshProvider
Prune() error
Close() error Close() error
GetNode(string, string) MeshNode
GetRouteManager() RouteManager
} }
type MeshManagerImpl struct { type MeshManagerImpl struct {
Meshes map[string]MeshProvider meshLock sync.RWMutex
RouteManager RouteManager meshes map[string]MeshProvider
Client *wgctrl.Client RouteManager RouteManager
// HostParameters contains information that uniquely locates Client *wgctrl.Client
// the node in the mesh network.
HostParameters *HostParameters HostParameters *HostParameters
conf *conf.WgMeshConfiguration conf *conf.DaemonConfiguration
meshProviderFactory MeshProviderFactory meshProviderFactory MeshProviderFactory
nodeFactory MeshNodeFactory nodeFactory MeshNodeFactory
configApplyer MeshConfigApplyer configApplyer MeshConfigApplyer
idGenerator lib.IdGenerator idGenerator lib.IdGenerator
ipAllocator ip.IPAllocator ipAllocator ip.IPAllocator
interfaceManipulator wg.WgInterfaceManipulator interfaceManipulator wg.WgInterfaceManipulator
cmdRunner cmd.CmdRunner
OnDelete func(MeshProvider)
} }
// Prune implements MeshManager. func (m *MeshManagerImpl) GetRouteManager() RouteManager {
func (m *MeshManagerImpl) Prune() error { return m.RouteManager
for _, mesh := range m.Meshes { }
err := mesh.Prune(m.conf.PruneTime)
// RemoveService: remove a service from the given mesh.
func (m *MeshManagerImpl) RemoveService(meshId, service string) error {
mesh := m.GetMesh(meshId)
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
if !mesh.NodeExists(m.HostParameters.GetPublicKey()) {
return fmt.Errorf("node %s does not exist in the mesh", meshId)
}
return mesh.RemoveService(m.HostParameters.GetPublicKey(), service)
}
// SetService: add a service to the given mesh
func (m *MeshManagerImpl) SetService(meshId, service, value string) error {
mesh := m.GetMesh(meshId)
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
if !mesh.NodeExists(m.HostParameters.GetPublicKey()) {
return fmt.Errorf("node %s does not exist in the mesh", meshId)
}
return mesh.AddService(m.HostParameters.GetPublicKey(), service, value)
}
// GetNode: gets the node with given id in the mesh network
func (m *MeshManagerImpl) GetNode(meshid, nodeId string) MeshNode {
mesh, ok := m.meshes[meshid]
if !ok {
return nil
}
node, err := mesh.GetNode(nodeId)
if err != nil {
return nil
}
return node
}
// CreateMeshParams contains the parameters required to create a mesh
type CreateMeshParams struct {
Port int
Conf *conf.WgConfiguration
}
// getConf: gets the new configuration with the base configuration overriden
// from the recent
func (m *MeshManagerImpl) getConf(override *conf.WgConfiguration) (*conf.WgConfiguration, error) {
meshConfiguration := m.conf.BaseConfiguration
if override != nil {
newConf, err := conf.MergeMeshConfiguration(meshConfiguration, *override)
if err != nil {
return nil, err
}
meshConfiguration = newConf
}
return &meshConfiguration, nil
}
// CreateMesh: Creates a new mesh, stores it and returns the mesh id
func (m *MeshManagerImpl) CreateMesh(args *CreateMeshParams) (string, error) {
meshConfiguration, err := m.getConf(args.Conf)
if err != nil {
return "", err
}
if *meshConfiguration.Role == conf.CLIENT_ROLE {
return "", fmt.Errorf("cannot create mesh as a client")
}
meshId, err := m.idGenerator.GetId()
var ifName string = ""
if err != nil {
return "", err
}
m.cmdRunner.RunCommands(m.conf.BaseConfiguration.PreUp...)
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(args.Port, m.HostParameters.PrivateKey)
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
}
}
nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: ifName,
Port: args.Port,
Conf: meshConfiguration,
Client: m.Client,
MeshId: meshId,
DaemonConf: m.conf,
NodeID: m.HostParameters.GetPublicKey(),
})
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
}
m.meshLock.Lock()
m.meshes[meshId] = nodeManager
m.meshLock.Unlock()
m.cmdRunner.RunCommands(m.conf.BaseConfiguration.PostUp...)
return meshId, nil
}
type AddMeshParams struct {
MeshId string
WgPort int
MeshBytes []byte
Conf *conf.WgConfiguration
}
// AddMesh: Add a new mesh network to the list of addresses
func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
var ifName string
var err error
meshConfiguration, err := m.getConf(params.Conf)
if err != nil {
return err
}
m.cmdRunner.RunCommands(meshConfiguration.PreUp...)
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(params.WgPort, m.HostParameters.PrivateKey)
if err != nil { if err != nil {
return err return err
} }
} }
return nil
}
// CreateMesh: Creates a new mesh, stores it and returns the mesh id
func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) {
meshId, err := m.idGenerator.GetId()
if err != nil {
return "", err
}
nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: devName,
Port: port,
Conf: m.conf,
Client: m.Client,
MeshId: meshId,
})
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
}
err = m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: devName,
Port: port,
})
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
}
m.Meshes[meshId] = nodeManager
return meshId, nil
}
type AddMeshParams struct {
MeshId string
DevName string
WgPort int
MeshBytes []byte
}
// AddMesh: Add the mesh to the list of meshes
func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{ meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: params.DevName, DevName: ifName,
Port: params.WgPort, Port: params.WgPort,
Conf: m.conf, Conf: meshConfiguration,
Client: m.Client, Client: m.Client,
MeshId: params.MeshId, MeshId: params.MeshId,
DaemonConf: m.conf,
NodeID: m.HostParameters.GetPublicKey(),
}) })
m.cmdRunner.RunCommands(meshConfiguration.PostUp...)
if err != nil { if err != nil {
return err return err
} }
@ -121,70 +235,41 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
return err return err
} }
m.Meshes[params.MeshId] = meshProvider m.meshLock.Lock()
m.meshes[params.MeshId] = meshProvider
return m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{ m.meshLock.Unlock()
IfName: params.DevName,
Port: params.WgPort,
})
}
// HasChanges returns true if the mesh has changes
func (m *MeshManagerImpl) HasChanges(meshId string) bool {
return m.Meshes[meshId].HasChanges()
}
// GetMesh returns the mesh with the given meshid
func (m *MeshManagerImpl) GetMesh(meshId string) MeshProvider {
theMesh := m.Meshes[meshId]
return theMesh
}
// EnableInterface: Enables the given WireGuard interface.
func (s *MeshManagerImpl) EnableInterface(meshId string) error {
err := s.configApplyer.ApplyConfig()
if err != nil {
return err
}
err = s.RouteManager.InstallRoutes()
if err != nil {
return err
}
return nil return nil
} }
// GetPublicKey: Gets the public key of the WireGuard mesh // HasChanges: returns true if the mesh has changes
func (s *MeshManagerImpl) GetPublicKey(meshId string) (*wgtypes.Key, error) { func (m *MeshManagerImpl) HasChanges(meshId string) bool {
mesh, ok := s.Meshes[meshId] return m.meshes[meshId].HasChanges()
if !ok {
return nil, errors.New("mesh does not exist")
}
dev, err := mesh.GetDevice()
if err != nil {
return nil, err
}
return &dev.PublicKey, nil
} }
// GetMesh: returns the mesh with the given meshid
func (m *MeshManagerImpl) GetMesh(meshId string) MeshProvider {
theMesh := m.meshes[meshId]
return theMesh
}
// GetPublicKey: Gets the public key of the WireGuard mesh
func (s *MeshManagerImpl) GetPublicKey() *wgtypes.Key {
key := s.HostParameters.PrivateKey.PublicKey()
return &key
}
// AddSelfParams: parameters required to add yourself to a mesh
// network
type AddSelfParams struct { type AddSelfParams struct {
// MeshId is the ID of the mesh to add this instance to // MeshId is the ID of the mesh to add this instance to
MeshId string MeshId string
// WgPort is the WireGuard port to advertise // WgPort is the WireGuard port to advertise
WgPort int WgPort int
// Endpoint is the alias of the machine to send routable packets // Endpoint is the alias of the machine to send routable packets
// to
Endpoint string Endpoint string
} }
// AddSelf adds this host to the mesh // AddSelf: adds this host to the mesh
func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error { func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
mesh := s.GetMesh(params.MeshId) mesh := s.GetMesh(params.MeshId)
@ -192,123 +277,179 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
return fmt.Errorf("addself: mesh %s does not exist", params.MeshId) return fmt.Errorf("addself: mesh %s does not exist", params.MeshId)
} }
pubKey, err := s.GetPublicKey(params.MeshId) if params.WgPort == 0 && !s.conf.StubWg {
device, err := mesh.GetDevice()
if err != nil {
return err
}
nodeIP, err := s.ipAllocator.GetIP(*pubKey, params.MeshId)
if err != nil {
return err
}
node := s.nodeFactory.Build(&MeshNodeFactoryParams{
PublicKey: pubKey,
NodeIP: nodeIP,
WgPort: params.WgPort,
Endpoint: params.Endpoint,
})
device, err := mesh.GetDevice()
if err != nil {
return fmt.Errorf("failed to get device %w", err)
}
err = s.interfaceManipulator.AddAddress(device.Name, fmt.Sprintf("%s/64", nodeIP))
if err != nil {
return fmt.Errorf("addSelf: failed to add address to dev %w", err)
}
s.Meshes[params.MeshId].AddNode(node)
return s.RouteManager.UpdateRoutes()
}
// LeaveMesh leaves the mesh network
func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
mesh, exists := s.Meshes[meshId]
if !exists {
return fmt.Errorf("mesh %s does not exist", meshId)
}
err := s.RouteManager.RemoveRoutes(meshId)
if err != nil {
return err
}
device, err := mesh.GetDevice()
if err != nil {
return err
}
err = s.interfaceManipulator.RemoveInterface(device.Name)
delete(s.Meshes, meshId)
return err
}
func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
meshInstance, ok := s.Meshes[meshId]
if !ok {
return nil, fmt.Errorf("mesh %s does not exist", meshId)
}
snapshot, err := meshInstance.GetMesh()
if err != nil {
return nil, err
}
node, ok := snapshot.GetNodes()[s.HostParameters.HostEndpoint]
if !ok {
return nil, errors.New("the node doesn't exist in the mesh")
}
return node, nil
}
func (s *MeshManagerImpl) ApplyConfig() error {
err := s.configApplyer.ApplyConfig()
if err != nil {
return err
}
return s.RouteManager.InstallRoutes()
}
func (s *MeshManagerImpl) SetDescription(description string) error {
for _, mesh := range s.Meshes {
err := mesh.SetDescription(s.HostParameters.HostEndpoint, description)
if err != nil { if err != nil {
return err return err
} }
params.WgPort = device.ListenPort
} }
return nil pubKey := s.HostParameters.PrivateKey.PublicKey()
}
collisionCount := uint8(0)
var nodeIP net.IP
// Perform Duplicate Address Detection with the nodes
// that are already in the network
for {
generatedIP, err := s.ipAllocator.GetIP(pubKey, params.MeshId, collisionCount)
if err != nil {
return err
}
// UpdateTimeStamp updates the timestamp of this node in all meshes
func (s *MeshManagerImpl) UpdateTimeStamp() error {
for _, mesh := range s.Meshes {
snapshot, err := mesh.GetMesh() snapshot, err := mesh.GetMesh()
if err != nil { if err != nil {
return err return err
} }
_, exists := snapshot.GetNodes()[s.HostParameters.HostEndpoint] proposition := func(node MeshNode) bool {
ipNet := node.GetWgHost()
return ipNet.IP.Equal(nodeIP)
}
if exists { if lib.Contains(lib.MapValues(snapshot.GetNodes()), proposition) {
err = mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint) collisionCount++
} else {
nodeIP = generatedIP
break
}
}
node := s.nodeFactory.Build(&MeshNodeFactoryParams{
PublicKey: &pubKey,
NodeIP: nodeIP,
WgPort: params.WgPort,
Endpoint: params.Endpoint,
MeshConfig: mesh.GetConfiguration(),
})
if !s.conf.StubWg {
device, err := mesh.GetDevice()
if err != nil {
return fmt.Errorf("failed to get device %w", err)
}
err = s.interfaceManipulator.AddAddress(device.Name, fmt.Sprintf("%s/64", nodeIP))
if err != nil {
return fmt.Errorf("addSelf: failed to add address to dev %w", err)
}
}
s.meshes[params.MeshId].AddNode(node)
return nil
}
// LeaveMesh: leaves the mesh network and force a synchronsiation
func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
mesh := s.GetMesh(meshId)
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
err := mesh.RemoveNode(s.HostParameters.GetPublicKey())
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
if s.OnDelete != nil {
s.OnDelete(mesh)
}
s.meshLock.Lock()
delete(s.meshes, meshId)
s.meshLock.Unlock()
s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PreDown...)
if !s.conf.StubWg {
device, err := mesh.GetDevice()
if err != nil {
return err
}
err = s.interfaceManipulator.RemoveInterface(device.Name)
if err != nil {
return err
}
}
s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PostDown...)
return err
}
func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
meshInstance, ok := s.meshes[meshId]
if !ok {
return nil, fmt.Errorf("mesh %s does not exist", meshId)
}
node, err := meshInstance.GetNode(s.HostParameters.GetPublicKey())
if err != nil {
return nil, errors.New("the node doesn't exist in the mesh")
}
return node, nil
}
// ApplyConfig: applies the WireGuard configuration
// adds routes to the RIB and so forth.
func (s *MeshManagerImpl) ApplyConfig() error {
if s.conf.StubWg {
return nil
}
return s.configApplyer.ApplyConfig()
}
func (s *MeshManagerImpl) SetDescription(meshId, description string) error {
mesh := s.GetMesh(meshId)
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
if !mesh.NodeExists(s.HostParameters.GetPublicKey()) {
return fmt.Errorf("node %s does not exist in the mesh", meshId)
}
return mesh.SetDescription(s.HostParameters.GetPublicKey(), description)
}
// SetAlias sets the alias of the node for the given meshid
func (s *MeshManagerImpl) SetAlias(meshId, alias string) error {
mesh := s.GetMesh(meshId)
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
if !mesh.NodeExists(s.HostParameters.GetPublicKey()) {
return fmt.Errorf("node %s does not exist in the mesh", meshId)
}
return mesh.SetAlias(s.HostParameters.GetPublicKey(), alias)
}
// UpdateTimeStamp: updates the timestamp of this node in all meshes
// essentially performs heartbeat if the node is the leader
func (s *MeshManagerImpl) UpdateTimeStamp() error {
meshes := s.GetMeshes()
for _, mesh := range meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.UpdateTimeStamp(s.HostParameters.GetPublicKey())
if err != nil { if err != nil {
return err return err
@ -323,12 +464,30 @@ func (s *MeshManagerImpl) GetClient() *wgctrl.Client {
return s.Client return s.Client
} }
// GetMeshes: get all meshes the node is part of
func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider { func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider {
return s.Meshes meshes := make(map[string]MeshProvider)
// GetMesh: copies the map of meshes to a new map
// to prevent a whole range of concurrency issues
// due to iteration and modification
s.meshLock.RLock()
for id, mesh := range s.meshes {
meshes[id] = mesh
}
s.meshLock.RUnlock()
return meshes
} }
// Close: close the mesh manager
func (s *MeshManagerImpl) Close() error { func (s *MeshManagerImpl) Close() error {
for _, mesh := range s.Meshes { if s.conf.StubWg {
return nil
}
for _, mesh := range s.meshes {
dev, err := mesh.GetDevice() dev, err := mesh.GetDevice()
if err != nil { if err != nil {
@ -345,9 +504,9 @@ func (s *MeshManagerImpl) Close() error {
return nil return nil
} }
// NewMeshManagerParams params required to create an instance of a mesh manager // NewMeshManagerParams: params required to create an instance of a mesh manager
type NewMeshManagerParams struct { type NewMeshManagerParams struct {
Conf conf.WgMeshConfiguration Conf conf.DaemonConfiguration
Client *wgctrl.Client Client *wgctrl.Client
MeshProvider MeshProviderFactory MeshProvider MeshProviderFactory
NodeFactory MeshNodeFactory NodeFactory MeshNodeFactory
@ -356,23 +515,19 @@ type NewMeshManagerParams struct {
InterfaceManipulator wg.WgInterfaceManipulator InterfaceManipulator wg.WgInterfaceManipulator
ConfigApplyer MeshConfigApplyer ConfigApplyer MeshConfigApplyer
RouteManager RouteManager RouteManager RouteManager
CommandRunner cmd.CmdRunner
OnDelete func(MeshProvider)
} }
// Creates a new instance of a mesh manager with the given parameters // NewMeshManager: Creates a new instance of a mesh manager with the given parameters
func NewMeshManager(params *NewMeshManagerParams) *MeshManagerImpl { func NewMeshManager(params *NewMeshManagerParams) MeshManager {
hostParams := HostParameters{} privateKey, _ := wgtypes.GeneratePrivateKey()
hostParams := HostParameters{
switch params.Conf.Endpoint { PrivateKey: &privateKey,
case "":
hostParams.HostEndpoint = fmt.Sprintf("%s:%s", lib.GetOutboundIP().String(), params.Conf.GrpcPort)
default:
hostParams.HostEndpoint = fmt.Sprintf("%s:%s", params.Conf.Endpoint, params.Conf.GrpcPort)
} }
logging.Log.WriteInfof("Endpoint %s", hostParams.HostEndpoint)
m := &MeshManagerImpl{ m := &MeshManagerImpl{
Meshes: make(map[string]MeshProvider), meshes: make(map[string]MeshProvider),
HostParameters: &hostParams, HostParameters: &hostParams,
meshProviderFactory: params.MeshProvider, meshProviderFactory: params.MeshProvider,
nodeFactory: params.NodeFactory, nodeFactory: params.NodeFactory,
@ -387,8 +542,14 @@ func NewMeshManager(params *NewMeshManagerParams) *MeshManagerImpl {
m.RouteManager = NewRouteManager(m) m.RouteManager = NewRouteManager(m)
} }
if params.CommandRunner == nil {
m.cmdRunner = &cmd.UnixCmdRunner{}
}
m.idGenerator = params.IdGenerator m.idGenerator = params.IdGenerator
m.ipAllocator = params.IPAllocator m.ipAllocator = params.IPAllocator
m.interfaceManipulator = params.InterfaceManipulator m.interfaceManipulator = params.InterfaceManipulator
m.OnDelete = params.OnDelete
return m return m
} }

View File

@ -3,26 +3,43 @@ package mesh
import ( import (
"testing" "testing"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/wg" "github.com/tim-beatham/smegmesh/pkg/wg"
) )
func getMeshConfiguration() *conf.WgMeshConfiguration { func getMeshConfiguration() *conf.DaemonConfiguration {
return &conf.WgMeshConfiguration{ advertiseRoutes := true
GrpcPort: "8080", advertiseDefaultRoute := true
Endpoint: "abc.com", ipDiscovery := conf.PUBLIC_IP_DISCOVERY
ClusterSize: 64, role := conf.PEER_ROLE
SyncRate: 4,
BranchRate: 3, return &conf.DaemonConfiguration{
InterClusterChance: 0.15, GrpcPort: 8080,
InfectionCount: 2, CertificatePath: "./somecertificatepath",
KeepAliveTime: 60, PrivateKeyPath: "./someprivatekeypath",
CaCertificatePath: "./somecacertificatepath",
SkipCertVerification: true,
Timeout: 5,
Profile: false,
StubWg: true,
SyncInterval: 2,
Heartbeat: 60,
ClusterSize: 64,
InterClusterChance: 0.15,
Branch: 3,
InfectionCount: 3,
BaseConfiguration: conf.WgConfiguration{
IPDiscovery: &ipDiscovery,
AdvertiseRoutes: &advertiseRoutes,
AdvertiseDefaultRoute: &advertiseDefaultRoute,
Role: &role,
},
} }
} }
func getMeshManager() *MeshManagerImpl { func getMeshManager() MeshManager {
manager := NewMeshManager(&NewMeshManagerParams{ manager := NewMeshManager(&NewMeshManagerParams{
Conf: *getMeshConfiguration(), Conf: *getMeshConfiguration(),
Client: nil, Client: nil,
@ -41,7 +58,10 @@ func getMeshManager() *MeshManagerImpl {
func TestCreateMeshCreatesANewMeshProvider(t *testing.T) { func TestCreateMeshCreatesANewMeshProvider(t *testing.T) {
manager := getMeshManager() manager := getMeshManager()
meshId, err := manager.CreateMesh("wg0", 5000) meshId, err := manager.CreateMesh(&CreateMeshParams{
Port: 0,
Conf: &conf.WgConfiguration{},
})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -51,7 +71,7 @@ func TestCreateMeshCreatesANewMeshProvider(t *testing.T) {
t.Fatal(`meshId should not be empty`) t.Fatal(`meshId should not be empty`)
} }
_, exists := manager.Meshes[meshId] _, exists := manager.GetMeshes()[meshId]
if !exists { if !exists {
t.Fatal(`mesh was not created when it should be`) t.Fatal(`mesh was not created when it should be`)
@ -64,7 +84,6 @@ func TestAddMeshAddsAMesh(t *testing.T) {
manager.AddMesh(&AddMeshParams{ manager.AddMesh(&AddMeshParams{
MeshId: meshId, MeshId: meshId,
DevName: "wg0",
WgPort: 6000, WgPort: 6000,
MeshBytes: make([]byte, 0), MeshBytes: make([]byte, 0),
}) })
@ -83,7 +102,6 @@ func TestAddMeshMeshAlreadyExistsReplacesIt(t *testing.T) {
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
err := manager.AddMesh(&AddMeshParams{ err := manager.AddMesh(&AddMeshParams{
MeshId: meshId, MeshId: meshId,
DevName: "wg0",
WgPort: 6000, WgPort: 6000,
MeshBytes: make([]byte, 0), MeshBytes: make([]byte, 0),
}) })
@ -106,7 +124,6 @@ func TestAddSelfAddsSelfToTheMesh(t *testing.T) {
err := manager.AddMesh(&AddMeshParams{ err := manager.AddMesh(&AddMeshParams{
MeshId: meshId, MeshId: meshId,
DevName: "wg0",
WgPort: 6000, WgPort: 6000,
MeshBytes: make([]byte, 0), MeshBytes: make([]byte, 0),
}) })
@ -131,7 +148,7 @@ func TestAddSelfAddsSelfToTheMesh(t *testing.T) {
t.Error(err) t.Error(err)
} }
_, ok := mesh.GetNodes()["abc.com"] _, ok := mesh.GetNodes()[manager.GetPublicKey().String()]
if !ok { if !ok {
t.Fatalf(`node has not been added`) t.Fatalf(`node has not been added`)
@ -175,7 +192,6 @@ func TestLeaveMeshDeletesMesh(t *testing.T) {
err := manager.AddMesh(&AddMeshParams{ err := manager.AddMesh(&AddMeshParams{
MeshId: meshId, MeshId: meshId,
DevName: "wg0",
WgPort: 6000, WgPort: 6000,
MeshBytes: make([]byte, 0), MeshBytes: make([]byte, 0),
}) })
@ -190,43 +206,87 @@ func TestLeaveMeshDeletesMesh(t *testing.T) {
t.Fatalf("%s", err.Error()) t.Fatalf("%s", err.Error())
} }
_, exists := manager.Meshes[meshId] _, exists := manager.GetMeshes()[meshId]
if exists { if exists {
t.Fatalf(`expected mesh to have been deleted`) t.Fatalf(`expected mesh to have been deleted`)
} }
} }
func TestSetDescription(t *testing.T) { func TestSetAliasUpdatesAliasOfNode(t *testing.T) {
manager := getMeshManager()
alias := "Firpo"
meshId, _ := manager.CreateMesh(&CreateMeshParams{
Port: 5000,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
err := manager.SetAlias(meshId, alias)
if err != nil {
t.Fatalf(`failed to set the alias`)
}
self, err := manager.GetSelf(meshId)
if err != nil {
t.Fatalf(`failed to set the alias err: %s`, err.Error())
}
if alias != self.GetAlias() {
t.Fatalf(`alias should be %s was %s`, alias, self.GetAlias())
}
}
func TestSetDescriptionSetsTheDescriptionOfTheNode(t *testing.T) {
manager := getMeshManager() manager := getMeshManager()
description := "wooooo" description := "wooooo"
meshId1, _ := manager.CreateMesh("wg0", 5000) meshId1, _ := manager.CreateMesh(&CreateMeshParams{
meshId2, _ := manager.CreateMesh("wg0", 5001) Port: 5000,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{ manager.AddSelf(&AddSelfParams{
MeshId: meshId1, MeshId: meshId1,
WgPort: 5000, WgPort: 5000,
Endpoint: "abc.com:8080", Endpoint: "abc.com:8080",
}) })
manager.AddSelf(&AddSelfParams{
MeshId: meshId2,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
err := manager.SetDescription(description) err := manager.SetDescription(meshId1, description)
if err != nil { if err != nil {
t.Fatalf(`failed to set the descriptions`) t.Fatalf(`failed to set the descriptions`)
} }
}
self1, err := manager.GetSelf(meshId1)
if err != nil {
t.Fatalf(`failed to set the description`)
}
if description != self1.GetDescription() {
t.Fatalf(`description should be %s was %s`, description, self1.GetDescription())
}
}
func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) { func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) {
manager := getMeshManager() manager := getMeshManager()
meshId1, _ := manager.CreateMesh("wg0", 5000) meshId1, _ := manager.CreateMesh(&CreateMeshParams{
meshId2, _ := manager.CreateMesh("wg0", 5001) Port: 5000,
Conf: &conf.WgConfiguration{},
})
meshId2, _ := manager.CreateMesh(&CreateMeshParams{
Port: 5001,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{ manager.AddSelf(&AddSelfParams{
MeshId: meshId1, MeshId: meshId1,
@ -245,3 +305,68 @@ func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) {
t.Fatalf(`failed to update the timestamp`) t.Fatalf(`failed to update the timestamp`)
} }
} }
func TestAddServiceAddsServiceToTheMesh(t *testing.T) {
manager := getMeshManager()
meshId1, _ := manager.CreateMesh(&CreateMeshParams{
Port: 5000,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
serviceName := "hello"
manager.SetService(meshId1, serviceName, "dave")
self, err := manager.GetSelf(meshId1)
if err != nil {
t.Fatalf(`error thrown %s:`, err.Error())
}
if _, ok := self.GetServices()[serviceName]; !ok {
t.Fatalf(`service not added`)
}
}
func TestRemoveServiceRemovesTheServiceFromTheMesh(t *testing.T) {
manager := getMeshManager()
meshId1, _ := manager.CreateMesh(&CreateMeshParams{
Port: 5000,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
serviceName := "hello"
manager.SetService(meshId1, serviceName, "dave")
self, err := manager.GetSelf(meshId1)
if err != nil {
t.Fatalf(`error thrown %s:`, err.Error())
}
if _, ok := self.GetServices()[serviceName]; !ok {
t.Fatalf(`service not added`)
}
manager.RemoveService(meshId1, serviceName)
self, err = manager.GetSelf(meshId1)
if err != nil {
t.Fatalf(`error thrown %s:`, err.Error())
}
if _, ok := self.GetServices()[serviceName]; ok {
t.Fatalf(`service still exists`)
}
}

View File

@ -1,16 +0,0 @@
package mesh
import (
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
func pruneFunction(m MeshManager) lib.TimerFunc {
return func() error {
return m.Prune()
}
}
func NewPruner(m MeshManager, conf conf.WgMeshConfiguration) *lib.Timer {
return lib.NewTimer(pruneFunction(m), conf.PruneTime/2)
}

View File

@ -1,184 +1,127 @@
package mesh package mesh
import ( import (
"fmt"
"net" "net"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/route"
"golang.org/x/sys/unix"
) )
// RouteManager: manager that leaks routes between meshes
type RouteManager interface { type RouteManager interface {
// UpdateRoutes: leak all routes in each mesh
UpdateRoutes() error UpdateRoutes() error
InstallRoutes() error
RemoveRoutes(meshId string) error
} }
type RouteManagerImpl struct { type RouteManagerImpl struct {
meshManager MeshManager meshManager MeshManager
routeInstaller route.RouteInstaller
} }
func (r *RouteManagerImpl) UpdateRoutes() error { func (r *RouteManagerImpl) UpdateRoutes() error {
meshes := r.meshManager.GetMeshes() meshes := r.meshManager.GetMeshes()
ulaBuilder := new(ip.ULABuilder) routes := make(map[string][]Route)
for _, mesh := range meshes {
// Make empty routes so that routes are retracted
routes[mesh.GetMeshId()] = make([]Route, 0)
}
for _, mesh1 := range meshes { for _, mesh1 := range meshes {
if !*mesh1.GetConfiguration().AdvertiseRoutes {
continue
}
self, err := mesh1.GetNode(r.meshManager.GetPublicKey().String())
if err != nil {
return err
}
if _, ok := routes[mesh1.GetMeshId()]; !ok {
routes[mesh1.GetMeshId()] = make([]Route, 0)
}
if *mesh1.GetConfiguration().AdvertiseDefaultRoute {
_, ipv6Default, _ := net.ParseCIDR("::/0")
defaultRoute := &RouteStub{
Destination: ipv6Default,
Path: []string{mesh1.GetMeshId()},
}
mesh1.AddRoutes(NodeID(self), defaultRoute)
routes[mesh1.GetMeshId()] = append(routes[mesh1.GetMeshId()], defaultRoute)
}
routeMap, err := mesh1.GetRoutes(NodeID(self))
if err != nil {
return err
}
for _, mesh2 := range meshes { for _, mesh2 := range meshes {
routeValues, ok := routes[mesh2.GetMeshId()]
if !ok {
routeValues = make([]Route, 0)
}
if mesh1 == mesh2 { if mesh1 == mesh2 {
continue continue
} }
ipNet, err := ulaBuilder.GetIPNet(mesh2.GetMeshId()) mesh1IpNet, _ := (&ip.ULABuilder{}).GetIPNet(mesh1.GetMeshId())
if err != nil { routeValues = append(routeValues, &RouteStub{
logging.Log.WriteErrorf(err.Error()) Destination: mesh1IpNet,
return err Path: []string{mesh1.GetMeshId()},
} })
self, err := r.meshManager.GetSelf(mesh1.GetMeshId()) routeValues = append(routeValues, lib.MapValues(routeMap)...)
mesh2IpNet, _ := (&ip.ULABuilder{}).GetIPNet(mesh2.GetMeshId())
routeValues = lib.Filter(routeValues, func(r Route) bool {
pathNotMesh := func(s string) bool {
return s == mesh2.GetMeshId()
}
if err != nil { // Remove any potential routing loops
return err return !r.GetDestination().IP.Equal(mesh2IpNet.IP) &&
} !lib.Contains(r.GetPath()[1:], pathNotMesh)
})
err = mesh1.AddRoutes(self.GetHostEndpoint(), ipNet.String()) routes[mesh2.GetMeshId()] = routeValues
}
}
if err != nil { // Calculate the set different of each, working out routes to remove and to keep.
return err for meshId, meshRoutes := range routes {
mesh := meshes[meshId]
self, err := mesh.GetNode(r.meshManager.GetPublicKey().String())
if err != nil {
return err
}
toRemove := make([]Route, 0)
prevRoutes := self.GetRoutes()
for _, route := range prevRoutes {
if !lib.Contains(meshRoutes, func(r Route) bool {
return RouteEqual(r, route)
}) {
toRemove = append(toRemove, route)
} }
} }
}
return nil mesh.RemoveRoutes(NodeID(self), toRemove...)
} mesh.AddRoutes(NodeID(self), meshRoutes...)
// removeRoutes: removes all meshes we are no longer a part of
func (r *RouteManagerImpl) RemoveRoutes(meshId string) error {
ulaBuilder := new(ip.ULABuilder)
meshes := r.meshManager.GetMeshes()
ipNet, err := ulaBuilder.GetIPNet(meshId)
if err != nil {
return err
}
for _, mesh1 := range meshes {
self, err := r.meshManager.GetSelf(meshId)
if err != nil {
return err
}
mesh1.RemoveRoutes(self.GetHostEndpoint(), ipNet.String())
}
return nil
}
// AddRoute adds a route to the given interface
func (m *RouteManagerImpl) addRoute(ifName string, meshPrefix string, routes ...lib.Route) error {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return fmt.Errorf("failed to create config: %w", err)
}
defer rtnl.Close()
// Delete any routes that may be vacant
err = rtnl.DeleteRoutes(ifName, unix.AF_INET6, routes...)
if err != nil {
return err
}
for _, route := range routes {
if route.Destination.String() == meshPrefix {
continue
}
err = rtnl.AddRoute(ifName, route)
if err != nil {
return err
}
}
return nil
}
func (m *RouteManagerImpl) installRoute(ifName string, meshid string, node MeshNode) error {
routeMapFunc := func(route string) lib.Route {
_, cidr, _ := net.ParseCIDR(route)
r := lib.Route{
Destination: *cidr,
Gateway: node.GetWgHost().IP,
}
return r
}
ipBuilder := &ip.ULABuilder{}
ipNet, err := ipBuilder.GetIPNet(meshid)
if err != nil {
return err
}
routes := lib.Map(append(node.GetRoutes(), ipNet.String()), routeMapFunc)
return m.addRoute(ifName, ipNet.String(), routes...)
}
func (m *RouteManagerImpl) installRoutes(meshProvider MeshProvider) error {
mesh, err := meshProvider.GetMesh()
if err != nil {
return err
}
dev, err := meshProvider.GetDevice()
if err != nil {
return err
}
self, err := m.meshManager.GetSelf(meshProvider.GetMeshId())
if err != nil {
return err
}
for _, node := range mesh.GetNodes() {
if self.GetHostEndpoint() == node.GetHostEndpoint() {
continue
}
err = m.installRoute(dev.Name, meshProvider.GetMeshId(), node)
if err != nil {
return err
}
}
return nil
}
// InstallRoutes installs all routes to the RIB
func (r *RouteManagerImpl) InstallRoutes() error {
for _, mesh := range r.meshManager.GetMeshes() {
err := r.installRoutes(mesh)
if err != nil {
return err
}
} }
return nil return nil
} }
func NewRouteManager(m MeshManager) RouteManager { func NewRouteManager(m MeshManager) RouteManager {
return &RouteManagerImpl{meshManager: m, routeInstaller: route.NewRouteInstaller()} return &RouteManagerImpl{meshManager: m}
} }

View File

@ -5,7 +5,8 @@ import (
"net" "net"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/lib"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@ -16,9 +17,26 @@ type MeshNodeStub struct {
wgEndpoint string wgEndpoint string
wgHost *net.IPNet wgHost *net.IPNet
timeStamp int64 timeStamp int64
routes []string routes []Route
identifier string identifier string
description string description string
alias string
services map[string]string
}
// GetType implements MeshNode.
func (*MeshNodeStub) GetType() conf.NodeType {
return conf.PEER_ROLE
}
// GetServices implements MeshNode.
func (m *MeshNodeStub) GetServices() map[string]string {
return m.services
}
// GetAlias implements MeshNode.
func (s *MeshNodeStub) GetAlias() string {
return s.alias
} }
func (m *MeshNodeStub) GetHostEndpoint() string { func (m *MeshNodeStub) GetHostEndpoint() string {
@ -41,7 +59,7 @@ func (m *MeshNodeStub) GetTimeStamp() int64 {
return m.timeStamp return m.timeStamp
} }
func (m *MeshNodeStub) GetRoutes() []string { func (m *MeshNodeStub) GetRoutes() []Route {
return m.routes return m.routes
} }
@ -66,23 +84,105 @@ type MeshProviderStub struct {
snapshot *MeshSnapshotStub snapshot *MeshSnapshotStub
} }
// GetConfiguration implements MeshProvider.
func (*MeshProviderStub) GetConfiguration() *conf.WgConfiguration {
advertiseRoutes := true
advertiseDefaultRoute := true
ipDiscovery := conf.PUBLIC_IP_DISCOVERY
role := conf.PEER_ROLE
return &conf.WgConfiguration{
IPDiscovery: &ipDiscovery,
AdvertiseRoutes: &advertiseRoutes,
AdvertiseDefaultRoute: &advertiseDefaultRoute,
Role: &role,
}
}
// Mark implements MeshProvider.
func (*MeshProviderStub) Mark(nodeId string) {
}
// RemoveNode implements MeshProvider.
func (*MeshProviderStub) RemoveNode(nodeId string) error {
return nil
}
func (*MeshProviderStub) GetRoutes(targetId string) (map[string]Route, error) {
return nil, nil
}
// GetNodeIds implements MeshProvider.
func (*MeshProviderStub) GetPeers() []string {
return make([]string, 0)
}
// GetNode implements MeshProvider.
func (m *MeshProviderStub) GetNode(nodeId string) (MeshNode, error) {
return m.snapshot.nodes[nodeId], nil
}
// NodeExists implements MeshProvider.
func (m *MeshProviderStub) NodeExists(nodeId string) bool {
return m.snapshot.nodes[nodeId] != nil
}
// AddService implements MeshProvider.
func (m *MeshProviderStub) AddService(nodeId string, key string, value string) error {
node := (m.snapshot.nodes[nodeId]).(*MeshNodeStub)
node.services[key] = value
return nil
}
// RemoveService implements MeshProvider.
func (m *MeshProviderStub) RemoveService(nodeId string, key string) error {
node := (m.snapshot.nodes[nodeId]).(*MeshNodeStub)
delete(node.services, key)
return nil
}
// SetAlias implements MeshProvider.
func (m *MeshProviderStub) SetAlias(nodeId string, alias string) error {
node := (m.snapshot.nodes[nodeId]).(*MeshNodeStub)
node.alias = alias
return nil
}
// AddRoutes implements
func (m *MeshProviderStub) AddRoutes(nodeId string, route ...Route) error {
node := (m.snapshot.nodes[nodeId]).(*MeshNodeStub)
node.routes = append(node.routes, route...)
return nil
}
// RemoveRoutes implements MeshProvider. // RemoveRoutes implements MeshProvider.
func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error { func (m *MeshProviderStub) RemoveRoutes(nodeId string, route ...Route) error {
panic("unimplemented") node := (m.snapshot.nodes[nodeId]).(*MeshNodeStub)
newRoutes := lib.Filter(node.routes, func(r1 Route) bool {
return !lib.Contains(route, func(r2 Route) bool {
return RouteEqual(r1, r2)
})
})
node.routes = newRoutes
return nil
} }
// Prune implements MeshProvider. // Prune implements MeshProvider.
func (*MeshProviderStub) Prune(pruneAmount int) error { func (*MeshProviderStub) Prune() error {
return nil return nil
} }
// UpdateTimeStamp implements MeshProvider. // UpdateTimeStamp implements MeshProvider.
func (*MeshProviderStub) UpdateTimeStamp(nodeId string) error { func (m *MeshProviderStub) UpdateTimeStamp(nodeId string) error {
node := (m.snapshot.nodes[nodeId]).(*MeshNodeStub)
node.timeStamp = time.Now().Unix()
return nil return nil
} }
func (s *MeshProviderStub) AddNode(node MeshNode) { func (s *MeshProviderStub) AddNode(node MeshNode) {
s.snapshot.nodes[node.GetHostEndpoint()] = node pubKey, _ := node.GetPublicKey()
s.snapshot.nodes[pubKey.String()] = node
} }
func (s *MeshProviderStub) GetMesh() (MeshSnapshot, error) { func (s *MeshProviderStub) GetMesh() (MeshSnapshot, error) {
@ -114,15 +214,13 @@ func (s *MeshProviderStub) HasChanges() bool {
return false return false
} }
func (s *MeshProviderStub) AddRoutes(nodeId string, route ...string) error {
return nil
}
func (s *MeshProviderStub) GetSyncer() MeshSyncer { func (s *MeshProviderStub) GetSyncer() MeshSyncer {
return nil return nil
} }
func (s *MeshProviderStub) SetDescription(nodeId string, description string) error { func (s *MeshProviderStub) SetDescription(nodeId string, description string) error {
meshNode := (s.snapshot.nodes[nodeId]).(*MeshNodeStub)
meshNode.description = description
return nil return nil
} }
@ -136,7 +234,7 @@ func (s *StubMeshProviderFactory) CreateMesh(params *MeshProviderFactoryParams)
} }
type StubNodeFactory struct { type StubNodeFactory struct {
Config *conf.WgMeshConfiguration Config *conf.DaemonConfiguration
} }
func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode { func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode {
@ -145,12 +243,13 @@ func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode {
return &MeshNodeStub{ return &MeshNodeStub{
hostEndpoint: params.Endpoint, hostEndpoint: params.Endpoint,
publicKey: *params.PublicKey, publicKey: *params.PublicKey,
wgEndpoint: fmt.Sprintf("%s:%s", params.Endpoint, s.Config.GrpcPort), wgEndpoint: fmt.Sprintf("%s:%d", params.Endpoint, s.Config.GrpcPort),
wgHost: wgHost, wgHost: wgHost,
timeStamp: time.Now().Unix(), timeStamp: time.Now().Unix(),
routes: make([]string, 0), routes: make([]Route, 0),
identifier: "abc", identifier: "abc",
description: "A Mesh Node Stub", description: "A Mesh Node Stub",
services: make(map[string]string),
} }
} }
@ -171,9 +270,34 @@ type MeshManagerStub struct {
meshes map[string]MeshProvider meshes map[string]MeshProvider
} }
// GetRouteManager implements MeshManager.
func (*MeshManagerStub) GetRouteManager() RouteManager {
return nil
}
// GetNode implements MeshManager.
func (*MeshManagerStub) GetNode(meshId, nodeId string) MeshNode {
return nil
}
// RemoveService implements MeshManager.
func (*MeshManagerStub) RemoveService(meshId, service string) error {
return nil
}
// SetService implements MeshManager.
func (*MeshManagerStub) SetService(meshId, service, value string) error {
return nil
}
// SetAlias implements MeshManager.
func (*MeshManagerStub) SetAlias(meshId, alias string) error {
return nil
}
// Close implements MeshManager. // Close implements MeshManager.
func (*MeshManagerStub) Close() error { func (*MeshManagerStub) Close() error {
panic("unimplemented") return nil
} }
// Prune implements MeshManager. // Prune implements MeshManager.
@ -185,7 +309,7 @@ func NewMeshManagerStub() MeshManager {
return &MeshManagerStub{meshes: make(map[string]MeshProvider)} return &MeshManagerStub{meshes: make(map[string]MeshProvider)}
} }
func (m *MeshManagerStub) CreateMesh(devName string, port int) (string, error) { func (m *MeshManagerStub) CreateMesh(*CreateMeshParams) (string, error) {
return "tim123", nil return "tim123", nil
} }
@ -208,13 +332,9 @@ func (m *MeshManagerStub) GetMesh(meshId string) MeshProvider {
snapshot: &MeshSnapshotStub{nodes: make(map[string]MeshNode)}} snapshot: &MeshSnapshotStub{nodes: make(map[string]MeshNode)}}
} }
func (m *MeshManagerStub) EnableInterface(meshId string) error { func (m *MeshManagerStub) GetPublicKey() *wgtypes.Key {
return nil
}
func (m *MeshManagerStub) GetPublicKey(meshId string) (*wgtypes.Key, error) {
key, _ := wgtypes.GenerateKey() key, _ := wgtypes.GenerateKey()
return &key, nil return &key
} }
func (m *MeshManagerStub) AddSelf(params *AddSelfParams) error { func (m *MeshManagerStub) AddSelf(params *AddSelfParams) error {
@ -229,7 +349,7 @@ func (m *MeshManagerStub) ApplyConfig() error {
return nil return nil
} }
func (m *MeshManagerStub) SetDescription(description string) error { func (m *MeshManagerStub) SetDescription(meshId, description string) error {
return nil return nil
} }

View File

@ -4,12 +4,45 @@ package mesh
import ( import (
"net" "net"
"slices"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
type Route interface {
// GetDestination: returns the destination of the route
GetDestination() *net.IPNet
// GetHopCount: get the total hopcount of the prefix
GetHopCount() int
// GetPath: get a list of AS paths to get to the destination
GetPath() []string
}
func RouteEqual(r1 Route, r2 Route) bool {
return r1.GetDestination().String() == r2.GetDestination().String() &&
r1.GetHopCount() == r2.GetHopCount() &&
slices.Equal(r1.GetPath(), r2.GetPath())
}
type RouteStub struct {
Destination *net.IPNet
Path []string
}
func (r *RouteStub) GetDestination() *net.IPNet {
return r.Destination
}
func (r *RouteStub) GetHopCount() int {
return len(r.Path)
}
func (r *RouteStub) GetPath() []string {
return r.Path
}
// MeshNode represents an implementation of a node in a mesh // MeshNode represents an implementation of a node in a mesh
type MeshNode interface { type MeshNode interface {
// GetHostEndpoint: gets the gRPC endpoint of the node // GetHostEndpoint: gets the gRPC endpoint of the node
@ -23,11 +56,34 @@ type MeshNode interface {
// GetTimestamp: get the UNIX time stamp of the ndoe // GetTimestamp: get the UNIX time stamp of the ndoe
GetTimeStamp() int64 GetTimeStamp() int64
// GetRoutes: returns the routes that the nodes provides // GetRoutes: returns the routes that the nodes provides
GetRoutes() []string GetRoutes() []Route
// GetIdentifier: returns the identifier of the node // GetIdentifier: returns the identifier of the node
GetIdentifier() string GetIdentifier() string
// GetDescription: returns the description for this node // GetDescription: returns the description for this node
GetDescription() string GetDescription() string
// GetAlias: associates the node with an alias. Potentially used
// for DNS and so forth.
GetAlias() string
// GetServices: returns a list of services offered by the node
GetServices() map[string]string
GetType() conf.NodeType
}
// NodeEquals: determines if two mesh nodes are equivalent to one another
func NodeEquals(node1, node2 MeshNode) bool {
key1, _ := node1.GetPublicKey()
key2, _ := node2.GetPublicKey()
if node1 == nil || node2 == nil {
return false
}
return key1.String() == key2.String()
}
func NodeID(node MeshNode) string {
key, _ := node.GetPublicKey()
return key.String()
} }
type MeshSnapshot interface { type MeshSnapshot interface {
@ -46,7 +102,7 @@ type MeshSyncer interface {
type MeshProvider interface { type MeshProvider interface {
// AddNode() adds a node to the mesh // AddNode() adds a node to the mesh
AddNode(node MeshNode) AddNode(node MeshNode)
// GetMesh() returns a snapshot of the mesh provided by the mesh provider // GetMesh() returns a snapshot of the mesh provided by the mesh provider.
GetMesh() (MeshSnapshot, error) GetMesh() (MeshSnapshot, error)
// GetMeshId() returns the ID of the mesh network // GetMeshId() returns the ID of the mesh network
GetMeshId() string GetMeshId() string
@ -63,30 +119,59 @@ type MeshProvider interface {
// UpdateTimeStamp: update the timestamp of the given node // UpdateTimeStamp: update the timestamp of the given node
UpdateTimeStamp(nodeId string) error UpdateTimeStamp(nodeId string) error
// AddRoutes: adds routes to the given node // AddRoutes: adds routes to the given node
AddRoutes(nodeId string, route ...string) error AddRoutes(nodeId string, route ...Route) error
// DeleteRoutes: deletes the routes from the node // DeleteRoutes: deletes the routes from the node
RemoveRoutes(nodeId string, route ...string) error RemoveRoutes(nodeId string, route ...Route) error
// GetSyncer: returns the automerge syncer for sync // GetSyncer: returns the automerge syncer for sync
GetSyncer() MeshSyncer GetSyncer() MeshSyncer
// GetNode get a particular not within the mesh
GetNode(string) (MeshNode, error)
// NodeExists: returns true if a particular node exists false otherwise
NodeExists(string) bool
// SetDescription: sets the description of this automerge data type // SetDescription: sets the description of this automerge data type
SetDescription(nodeId string, description string) error SetDescription(nodeId string, description string) error
// Prune: prunes all nodes that have not updated their timestamp in // SetAlias: set the alias of the nodeId
// pruneAmount seconds SetAlias(nodeId string, alias string) error
Prune(pruneAmount int) error // AddService: adds the service to the given node
AddService(nodeId, key, value string) error
// RemoveService: removes the service form the node. throws an error if the service does not exist
RemoveService(nodeId, key string) error
// Prune: prunes all nodes that have not updated their
// vector clock
Prune() error
// GetPeers: get a list of contactable peers
GetPeers() []string
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen
GetRoutes(targetNode string) (map[string]Route, error)
// RemoveNode(): remove the node from the mesh
RemoveNode(nodeId string) error
// Mark: marks the node as unreachable. This is not broadcast to the entire
// this is not considered when syncing node state
Mark(nodeId string)
// GetConfiguration: gets the configuration parameters specific for this
// mesh network
GetConfiguration() *conf.WgConfiguration
} }
// HostParameters contains the IDs of a node // HostParameters contains the IDs of a node
type HostParameters struct { type HostParameters struct {
HostEndpoint string PrivateKey *wgtypes.Key
}
// GetPublicKey: gets the public key of the node
func (h *HostParameters) GetPublicKey() string {
return h.PrivateKey.PublicKey().String()
} }
// MeshProviderFactoryParams parameters required to build a mesh provider // MeshProviderFactoryParams parameters required to build a mesh provider
type MeshProviderFactoryParams struct { type MeshProviderFactoryParams struct {
DevName string DevName string
MeshId string MeshId string
Port int Port int
Conf *conf.WgMeshConfiguration Conf *conf.WgConfiguration
Client *wgctrl.Client DaemonConf *conf.DaemonConfiguration
Client *wgctrl.Client
NodeID string
} }
// MeshProviderFactory creates an instance of a mesh provider // MeshProviderFactory creates an instance of a mesh provider
@ -97,10 +182,11 @@ type MeshProviderFactory interface {
// MeshNodeFactoryParams are the parameters required to construct // MeshNodeFactoryParams are the parameters required to construct
// a mesh node // a mesh node
type MeshNodeFactoryParams struct { type MeshNodeFactoryParams struct {
PublicKey *wgtypes.Key PublicKey *wgtypes.Key
NodeIP net.IP NodeIP net.IP
WgPort int WgPort int
Endpoint string Endpoint string
MeshConfig *conf.WgConfiguration
} }
// MeshBuilder build the hosts mesh node for it to be added to the mesh // MeshBuilder build the hosts mesh node for it to be added to the mesh

View File

@ -3,10 +3,12 @@ package query
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strings"
"github.com/jmespath/go-jmespath" "github.com/jmespath/go-jmespath"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/mesh"
) )
// Querier queries a data store for the given data // Querier queries a data store for the given data
@ -15,29 +17,42 @@ type Querier interface {
Query(meshId string, queryParams string) ([]byte, error) Query(meshId string, queryParams string) ([]byte, error)
} }
// JmesQuerier: queries the datstore in JMESPath syntax
type JmesQuerier struct { type JmesQuerier struct {
manager mesh.MeshManager manager mesh.MeshManager
} }
// QueryError: query error if something went wrong
type QueryError struct { type QueryError struct {
msg string msg string
} }
// QuerRoute: represents a route in the query
type QueryRoute struct {
Destination string `json:"destination"`
HopCount int `json:"hopCount"`
Path string `json:"path"`
}
// QueryNode: represents a single node in the query
type QueryNode struct { type QueryNode struct {
HostEndpoint string `json:"hostEndpoint"` HostEndpoint string `json:"hostEndpoint"`
PublicKey string `json:"publicKey"` PublicKey string `json:"publicKey"`
WgEndpoint string `json:"wgEndpoint"` WgEndpoint string `json:"wgEndpoint"`
WgHost string `json:"wgHost"` WgHost string `json:"wgHost"`
Timestamp int64 `json:"timestmap"` Timestamp int64 `json:"timestamp"`
Description string `json:"description"` Description string `json:"description"`
Routes []string `json:"routes"` Routes []QueryRoute `json:"routes"`
Alias string `json:"alias"`
Services map[string]string `json:"services"`
Type conf.NodeType `json:"type"`
} }
func (m *QueryError) Error() string { func (m *QueryError) Error() string {
return m.msg return m.msg
} }
// Query: queries the data // Query: queries the the datastore at the given meshid
func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) { func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
mesh, ok := j.manager.GetMeshes()[meshId] mesh, ok := j.manager.GetMeshes()[meshId]
@ -51,7 +66,7 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
return nil, err return nil, err
} }
nodes := lib.Map(lib.MapValues(snapshot.GetNodes()), meshNodeToQueryNode) nodes := lib.Map(lib.MapValues(snapshot.GetNodes()), MeshNodeToQueryNode)
result, err := jmespath.Search(queryParams, nodes) result, err := jmespath.Search(queryParams, nodes)
@ -63,7 +78,8 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
return bytes, err return bytes, err
} }
func meshNodeToQueryNode(node mesh.MeshNode) *QueryNode { // MeshNodeToQuerynode: convert the mesh node into a query abstraction
func MeshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode := new(QueryNode) queryNode := new(QueryNode)
queryNode.HostEndpoint = node.GetHostEndpoint() queryNode.HostEndpoint = node.GetHostEndpoint()
pubKey, _ := node.GetPublicKey() pubKey, _ := node.GetPublicKey()
@ -74,8 +90,18 @@ func meshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode.WgHost = node.GetWgHost().String() queryNode.WgHost = node.GetWgHost().String()
queryNode.Timestamp = node.GetTimeStamp() queryNode.Timestamp = node.GetTimeStamp()
queryNode.Routes = node.GetRoutes() queryNode.Routes = lib.Map(node.GetRoutes(), func(r mesh.Route) QueryRoute {
return QueryRoute{
Destination: r.GetDestination().String(),
HopCount: r.GetHopCount(),
Path: strings.Join(r.GetPath(), ","),
}
})
queryNode.Description = node.GetDescription() queryNode.Description = node.GetDescription()
queryNode.Alias = node.GetAlias()
queryNode.Services = node.GetServices()
queryNode.Type = node.GetType()
return queryNode return queryNode
} }

View File

@ -4,126 +4,173 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"strconv" "slices"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/rpc"
) )
// IpcHandler: represents a handler for ipc calls
type IpcHandler struct { type IpcHandler struct {
Server ctrlserver.CtrlServer Server ctrlserver.CtrlServer
} }
// getOverrideConfiguration: override any specific WireGuard configuration
func getOverrideConfiguration(args *ipc.WireGuardArgs) conf.WgConfiguration {
overrideConf := conf.WgConfiguration{}
if args.Role != "" {
role := conf.NodeType(args.Role)
overrideConf.Role = &role
}
if args.Endpoint != "" {
overrideConf.Endpoint = &args.Endpoint
}
if args.KeepAliveWg != 0 {
keepAliveWg := args.KeepAliveWg
overrideConf.KeepAliveWg = &keepAliveWg
}
overrideConf.AdvertiseRoutes = &args.AdvertiseRoutes
overrideConf.AdvertiseDefaultRoute = &args.AdvertiseDefaultRoute
return overrideConf
}
// CreateMesh: create a new mesh network
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error { func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
meshId, err := n.Server.GetMeshManager().CreateMesh(args.IfName, args.WgPort) overrideConf := getOverrideConfiguration(&args.WgArgs)
meshId, err := n.Server.GetMeshManager().CreateMesh(&mesh.CreateMeshParams{
Port: args.WgArgs.WgPort,
Conf: &overrideConf,
})
if err != nil { if err != nil {
return err return errors.New("could not create mesh")
} }
err = n.Server.GetMeshManager().AddSelf(&mesh.AddSelfParams{ err = n.Server.GetMeshManager().AddSelf(&mesh.AddSelfParams{
MeshId: meshId, MeshId: meshId,
WgPort: args.WgPort, WgPort: args.WgArgs.WgPort,
Endpoint: args.Endpoint, Endpoint: args.WgArgs.Endpoint,
}) })
if err != nil { if err != nil {
return err return errors.New("could not create mesh")
} }
*reply = meshId *reply = meshId
return err return err
} }
// ListMeshes: list mesh networks
func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error { func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
meshNames := make([]string, len(n.Server.GetMeshManager().GetMeshes())) meshNames := make([]string, len(n.Server.GetMeshManager().GetMeshes()))
i := 0 i := 0
for meshId, _ := range n.Server.GetMeshManager().GetMeshes() { for meshId := range n.Server.GetMeshManager().GetMeshes() {
meshNames[i] = meshId meshNames[i] = meshId
i++ i++
} }
slices.Sort(meshNames)
*reply = ipc.ListMeshReply{Meshes: meshNames} *reply = ipc.ListMeshReply{Meshes: meshNames}
return nil return nil
} }
func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { // JoinMesh: join a mesh network
peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAdress) func (n *IpcHandler) JoinMesh(args *ipc.JoinMeshArgs, reply *string) error {
overrideConf := getOverrideConfiguration(&args.WgArgs)
if n.Server.GetMeshManager().GetMesh(args.MeshId) != nil {
return fmt.Errorf("user is already apart of the mesh")
}
peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAddress)
if err != nil { if err != nil {
return err return fmt.Errorf("could not join mesh %s", args.MeshId)
} }
client, err := peerConnection.GetClient() client, err := peerConnection.GetClient()
if err != nil { if err != nil {
return err return fmt.Errorf("could not join mesh %s", args.MeshId)
} }
c := rpc.NewMeshCtrlServerClient(client) c := rpc.NewMeshCtrlServerClient(client)
if err != nil { if err != nil {
return err return fmt.Errorf("could not join mesh %s", args.MeshId)
} }
ctx, cancel := context.WithTimeout(context.Background(), time.Second) configuration := n.Server.GetConfiguration()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(configuration.Timeout))
defer cancel() defer cancel()
meshReply, err := c.GetMesh(ctx, &rpc.GetMeshRequest{MeshId: args.MeshId}) meshReply, err := c.GetMesh(ctx, &rpc.GetMeshRequest{MeshId: args.MeshId})
if err != nil { if err != nil {
return err return fmt.Errorf("could not join mesh %s", args.MeshId)
} }
err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{ err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
MeshId: args.MeshId, MeshId: args.MeshId,
DevName: args.IfName, WgPort: args.WgArgs.WgPort,
WgPort: args.Port,
MeshBytes: meshReply.Mesh, MeshBytes: meshReply.Mesh,
Conf: &overrideConf,
}) })
if err != nil { if err != nil {
return err return fmt.Errorf("could not join mesh %s", args.MeshId)
} }
err = n.Server.GetMeshManager().AddSelf(&mesh.AddSelfParams{ err = n.Server.GetMeshManager().AddSelf(&mesh.AddSelfParams{
MeshId: args.MeshId, MeshId: args.MeshId,
WgPort: args.Port, WgPort: args.WgArgs.WgPort,
Endpoint: args.Endpoint, Endpoint: args.WgArgs.Endpoint,
}) })
if err != nil { if err != nil {
return err return fmt.Errorf("could not join mesh %s", args.MeshId)
} }
*reply = strconv.FormatBool(true) *reply = fmt.Sprintf("Successfully Joined: %s", args.MeshId)
return nil return nil
} }
// LeaveMesh leaves a mesh network // LeaveMesh: leaves a mesh network
func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error { func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error {
err := n.Server.GetMeshManager().LeaveMesh(meshId) err := n.Server.GetMeshManager().LeaveMesh(meshId)
if err == nil { if err == nil {
*reply = fmt.Sprintf("Left Mesh %s", meshId) *reply = fmt.Sprintf("Left Mesh %s", meshId)
} }
return err return err
} }
// GetMesh: get a mesh network at the given meshid
func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error { func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
mesh := n.Server.GetMeshManager().GetMesh(meshId) theMesh := n.Server.GetMeshManager().GetMesh(meshId)
meshSnapshot, err := mesh.GetMesh()
if theMesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
meshSnapshot, err := theMesh.GetMesh()
if err != nil { if err != nil {
return err return err
} }
if mesh == nil { if theMesh == nil {
return errors.New("mesh does not exist") return errors.New("mesh does not exist")
} }
@ -131,22 +178,9 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
i := 0 i := 0
for _, node := range meshSnapshot.GetNodes() { for _, node := range meshSnapshot.GetNodes() {
pubKey, _ := node.GetPublicKey() node := ctrlserver.NewCtrlNode(theMesh, node)
if err != nil { nodes[i] = *node
return err
}
node := ctrlserver.MeshNode{
HostEndpoint: node.GetHostEndpoint(),
WgEndpoint: node.GetWgEndpoint(),
PublicKey: pubKey.String(),
WgHost: node.GetWgHost().String(),
Timestamp: node.GetTimeStamp(),
Routes: node.GetRoutes(),
}
nodes[i] = node
i += 1 i += 1
} }
@ -154,31 +188,7 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
return nil return nil
} }
func (n *IpcHandler) EnableInterface(meshId string, reply *string) error { // Query: perform a jmespath query
err := n.Server.GetMeshManager().EnableInterface(meshId)
if err != nil {
*reply = err.Error()
return err
}
*reply = "up"
return nil
}
func (n *IpcHandler) GetDOT(meshId string, reply *string) error {
g := mesh.NewMeshDotConverter(n.Server.GetMeshManager())
result, err := g.Generate(meshId)
if err != nil {
return err
}
*reply = result
return nil
}
func (n *IpcHandler) Query(params ipc.QueryMesh, reply *string) error { func (n *IpcHandler) Query(params ipc.QueryMesh, reply *string) error {
queryResponse, err := n.Server.GetQuerier().Query(params.MeshId, params.Query) queryResponse, err := n.Server.GetQuerier().Query(params.MeshId, params.Query)
@ -190,17 +200,59 @@ func (n *IpcHandler) Query(params ipc.QueryMesh, reply *string) error {
return nil return nil
} }
func (n *IpcHandler) PutDescription(description string, reply *string) error { // PutDescription: change your description in the mesh
err := n.Server.GetMeshManager().SetDescription(description) func (n *IpcHandler) PutDescription(args ipc.PutDescriptionArgs, reply *string) error {
err := n.Server.GetMeshManager().SetDescription(args.MeshId, args.Description)
if err != nil { if err != nil {
return err return err
} }
*reply = fmt.Sprintf("Set description to %s", description) *reply = fmt.Sprintf("set description to %s for %s", args.Description, args.MeshId)
return nil return nil
} }
// PutAlias: put your aliasin the mesh
func (n *IpcHandler) PutAlias(args ipc.PutAliasArgs, reply *string) error {
if args.Alias == "" {
return fmt.Errorf("alias not provided")
}
err := n.Server.GetMeshManager().SetAlias(args.MeshId, args.Alias)
if err != nil {
return fmt.Errorf("could not set alias: %s", args.Alias)
}
*reply = fmt.Sprintf("Set alias to %s", args.Alias)
return nil
}
// PutService: place a service in the mesh
func (n *IpcHandler) PutService(service ipc.PutServiceArgs, reply *string) error {
err := n.Server.GetMeshManager().SetService(service.MeshId, service.Service, service.Value)
if err != nil {
return err
}
*reply = fmt.Sprintf("Set service %s in %s to %s", service.Service, service.MeshId, service.Value)
return nil
}
// DeleteService: withtract a service in the mesh
func (n *IpcHandler) DeleteService(service ipc.DeleteServiceArgs, reply *string) error {
err := n.Server.GetMeshManager().RemoveService(service.MeshId, service.Service)
if err != nil {
return err
}
*reply = fmt.Sprintf("Removed service %s from %s", service.Service, service.MeshId)
return nil
}
// RobinIpcParams: parameters required to construct a new mesh network
type RobinIpcParams struct { type RobinIpcParams struct {
CtrlServer ctrlserver.CtrlServer CtrlServer ctrlserver.CtrlServer
} }

View File

@ -3,9 +3,10 @@ package robin
import ( import (
"testing" "testing"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/ipc"
"github.com/tim-beatham/smegmesh/pkg/mesh"
) )
func getRequester() *IpcHandler { func getRequester() *IpcHandler {
@ -17,9 +18,11 @@ func TestCreateMeshRepliesMeshId(t *testing.T) {
requester := getRequester() requester := getRequester()
err := requester.CreateMesh(&ipc.NewMeshArgs{ err := requester.CreateMesh(&ipc.NewMeshArgs{
IfName: "wg0", WgArgs: ipc.WireGuardArgs{
WgPort: 5000, WgPort: 500,
Endpoint: "abc.com", Endpoint: "abc.com:1234",
Role: "peer",
},
}, &reply) }, &reply)
if err != nil { if err != nil {
@ -52,9 +55,8 @@ func TestListMeshesMeshesNotEmpty(t *testing.T) {
requester.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{ requester.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
MeshId: "tim123", MeshId: "tim123",
DevName: "wg0",
WgPort: 5000,
MeshBytes: make([]byte, 0), MeshBytes: make([]byte, 0),
Conf: &conf.WgConfiguration{},
}) })
err := requester.ListMeshes("", &reply) err := requester.ListMeshes("", &reply)

View File

@ -4,15 +4,17 @@ import (
"context" "context"
"errors" "errors"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/smegmesh/pkg/rpc"
) )
// WgRpc: represents a WireGuard rpc call
type WgRpc struct { type WgRpc struct {
rpc.UnimplementedMeshCtrlServerServer rpc.UnimplementedMeshCtrlServerServer
Server *ctrlserver.MeshCtrlServer Server *ctrlserver.MeshCtrlServer
} }
// GetMesh: serialise the mesh network into bytes
func (m *WgRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.GetMeshReply, error) { func (m *WgRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.GetMeshReply, error) {
mesh := m.Server.MeshManager.GetMesh(request.MeshId) mesh := m.Server.MeshManager.GetMesh(request.MeshId)
@ -28,7 +30,3 @@ func (m *WgRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.
return &reply, nil return &reply, nil
} }
func (m *WgRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (*rpc.JoinMeshReply, error) {
return &rpc.JoinMeshReply{Success: true}, nil
}

View File

@ -1 +0,0 @@
package robin

View File

@ -1,22 +1,35 @@
package route package route
import ( import (
"net" "github.com/tim-beatham/smegmesh/pkg/lib"
"os/exec" "golang.org/x/sys/unix"
logging "github.com/tim-beatham/wgmesh/pkg/log"
) )
// RouteInstaller: install the routes to the given interface
type RouteInstaller interface { type RouteInstaller interface {
InstallRoutes(devName string, routes ...*net.IPNet) error InstallRoutes(devName string, routes ...lib.Route) error
} }
type RouteInstallerImpl struct{} type RouteInstallerImpl struct{}
// InstallRoutes: installs a route into the routing table // InstallRoutes: installs a route into the routing table
func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...*net.IPNet) error { func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...lib.Route) error {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return err
}
defer rtnl.Close()
err = rtnl.DeleteRoutes(devName, unix.AF_INET6, routes...)
if err != nil {
return err
}
for _, route := range routes { for _, route := range routes {
err := r.installRoute(devName, route) err := rtnl.AddRoute(devName, route)
if err != nil { if err != nil {
return err return err
@ -26,22 +39,6 @@ func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...*net.IPNet)
return nil return nil
} }
// installRoute: installs a route into the linux table
func (r *RouteInstallerImpl) installRoute(devName string, route *net.IPNet) error {
// TODO: Find a library that automates this
cmd := exec.Command("/usr/bin/ip", "-6", "route", "add", route.String(), "dev", devName)
logging.Log.WriteInfof("%s %s", route.String(), devName)
if msg, err := cmd.CombinedOutput(); err != nil {
logging.Log.WriteErrorf(err.Error())
logging.Log.WriteErrorf(string(msg))
return err
}
return nil
}
func NewRouteInstaller() RouteInstaller { func NewRouteInstaller() RouteInstaller {
return &RouteInstallerImpl{} return &RouteInstallerImpl{}
} }

View File

@ -1,235 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.1
// protoc v3.21.12
// source: pkg/grpc/ctrlserver/authentication.proto
package rpc
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type JoinAuthMeshRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
MeshId string `protobuf:"bytes,1,opt,name=meshId,proto3" json:"meshId,omitempty"`
Alias string `protobuf:"bytes,2,opt,name=alias,proto3" json:"alias,omitempty"`
}
func (x *JoinAuthMeshRequest) Reset() {
*x = JoinAuthMeshRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *JoinAuthMeshRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JoinAuthMeshRequest) ProtoMessage() {}
func (x *JoinAuthMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JoinAuthMeshRequest.ProtoReflect.Descriptor instead.
func (*JoinAuthMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_authentication_proto_rawDescGZIP(), []int{0}
}
func (x *JoinAuthMeshRequest) GetMeshId() string {
if x != nil {
return x.MeshId
}
return ""
}
func (x *JoinAuthMeshRequest) GetAlias() string {
if x != nil {
return x.Alias
}
return ""
}
type JoinAuthMeshReply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"`
Token *string `protobuf:"bytes,2,opt,name=token,proto3,oneof" json:"token,omitempty"`
}
func (x *JoinAuthMeshReply) Reset() {
*x = JoinAuthMeshReply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *JoinAuthMeshReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JoinAuthMeshReply) ProtoMessage() {}
func (x *JoinAuthMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JoinAuthMeshReply.ProtoReflect.Descriptor instead.
func (*JoinAuthMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_authentication_proto_rawDescGZIP(), []int{1}
}
func (x *JoinAuthMeshReply) GetSuccess() bool {
if x != nil {
return x.Success
}
return false
}
func (x *JoinAuthMeshReply) GetToken() string {
if x != nil && x.Token != nil {
return *x.Token
}
return ""
}
var File_pkg_grpc_ctrlserver_authentication_proto protoreflect.FileDescriptor
var file_pkg_grpc_ctrlserver_authentication_proto_rawDesc = []byte{
0x0a, 0x28, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73,
0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x61, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61,
0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x72, 0x70, 0x63, 0x74,
0x79, 0x70, 0x65, 0x73, 0x22, 0x43, 0x0a, 0x13, 0x4a, 0x6f, 0x69, 0x6e, 0x41, 0x75, 0x74, 0x68,
0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x6d,
0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73,
0x68, 0x49, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x61, 0x6c, 0x69, 0x61, 0x73, 0x18, 0x02, 0x20, 0x01,
0x28, 0x09, 0x52, 0x05, 0x61, 0x6c, 0x69, 0x61, 0x73, 0x22, 0x52, 0x0a, 0x11, 0x4a, 0x6f, 0x69,
0x6e, 0x41, 0x75, 0x74, 0x68, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x18,
0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52,
0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x19, 0x0a, 0x05, 0x74, 0x6f, 0x6b, 0x65,
0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e,
0x88, 0x01, 0x01, 0x42, 0x08, 0x0a, 0x06, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x32, 0x5a, 0x0a,
0x0e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12,
0x48, 0x0a, 0x08, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x1d, 0x2e, 0x72, 0x70,
0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x4d,
0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x4d, 0x65,
0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x42, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67,
0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_pkg_grpc_ctrlserver_authentication_proto_rawDescOnce sync.Once
file_pkg_grpc_ctrlserver_authentication_proto_rawDescData = file_pkg_grpc_ctrlserver_authentication_proto_rawDesc
)
func file_pkg_grpc_ctrlserver_authentication_proto_rawDescGZIP() []byte {
file_pkg_grpc_ctrlserver_authentication_proto_rawDescOnce.Do(func() {
file_pkg_grpc_ctrlserver_authentication_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_grpc_ctrlserver_authentication_proto_rawDescData)
})
return file_pkg_grpc_ctrlserver_authentication_proto_rawDescData
}
var file_pkg_grpc_ctrlserver_authentication_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_pkg_grpc_ctrlserver_authentication_proto_goTypes = []interface{}{
(*JoinAuthMeshRequest)(nil), // 0: rpctypes.JoinAuthMeshRequest
(*JoinAuthMeshReply)(nil), // 1: rpctypes.JoinAuthMeshReply
}
var file_pkg_grpc_ctrlserver_authentication_proto_depIdxs = []int32{
0, // 0: rpctypes.Authentication.JoinMesh:input_type -> rpctypes.JoinAuthMeshRequest
1, // 1: rpctypes.Authentication.JoinMesh:output_type -> rpctypes.JoinAuthMeshReply
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_pkg_grpc_ctrlserver_authentication_proto_init() }
func file_pkg_grpc_ctrlserver_authentication_proto_init() {
if File_pkg_grpc_ctrlserver_authentication_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*JoinAuthMeshRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*JoinAuthMeshReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[1].OneofWrappers = []interface{}{}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pkg_grpc_ctrlserver_authentication_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_pkg_grpc_ctrlserver_authentication_proto_goTypes,
DependencyIndexes: file_pkg_grpc_ctrlserver_authentication_proto_depIdxs,
MessageInfos: file_pkg_grpc_ctrlserver_authentication_proto_msgTypes,
}.Build()
File_pkg_grpc_ctrlserver_authentication_proto = out.File
file_pkg_grpc_ctrlserver_authentication_proto_rawDesc = nil
file_pkg_grpc_ctrlserver_authentication_proto_goTypes = nil
file_pkg_grpc_ctrlserver_authentication_proto_depIdxs = nil
}

View File

@ -1,105 +0,0 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.2.0
// - protoc v3.21.12
// source: pkg/grpc/ctrlserver/authentication.proto
package rpc
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// AuthenticationClient is the client API for Authentication service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type AuthenticationClient interface {
JoinMesh(ctx context.Context, in *JoinAuthMeshRequest, opts ...grpc.CallOption) (*JoinAuthMeshReply, error)
}
type authenticationClient struct {
cc grpc.ClientConnInterface
}
func NewAuthenticationClient(cc grpc.ClientConnInterface) AuthenticationClient {
return &authenticationClient{cc}
}
func (c *authenticationClient) JoinMesh(ctx context.Context, in *JoinAuthMeshRequest, opts ...grpc.CallOption) (*JoinAuthMeshReply, error) {
out := new(JoinAuthMeshReply)
err := c.cc.Invoke(ctx, "/rpctypes.Authentication/JoinMesh", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// AuthenticationServer is the server API for Authentication service.
// All implementations must embed UnimplementedAuthenticationServer
// for forward compatibility
type AuthenticationServer interface {
JoinMesh(context.Context, *JoinAuthMeshRequest) (*JoinAuthMeshReply, error)
mustEmbedUnimplementedAuthenticationServer()
}
// UnimplementedAuthenticationServer must be embedded to have forward compatible implementations.
type UnimplementedAuthenticationServer struct {
}
func (UnimplementedAuthenticationServer) JoinMesh(context.Context, *JoinAuthMeshRequest) (*JoinAuthMeshReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method JoinMesh not implemented")
}
func (UnimplementedAuthenticationServer) mustEmbedUnimplementedAuthenticationServer() {}
// UnsafeAuthenticationServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to AuthenticationServer will
// result in compilation errors.
type UnsafeAuthenticationServer interface {
mustEmbedUnimplementedAuthenticationServer()
}
func RegisterAuthenticationServer(s grpc.ServiceRegistrar, srv AuthenticationServer) {
s.RegisterService(&Authentication_ServiceDesc, srv)
}
func _Authentication_JoinMesh_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(JoinAuthMeshRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(AuthenticationServer).JoinMesh(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/rpctypes.Authentication/JoinMesh",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(AuthenticationServer).JoinMesh(ctx, req.(*JoinAuthMeshRequest))
}
return interceptor(ctx, in, info, handler)
}
// Authentication_ServiceDesc is the grpc.ServiceDesc for Authentication service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var Authentication_ServiceDesc = grpc.ServiceDesc{
ServiceName: "rpctypes.Authentication",
HandlerType: (*AuthenticationServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "JoinMesh",
Handler: _Authentication_JoinMesh_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "pkg/grpc/ctrlserver/authentication.proto",
}

View File

@ -20,77 +20,6 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
) )
type MeshNode struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
PublicKey string `protobuf:"bytes,1,opt,name=publicKey,proto3" json:"publicKey,omitempty"`
WgEndpoint string `protobuf:"bytes,2,opt,name=wgEndpoint,proto3" json:"wgEndpoint,omitempty"`
Endpoint string `protobuf:"bytes,3,opt,name=endpoint,proto3" json:"endpoint,omitempty"`
WgHost string `protobuf:"bytes,4,opt,name=wgHost,proto3" json:"wgHost,omitempty"`
}
func (x *MeshNode) Reset() {
*x = MeshNode{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *MeshNode) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*MeshNode) ProtoMessage() {}
func (x *MeshNode) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use MeshNode.ProtoReflect.Descriptor instead.
func (*MeshNode) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{0}
}
func (x *MeshNode) GetPublicKey() string {
if x != nil {
return x.PublicKey
}
return ""
}
func (x *MeshNode) GetWgEndpoint() string {
if x != nil {
return x.WgEndpoint
}
return ""
}
func (x *MeshNode) GetEndpoint() string {
if x != nil {
return x.Endpoint
}
return ""
}
func (x *MeshNode) GetWgHost() string {
if x != nil {
return x.WgHost
}
return ""
}
type GetMeshRequest struct { type GetMeshRequest struct {
state protoimpl.MessageState state protoimpl.MessageState
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
@ -102,7 +31,7 @@ type GetMeshRequest struct {
func (x *GetMeshRequest) Reset() { func (x *GetMeshRequest) Reset() {
*x = GetMeshRequest{} *x = GetMeshRequest{}
if protoimpl.UnsafeEnabled { if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1] mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi) ms.StoreMessageInfo(mi)
} }
@ -115,7 +44,7 @@ func (x *GetMeshRequest) String() string {
func (*GetMeshRequest) ProtoMessage() {} func (*GetMeshRequest) ProtoMessage() {}
func (x *GetMeshRequest) ProtoReflect() protoreflect.Message { func (x *GetMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1] mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil { if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil { if ms.LoadMessageInfo() == nil {
@ -128,7 +57,7 @@ func (x *GetMeshRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetMeshRequest.ProtoReflect.Descriptor instead. // Deprecated: Use GetMeshRequest.ProtoReflect.Descriptor instead.
func (*GetMeshRequest) Descriptor() ([]byte, []int) { func (*GetMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{1} return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{0}
} }
func (x *GetMeshRequest) GetMeshId() string { func (x *GetMeshRequest) GetMeshId() string {
@ -149,7 +78,7 @@ type GetMeshReply struct {
func (x *GetMeshReply) Reset() { func (x *GetMeshReply) Reset() {
*x = GetMeshReply{} *x = GetMeshReply{}
if protoimpl.UnsafeEnabled { if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2] mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi) ms.StoreMessageInfo(mi)
} }
@ -162,7 +91,7 @@ func (x *GetMeshReply) String() string {
func (*GetMeshReply) ProtoMessage() {} func (*GetMeshReply) ProtoMessage() {}
func (x *GetMeshReply) ProtoReflect() protoreflect.Message { func (x *GetMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2] mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil { if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil { if ms.LoadMessageInfo() == nil {
@ -175,7 +104,7 @@ func (x *GetMeshReply) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetMeshReply.ProtoReflect.Descriptor instead. // Deprecated: Use GetMeshReply.ProtoReflect.Descriptor instead.
func (*GetMeshReply) Descriptor() ([]byte, []int) { func (*GetMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{2} return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{1}
} }
func (x *GetMeshReply) GetMesh() []byte { func (x *GetMeshReply) GetMesh() []byte {
@ -185,145 +114,24 @@ func (x *GetMeshReply) GetMesh() []byte {
return nil return nil
} }
type JoinMeshRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Changes []byte `protobuf:"bytes,1,opt,name=changes,proto3" json:"changes,omitempty"`
MeshId string `protobuf:"bytes,2,opt,name=meshId,proto3" json:"meshId,omitempty"`
}
func (x *JoinMeshRequest) Reset() {
*x = JoinMeshRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *JoinMeshRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JoinMeshRequest) ProtoMessage() {}
func (x *JoinMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JoinMeshRequest.ProtoReflect.Descriptor instead.
func (*JoinMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{3}
}
func (x *JoinMeshRequest) GetChanges() []byte {
if x != nil {
return x.Changes
}
return nil
}
func (x *JoinMeshRequest) GetMeshId() string {
if x != nil {
return x.MeshId
}
return ""
}
type JoinMeshReply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"`
}
func (x *JoinMeshReply) Reset() {
*x = JoinMeshReply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *JoinMeshReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JoinMeshReply) ProtoMessage() {}
func (x *JoinMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JoinMeshReply.ProtoReflect.Descriptor instead.
func (*JoinMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{4}
}
func (x *JoinMeshReply) GetSuccess() bool {
if x != nil {
return x.Success
}
return false
}
var File_pkg_grpc_ctrlserver_ctrlserver_proto protoreflect.FileDescriptor var File_pkg_grpc_ctrlserver_ctrlserver_proto protoreflect.FileDescriptor
var file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc = []byte{ var file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc = []byte{
0x0a, 0x24, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73, 0x0a, 0x24, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73,
0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72,
0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73,
0x22, 0x7c, 0x0a, 0x08, 0x4d, 0x65, 0x73, 0x68, 0x4e, 0x6f, 0x64, 0x65, 0x12, 0x1c, 0x0a, 0x09, 0x22, 0x28, 0x0a, 0x0e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65,
0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01,
0x09, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x77, 0x67, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x22, 0x22, 0x0a, 0x0c, 0x47, 0x65,
0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x65,
0x77, 0x67, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, 0x73, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x32, 0x4f,
0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x6e, 0x0a, 0x0e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x74, 0x72, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72,
0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x77, 0x67, 0x48, 0x6f, 0x73, 0x74, 0x12, 0x3d, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x18, 0x2e, 0x72, 0x70,
0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x77, 0x67, 0x48, 0x6f, 0x73, 0x74, 0x22, 0x28, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65,
0x0a, 0x0e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73,
0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x42,
0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x22, 0x22, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x4d, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x6f, 0x33,
0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x22, 0x43, 0x0a, 0x0f,
0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12,
0x18, 0x0a, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c,
0x52, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73,
0x68, 0x49, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49,
0x64, 0x22, 0x29, 0x0a, 0x0d, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70,
0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20,
0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x32, 0x91, 0x01, 0x0a,
0x0e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x74, 0x72, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12,
0x3d, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x18, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e,
0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x40,
0x0a, 0x08, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x19, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73,
0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00,
0x42, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x33,
} }
var ( var (
@ -338,21 +146,16 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP() []byte {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescData return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescData
} }
var file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes = make([]protoimpl.MessageInfo, 5) var file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_pkg_grpc_ctrlserver_ctrlserver_proto_goTypes = []interface{}{ var file_pkg_grpc_ctrlserver_ctrlserver_proto_goTypes = []interface{}{
(*MeshNode)(nil), // 0: rpctypes.MeshNode (*GetMeshRequest)(nil), // 0: rpctypes.GetMeshRequest
(*GetMeshRequest)(nil), // 1: rpctypes.GetMeshRequest (*GetMeshReply)(nil), // 1: rpctypes.GetMeshReply
(*GetMeshReply)(nil), // 2: rpctypes.GetMeshReply
(*JoinMeshRequest)(nil), // 3: rpctypes.JoinMeshRequest
(*JoinMeshReply)(nil), // 4: rpctypes.JoinMeshReply
} }
var file_pkg_grpc_ctrlserver_ctrlserver_proto_depIdxs = []int32{ var file_pkg_grpc_ctrlserver_ctrlserver_proto_depIdxs = []int32{
1, // 0: rpctypes.MeshCtrlServer.GetMesh:input_type -> rpctypes.GetMeshRequest 0, // 0: rpctypes.MeshCtrlServer.GetMesh:input_type -> rpctypes.GetMeshRequest
3, // 1: rpctypes.MeshCtrlServer.JoinMesh:input_type -> rpctypes.JoinMeshRequest 1, // 1: rpctypes.MeshCtrlServer.GetMesh:output_type -> rpctypes.GetMeshReply
2, // 2: rpctypes.MeshCtrlServer.GetMesh:output_type -> rpctypes.GetMeshReply 1, // [1:2] is the sub-list for method output_type
4, // 3: rpctypes.MeshCtrlServer.JoinMesh:output_type -> rpctypes.JoinMeshReply 0, // [0:1] is the sub-list for method input_type
2, // [2:4] is the sub-list for method output_type
0, // [0:2] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name 0, // [0:0] is the sub-list for field type_name
@ -365,18 +168,6 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
} }
if !protoimpl.UnsafeEnabled { if !protoimpl.UnsafeEnabled {
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*MeshNode); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetMeshRequest); i { switch v := v.(*GetMeshRequest); i {
case 0: case 0:
return &v.state return &v.state
@ -388,7 +179,7 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
return nil return nil
} }
} }
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetMeshReply); i { switch v := v.(*GetMeshReply); i {
case 0: case 0:
return &v.state return &v.state
@ -400,30 +191,6 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
return nil return nil
} }
} }
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*JoinMeshRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*JoinMeshReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
} }
type x struct{} type x struct{}
out := protoimpl.TypeBuilder{ out := protoimpl.TypeBuilder{
@ -431,7 +198,7 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(), GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc, RawDescriptor: file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc,
NumEnums: 0, NumEnums: 0,
NumMessages: 5, NumMessages: 2,
NumExtensions: 0, NumExtensions: 0,
NumServices: 1, NumServices: 1,
}, },

View File

@ -23,7 +23,6 @@ const _ = grpc.SupportPackageIsVersion7
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type MeshCtrlServerClient interface { type MeshCtrlServerClient interface {
GetMesh(ctx context.Context, in *GetMeshRequest, opts ...grpc.CallOption) (*GetMeshReply, error) GetMesh(ctx context.Context, in *GetMeshRequest, opts ...grpc.CallOption) (*GetMeshReply, error)
JoinMesh(ctx context.Context, in *JoinMeshRequest, opts ...grpc.CallOption) (*JoinMeshReply, error)
} }
type meshCtrlServerClient struct { type meshCtrlServerClient struct {
@ -43,21 +42,11 @@ func (c *meshCtrlServerClient) GetMesh(ctx context.Context, in *GetMeshRequest,
return out, nil return out, nil
} }
func (c *meshCtrlServerClient) JoinMesh(ctx context.Context, in *JoinMeshRequest, opts ...grpc.CallOption) (*JoinMeshReply, error) {
out := new(JoinMeshReply)
err := c.cc.Invoke(ctx, "/rpctypes.MeshCtrlServer/JoinMesh", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// MeshCtrlServerServer is the server API for MeshCtrlServer service. // MeshCtrlServerServer is the server API for MeshCtrlServer service.
// All implementations must embed UnimplementedMeshCtrlServerServer // All implementations must embed UnimplementedMeshCtrlServerServer
// for forward compatibility // for forward compatibility
type MeshCtrlServerServer interface { type MeshCtrlServerServer interface {
GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error) GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error)
JoinMesh(context.Context, *JoinMeshRequest) (*JoinMeshReply, error)
mustEmbedUnimplementedMeshCtrlServerServer() mustEmbedUnimplementedMeshCtrlServerServer()
} }
@ -68,9 +57,6 @@ type UnimplementedMeshCtrlServerServer struct {
func (UnimplementedMeshCtrlServerServer) GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error) { func (UnimplementedMeshCtrlServerServer) GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetMesh not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetMesh not implemented")
} }
func (UnimplementedMeshCtrlServerServer) JoinMesh(context.Context, *JoinMeshRequest) (*JoinMeshReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method JoinMesh not implemented")
}
func (UnimplementedMeshCtrlServerServer) mustEmbedUnimplementedMeshCtrlServerServer() {} func (UnimplementedMeshCtrlServerServer) mustEmbedUnimplementedMeshCtrlServerServer() {}
// UnsafeMeshCtrlServerServer may be embedded to opt out of forward compatibility for this service. // UnsafeMeshCtrlServerServer may be embedded to opt out of forward compatibility for this service.
@ -102,24 +88,6 @@ func _MeshCtrlServer_GetMesh_Handler(srv interface{}, ctx context.Context, dec f
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
func _MeshCtrlServer_JoinMesh_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(JoinMeshRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(MeshCtrlServerServer).JoinMesh(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/rpctypes.MeshCtrlServer/JoinMesh",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(MeshCtrlServerServer).JoinMesh(ctx, req.(*JoinMeshRequest))
}
return interceptor(ctx, in, info, handler)
}
// MeshCtrlServer_ServiceDesc is the grpc.ServiceDesc for MeshCtrlServer service. // MeshCtrlServer_ServiceDesc is the grpc.ServiceDesc for MeshCtrlServer service.
// It's only intended for direct use with grpc.RegisterService, // It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy) // and not to be introspected or modified (even as a copy)
@ -131,10 +99,6 @@ var MeshCtrlServer_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetMesh", MethodName: "GetMesh",
Handler: _MeshCtrlServer_GetMesh_Handler, Handler: _MeshCtrlServer_GetMesh_Handler,
}, },
{
MethodName: "JoinMesh",
Handler: _MeshCtrlServer_JoinMesh_Handler,
},
}, },
Streams: []grpc.StreamDesc{}, Streams: []grpc.StreamDesc{},
Metadata: "pkg/grpc/ctrlserver/ctrlserver.proto", Metadata: "pkg/grpc/ctrlserver/ctrlserver.proto",

View File

@ -1,139 +1,279 @@
package sync package sync
import ( import (
"errors" "fmt"
"io"
"math/rand" "math/rand"
"sync" "sync"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn" "github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
) )
// Syncer: picks random nodes from the mesh // Syncer: picks random nodes from the meshs
type Syncer interface { type Syncer interface {
Sync(meshId string) error Sync(theMesh mesh.MeshProvider) (bool, error)
SyncMeshes() error SyncMeshes() error
} }
// SyncerImpl: implementation of a syncer to sync meshes
type SyncerImpl struct { type SyncerImpl struct {
manager mesh.MeshManager meshManager mesh.MeshManager
requester SyncRequester requester SyncRequester
infectionCount int infectionCount int
syncCount int syncCount int
cluster conn.ConnCluster cluster conn.ConnCluster
conf *conf.WgMeshConfiguration configuration *conf.DaemonConfiguration
lastSync map[string]int64
lastPoll map[string]int64
lastSyncLock sync.RWMutex
lastPollLock sync.RWMutex
} }
// Sync: Sync random nodes // Sync: Sync with random nodes. Returns true if there was changes false otherwise
func (s *SyncerImpl) Sync(meshId string) error { func (s *SyncerImpl) Sync(correspondingMesh mesh.MeshProvider) (bool, error) {
logging.Log.WriteInfof("UPDATING WG CONF") if correspondingMesh == nil {
err := s.manager.ApplyConfig() return false, fmt.Errorf("mesh provided was nil cannot sync nil mesh")
if err != nil {
logging.Log.WriteInfof("Failed to update config %w", err)
} }
if !s.manager.HasChanges(meshId) && s.infectionCount == 0 { // Self can be nil if the node is removed
logging.Log.WriteInfof("No changes for %s", meshId) selfID := s.meshManager.GetPublicKey()
return nil self, _ := correspondingMesh.GetNode(selfID.String())
correspondingMesh.Prune()
if correspondingMesh.HasChanges() {
logging.Log.WriteInfof("meshes %s has changes", correspondingMesh.GetMeshId())
} }
theMesh := s.manager.GetMesh(meshId) // If removed sync with other nodes to gossip the node is removed
if self != nil && self.GetType() == conf.PEER_ROLE && !correspondingMesh.HasChanges() && s.infectionCount == 0 {
logging.Log.WriteInfof("no changes for %s", correspondingMesh.GetMeshId())
if theMesh == nil { // If not synchronised in certain time pull from random neighbour
return errors.New("the provided mesh does not exist") if s.configuration.PullInterval != 0 && time.Now().Unix()-s.lastSync[correspondingMesh.GetMeshId()] > int64(s.configuration.PullInterval) {
} return s.Pull(self, correspondingMesh)
}
snapshot, err := theMesh.GetMesh() return false, nil
if err != nil {
return err
}
nodes := snapshot.GetNodes()
if len(nodes) <= 1 {
return nil
}
self, err := s.manager.GetSelf(meshId)
if err != nil {
return err
}
excludedNodes := map[string]struct{}{
self.GetHostEndpoint(): {},
}
meshNodes := lib.MapValuesWithExclude(nodes, excludedNodes)
getNames := func(node mesh.MeshNode) string {
return node.GetHostEndpoint()
}
nodeNames := lib.Map(meshNodes, getNames)
neighbours := s.cluster.GetNeighbours(nodeNames, self.GetHostEndpoint())
randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
for _, node := range randomSubset {
logging.Log.WriteInfof("Random node: %s", node)
} }
before := time.Now() before := time.Now()
if len(meshNodes) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance { publicKey := s.meshManager.GetPublicKey()
logging.Log.WriteInfof("Sending to random cluster") nodeNames := correspondingMesh.GetPeers()
interCluster := s.cluster.GetInterCluster(nodeNames, self.GetHostEndpoint())
randomSubset = append(randomSubset, interCluster) nodeNames = lib.Filter(nodeNames, func(s string) bool {
// Filter our only public key out so we dont sync with ourself
return s != publicKey.String()
})
var gossipNodes []string
// Clients always pings its peer for configuration
if self != nil && self.GetType() == conf.CLIENT_ROLE && len(nodeNames) > 1 {
neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
if len(neighbours) == 0 {
return false, nil
}
// Peer with 2 nodes so that there is redundnacy in
// the situation the node leaves pre-emptively
redundancyLength := min(len(neighbours), 2)
gossipNodes = neighbours[:redundancyLength]
} else {
neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
gossipNodes = lib.RandomSubsetOfLength(neighbours, s.configuration.Branch)
if len(nodeNames) > s.configuration.ClusterSize && rand.Float64() < s.configuration.InterClusterChance {
gossipNodes[len(gossipNodes)-1] = s.cluster.GetInterCluster(nodeNames, publicKey.String())
}
} }
var waitGroup sync.WaitGroup var succeeded bool = false
for index := range randomSubset { var wait sync.WaitGroup
waitGroup.Add(1)
go func(i int) error { for index, node := range gossipNodes {
defer waitGroup.Done() wait.Add(1)
err := s.requester.SyncMesh(meshId, randomSubset[i])
return err syncNode := func(i int) {
}(index) correspondingPeer, err := correspondingMesh.GetNode(node)
defer wait.Done()
if correspondingPeer == nil || err != nil {
logging.Log.WriteErrorf("node %s does not exist", node)
return
}
err = s.requester.SyncMesh(correspondingMesh, correspondingPeer)
if err == nil || err == io.EOF {
succeeded = true
}
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
}
go syncNode(index)
} }
waitGroup.Wait() wait.Wait()
s.syncCount++ s.syncCount++
logging.Log.WriteInfof("SYNC TIME: %v", time.Now().Sub(before)) logging.Log.WriteInfof("sync time: %v", time.Since(before))
logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount) logging.Log.WriteInfof("number of syncs: %d", s.syncCount)
s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount) s.infectionCount = ((s.configuration.InfectionCount + s.infectionCount - 1) % s.configuration.InfectionCount)
return nil
if !succeeded {
s.infectionCount++
}
changes := correspondingMesh.HasChanges()
correspondingMesh.SaveChanges()
s.lastSyncLock.Lock()
s.lastSync[correspondingMesh.GetMeshId()] = time.Now().Unix()
s.lastSyncLock.Unlock()
return changes, nil
}
// Pull one node in the cluster, if there has not been message dissemination
// in a certain period of time pull a random node within the cluster
func (s *SyncerImpl) Pull(self mesh.MeshNode, mesh mesh.MeshProvider) (bool, error) {
peers := mesh.GetPeers()
pubKey, _ := self.GetPublicKey()
neighbours := s.cluster.GetNeighbours(peers, pubKey.String())
neighbour := lib.RandomSubsetOfLength(neighbours, 1)
if len(neighbour) == 0 {
logging.Log.WriteInfof("no neighbours")
return false, nil
}
logging.Log.WriteInfof("pulling from node %s", neighbour[0])
pullNode, err := mesh.GetNode(neighbour[0])
if err != nil || pullNode == nil {
return false, fmt.Errorf("node %s does not exist in the mesh", neighbour[0])
}
err = s.requester.SyncMesh(mesh, pullNode)
if err == nil || err == io.EOF {
s.lastSync[mesh.GetMeshId()] = time.Now().Unix()
} else {
return false, err
}
s.syncCount++
changes := mesh.HasChanges()
return changes, nil
} }
// SyncMeshes: Sync all meshes // SyncMeshes: Sync all meshes
func (s *SyncerImpl) SyncMeshes() error { func (s *SyncerImpl) SyncMeshes() error {
for meshId, _ := range s.manager.GetMeshes() { var wg sync.WaitGroup
err := s.Sync(meshId)
if err != nil { meshes := s.meshManager.GetMeshes()
return err
s.lastPollLock.Lock()
meshesToSync := lib.Filter(lib.MapValues(meshes), func(mesh mesh.MeshProvider) bool {
return time.Now().Unix()-s.lastPoll[mesh.GetMeshId()] >= int64(s.configuration.SyncInterval)
})
s.lastPollLock.Unlock()
changes := make(chan bool, len(meshesToSync))
for i := 0; i < len(meshesToSync); {
wg.Add(1)
sync := func(index int) {
defer wg.Done()
var hasChanges bool = false
mesh := meshesToSync[index]
hasChanges, err := s.Sync(mesh)
changes <- hasChanges
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
s.lastPollLock.Lock()
s.lastPoll[mesh.GetMeshId()] = time.Now().Unix()
s.lastPollLock.Unlock()
}
go sync(i)
i++
}
wg.Wait()
hasChanges := false
for i := 0; i < len(changes); i++ {
if <-changes {
hasChanges = true
} }
} }
return nil var err error
if hasChanges {
logging.Log.WriteInfof("updating the WireGuard configuration")
err = s.meshManager.ApplyConfig()
if err != nil {
logging.Log.WriteErrorf("failed to update config %s", err.Error())
}
err = s.meshManager.GetRouteManager().UpdateRoutes()
if err != nil {
logging.Log.WriteErrorf("update routes failed %s", err.Error())
}
}
return err
} }
func NewSyncer(m mesh.MeshManager, conf *conf.WgMeshConfiguration, r SyncRequester) Syncer { type NewSyncerParams struct {
cluster, _ := conn.NewConnCluster(conf.ClusterSize) MeshManager mesh.MeshManager
ConnectionManager conn.ConnectionManager
Configuration *conf.DaemonConfiguration
Requester SyncRequester
}
func NewSyncer(params *NewSyncerParams) Syncer {
cluster, _ := conn.NewConnCluster(params.Configuration.ClusterSize)
syncRequester := NewSyncRequester(NewSyncRequesterParams{
MeshManager: params.MeshManager,
ConnectionManager: params.ConnectionManager,
Configuration: params.Configuration,
})
return &SyncerImpl{ return &SyncerImpl{
manager: m, meshManager: params.MeshManager,
conf: conf, configuration: params.Configuration,
requester: r, requester: syncRequester,
infectionCount: 0, infectionCount: 0,
syncCount: 0, syncCount: 0,
cluster: cluster} cluster: cluster,
lastSync: make(map[string]int64),
lastPoll: make(map[string]int64)}
} }

View File

@ -1,52 +1,60 @@
package sync package sync
import ( import (
logging "github.com/tim-beatham/wgmesh/pkg/log" "github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/mesh" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
// SyncErrorHandler: Handles errors when attempting to sync // SyncErrorHandler: Handles errors when attempting to sync
type SyncErrorHandler interface { type SyncErrorHandler interface {
Handle(meshId string, endpoint string, err error) bool Handle(mesh mesh.MeshProvider, endpoint string, err error) bool
} }
// SyncErrorHandlerImpl Is an implementation of the SyncErrorHandler // SyncErrorHandlerImpl Is an implementation of the SyncErrorHandler
type SyncErrorHandlerImpl struct { type SyncErrorHandlerImpl struct {
meshManager mesh.MeshManager meshManager mesh.MeshManager
connManager conn.ConnectionManager
} }
func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint string) bool { func (s *SyncErrorHandlerImpl) handleFailed(mesh mesh.MeshProvider, nodeId string) bool {
mesh := s.meshManager.GetMesh(meshId) mesh.Mark(nodeId)
node, err := mesh.GetNode(nodeId)
if mesh == nil { if err != nil {
return false s.connManager.RemoveConnection(node.GetHostEndpoint())
} }
// self, err := s.meshManager.GetSelf(meshId)
// if err != nil {
// return false
// }
// mesh.DecrementHealth(endpoint, self.GetHostEndpoint())
return true return true
} }
func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error) bool { func (s *SyncErrorHandlerImpl) handleDeadlineExceeded(mesh mesh.MeshProvider, nodeId string) bool {
node, err := mesh.GetNode(nodeId)
if err != nil {
return false
}
s.connManager.RemoveConnection(node.GetHostEndpoint())
return true
}
func (s *SyncErrorHandlerImpl) Handle(mesh mesh.MeshProvider, nodeId string, err error) bool {
errStatus, _ := status.FromError(err) errStatus, _ := status.FromError(err)
logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message()) logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message())
switch errStatus.Code() { switch errStatus.Code() {
case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound: case codes.Unavailable, codes.Unknown, codes.Internal, codes.NotFound:
return s.incrementFailedCount(meshId, endpoint) return s.handleFailed(mesh, nodeId)
case codes.DeadlineExceeded:
return s.handleDeadlineExceeded(mesh, nodeId)
} }
return false return false
} }
func NewSyncErrorHandler(m mesh.MeshManager) SyncErrorHandler { func NewSyncErrorHandler(m mesh.MeshManager, conn conn.ConnectionManager) SyncErrorHandler {
return &SyncErrorHandlerImpl{meshManager: m} return &SyncErrorHandlerImpl{meshManager: m, connManager: conn}
} }

View File

@ -2,74 +2,44 @@ package sync
import ( import (
"context" "context"
"errors"
"io" "io"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/conf"
logging "github.com/tim-beatham/wgmesh/pkg/log" "github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/mesh" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/rpc"
) )
// SyncRequester: coordinates the syncing of meshes // SyncRequester: coordinates the syncing of meshes
type SyncRequester interface { type SyncRequester interface {
GetMesh(meshId string, ifName string, port int, endPoint string) error SyncMesh(mesh mesh.MeshProvider, meshNode mesh.MeshNode) error
SyncMesh(meshid string, endPoint string) error
} }
type SyncRequesterImpl struct { type SyncRequesterImpl struct {
server *ctrlserver.MeshCtrlServer manager mesh.MeshManager
errorHdlr SyncErrorHandler connectionManager conn.ConnectionManager
configuration *conf.DaemonConfiguration
errorHdlr SyncErrorHandler
} }
// GetMesh: Retrieves the local state of the mesh at the endpoint // handleErr: handleGrpc errors
func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endPoint string) error { func (s *SyncRequesterImpl) handleErr(mesh mesh.MeshProvider, pubKey string, err error) error {
peerConnection, err := s.server.ConnectionManager.GetConnection(endPoint) ok := s.errorHdlr.Handle(mesh, pubKey, err)
if err != nil {
return err
}
client, err := peerConnection.GetClient()
if err != nil {
return err
}
c := rpc.NewSyncServiceClient(client)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
reply, err := c.GetConf(ctx, &rpc.GetConfRequest{MeshId: meshId})
if err != nil {
return err
}
err = s.server.MeshManager.AddMesh(&mesh.AddMeshParams{
MeshId: meshId,
DevName: ifName,
WgPort: port,
MeshBytes: reply.Mesh,
})
return err
}
func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error {
ok := s.errorHdlr.Handle(meshId, endpoint, err)
if ok { if ok {
return nil return nil
} }
return err return err
} }
// SyncMesh: Proactively send a sync request to the other mesh // SyncMesh: Proactively send a sync request to the other mesh
func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error { func (s *SyncRequesterImpl) SyncMesh(mesh mesh.MeshProvider, meshNode mesh.MeshNode) error {
peerConnection, err := s.server.ConnectionManager.GetConnection(endpoint) endpoint := meshNode.GetHostEndpoint()
pubKey, _ := meshNode.GetPublicKey()
peerConnection, err := s.connectionManager.GetConnection(endpoint)
if err != nil { if err != nil {
return err return err
@ -81,15 +51,9 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
return err return err
} }
mesh := s.server.MeshManager.GetMesh(meshId)
if mesh == nil {
return errors.New("mesh does not exist")
}
c := rpc.NewSyncServiceClient(client) c := rpc.NewSyncServiceClient(client)
syncTimeOut := s.server.Conf.SyncRate * float64(time.Second) syncTimeOut := float64(s.configuration.SyncInterval) * float64(time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(syncTimeOut)) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(syncTimeOut))
defer cancel() defer cancel()
@ -97,11 +61,11 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
err = s.syncMesh(mesh, ctx, c) err = s.syncMesh(mesh, ctx, c)
if err != nil { if err != nil {
return s.handleErr(meshId, endpoint, err) s.handleErr(mesh, pubKey.String(), err)
} }
logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId) logging.Log.WriteInfof("synced with node: %s meshId: %s\n", endpoint, mesh.GetMeshId())
return nil return err
} }
func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context, client rpc.SyncServiceClient) error { func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context, client rpc.SyncServiceClient) error {
@ -125,7 +89,7 @@ func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context
in, err := stream.Recv() in, err := stream.Recv()
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
logging.Log.WriteInfof("Stream recv error: %s\n", err.Error()) logging.Log.WriteInfof("stream recv error: %s\n", err.Error())
return err return err
} }
@ -134,7 +98,7 @@ func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context
} }
if err != nil { if err != nil {
logging.Log.WriteInfof("Syncer recv error: %s\n", err.Error()) logging.Log.WriteInfof("syncer recv error: %s\n", err.Error())
return err return err
} }
@ -148,7 +112,17 @@ func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context
return nil return nil
} }
func NewSyncRequester(s *ctrlserver.MeshCtrlServer) SyncRequester { type NewSyncRequesterParams struct {
errorHdlr := NewSyncErrorHandler(s.MeshManager) MeshManager mesh.MeshManager
return &SyncRequesterImpl{server: s, errorHdlr: errorHdlr} ConnectionManager conn.ConnectionManager
Configuration *conf.DaemonConfiguration
}
func NewSyncRequester(params NewSyncRequesterParams) SyncRequester {
errorHdlr := NewSyncErrorHandler(params.MeshManager, params.ConnectionManager)
return &SyncRequesterImpl{manager: params.MeshManager,
connectionManager: params.ConnectionManager,
configuration: params.Configuration,
errorHdlr: errorHdlr,
}
} }

View File

@ -1,32 +0,0 @@
package sync
import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
// SyncScheduler: Loops through all nodes in the mesh and runs a schedule to
// sync each event
type SyncScheduler interface {
Run() error
Stop() error
}
// SyncSchedulerImpl scheduler for sync scheduling
type SyncSchedulerImpl struct {
quit chan struct{}
server *ctrlserver.MeshCtrlServer
syncer Syncer
}
// Run implements SyncScheduler.
func syncFunction(syncer Syncer) lib.TimerFunc {
return func() error {
return syncer.SyncMeshes()
}
}
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester) *lib.Timer {
syncer := NewSyncer(s.MeshManager, s.Conf, syncRequester)
return lib.NewTimer(syncFunction(syncer), int(s.Conf.SyncRate))
}

View File

@ -6,19 +6,18 @@ import (
"errors" "errors"
"io" "io"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/rpc"
"github.com/tim-beatham/wgmesh/pkg/rpc"
) )
type SyncServiceImpl struct { type SyncServiceImpl struct {
rpc.UnimplementedSyncServiceServer rpc.UnimplementedSyncServiceServer
Server *ctrlserver.MeshCtrlServer MeshManager mesh.MeshManager
} }
// GetMesh: Gets a nodes local mesh configuration as a CRDT // GetMesh: Gets a nodes local mesh configuration as a CRDT
func (s *SyncServiceImpl) GetConf(context context.Context, request *rpc.GetConfRequest) (*rpc.GetConfReply, error) { func (s *SyncServiceImpl) GetConf(context context.Context, request *rpc.GetConfRequest) (*rpc.GetConfReply, error) {
mesh := s.Server.MeshManager.GetMesh(request.MeshId) mesh := s.MeshManager.GetMesh(request.MeshId)
if mesh == nil { if mesh == nil {
return nil, errors.New("mesh does not exist") return nil, errors.New("mesh does not exist")
@ -56,7 +55,7 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error
if len(meshId) == 0 { if len(meshId) == 0 {
meshId = in.MeshId meshId = in.MeshId
mesh := s.Server.MeshManager.GetMesh(meshId) mesh := s.MeshManager.GetMesh(meshId)
if mesh == nil { if mesh == nil {
return errors.New("mesh does not exist") return errors.New("mesh does not exist")
@ -64,11 +63,11 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error
syncer = mesh.GetSyncer() syncer = mesh.GetSyncer()
} else if meshId != in.MeshId { } else if meshId != in.MeshId {
return errors.New("Differing MeshIDs") return errors.New("differing meshids")
} }
if syncer == nil { if syncer == nil {
return errors.New("Syncer should not be nil") return errors.New("syncer should not be nil")
} }
msg, moreMessages := syncer.GenerateMessage() msg, moreMessages := syncer.GenerateMessage()
@ -92,7 +91,3 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error
} }
} }
} }
func NewSyncService(server *ctrlserver.MeshCtrlServer) *SyncServiceImpl {
return &SyncServiceImpl{Server: server}
}

View File

@ -1,14 +0,0 @@
package timestamp
import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
timerFunc := func() error {
return ctrlServer.MeshManager.UpdateTimeStamp()
}
return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime)
}

View File

@ -1,15 +1,20 @@
package wg package wg
import "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
type WgInterfaceManipulatorStub struct{} type WgInterfaceManipulatorStub struct{}
func (i *WgInterfaceManipulatorStub) CreateInterface(params *CreateInterfaceParams) error { // CreateInterface creates a WireGuard interface
func (w *WgInterfaceManipulatorStub) CreateInterface(port int, privateKey *wgtypes.Key) (string, error) {
return "aninterface", nil
}
// AddAddress adds an address to the given interface name
func (w *WgInterfaceManipulatorStub) AddAddress(ifName string, addr string) error {
return nil return nil
} }
func (i *WgInterfaceManipulatorStub) AddAddress(ifName string, addr string) error { // RemoveInterface removes the specified interface
return nil func (w *WgInterfaceManipulatorStub) RemoveInterface(ifName string) error {
}
func (i *WgInterfaceManipulatorStub) RemoveInterface(ifName string) error {
return nil return nil
} }

View File

@ -1,5 +1,16 @@
package wg package wg
import "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
type WgInterfaceManipulator interface {
// CreateInterface creates a WireGuard interface
CreateInterface(port int, privateKey *wgtypes.Key) (string, error)
// AddAddress adds an address to the given interface name
AddAddress(ifName string, addr string) error
// RemoveInterface removes the specified interface
RemoveInterface(ifName string) error
}
type WgError struct { type WgError struct {
msg string msg string
} }
@ -7,17 +18,3 @@ type WgError struct {
func (m *WgError) Error() string { func (m *WgError) Error() string {
return m.msg return m.msg
} }
type CreateInterfaceParams struct {
IfName string
Port int
}
type WgInterfaceManipulator interface {
// CreateInterface creates a WireGuard interface
CreateInterface(params *CreateInterfaceParams) error
// AddAddress adds an address to the given interface name
AddAddress(ifName string, addr string) error
// RemoveInterface removes the specified interface
RemoveInterface(ifName string) error
}

View File

@ -1,10 +1,12 @@
package wg package wg
import ( import (
"crypto"
"crypto/rand"
"fmt" "fmt"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@ -13,40 +15,47 @@ type WgInterfaceManipulatorImpl struct {
client *wgctrl.Client client *wgctrl.Client
} }
const hashLength = 6
// CreateInterface creates a WireGuard interface // CreateInterface creates a WireGuard interface
func (m *WgInterfaceManipulatorImpl) CreateInterface(params *CreateInterfaceParams) error { func (m *WgInterfaceManipulatorImpl) CreateInterface(port int, privKey *wgtypes.Key) (string, error) {
rtnl, err := lib.NewRtNetlinkConfig() rtnl, err := lib.NewRtNetlinkConfig()
if err != nil { if err != nil {
return fmt.Errorf("failed to access link: %w", err) return "", fmt.Errorf("failed to access link: %w", err)
} }
defer rtnl.Close() defer rtnl.Close()
err = rtnl.CreateLink(params.IfName) randomBuf := make([]byte, 32)
_, err = rand.Read(randomBuf)
if err != nil { if err != nil {
return fmt.Errorf("failed to create link: %w", err) return "", err
} }
privateKey, err := wgtypes.GeneratePrivateKey() md5 := crypto.MD5.New().Sum(randomBuf)
md5Str := fmt.Sprintf("wg%x", md5)[:hashLength]
err = rtnl.CreateLink(md5Str)
if err != nil { if err != nil {
return fmt.Errorf("failed to create private key: %w", err) return "", fmt.Errorf("failed to create link: %w", err)
} }
var cfg wgtypes.Config = wgtypes.Config{ var cfg wgtypes.Config = wgtypes.Config{
PrivateKey: &privateKey, PrivateKey: privKey,
ListenPort: &params.Port, ListenPort: &port,
} }
err = m.client.ConfigureDevice(params.IfName, cfg) err = m.client.ConfigureDevice(md5Str, cfg)
if err != nil { if err != nil {
return fmt.Errorf("failed to configure dev: %w", err) m.RemoveInterface(md5Str)
return "", fmt.Errorf("failed to configure dev: %w", err)
} }
logging.Log.WriteInfof("ip link set up dev %s type wireguard", params.IfName) logging.Log.WriteInfof("ip link set up dev %s type wireguard", md5Str)
return nil return md5Str, nil
} }
// Add an address to the given interface // Add an address to the given interface

View File

@ -0,0 +1,92 @@
// Package to convert an IPV6 addres into 8 words
package what8words
import (
"bufio"
"bytes"
"fmt"
"net"
"os"
"strings"
)
type What8Words struct {
words []string
}
// Convert implements What8Words.
func (w *What8Words) Convert(ipStr string) (string, error) {
ip, ipNet, err := net.ParseCIDR(ipStr)
if err != nil {
return "", err
}
ip16 := ip.To16()
if ip16 == nil {
return "", fmt.Errorf("cannot convert ip to 16 representation")
}
representation := make([]string, 7)
for i := 2; i <= net.IPv6len-2; i += 2 {
word1 := w.words[ip16[i]]
word2 := w.words[ip16[i+1]]
representation[i/2-1] = fmt.Sprintf("%s-%s", word1, word2)
}
prefixSize, _ := ipNet.Mask.Size()
return strings.Join(representation[:prefixSize/16-1], "."), nil
}
// Convert implements What8Words.
func (w *What8Words) ConvertIdentifier(ipStr string) (string, error) {
ip, err := w.Convert(ipStr)
if err != nil {
return "", err
}
constituents := strings.Split(ip, ".")
return strings.Join(constituents[3:], "."), nil
}
func NewWhat8Words(pathToWords string) (*What8Words, error) {
words, err := ReadWords(pathToWords)
if err != nil {
return nil, err
}
return &What8Words{words: words}, nil
}
// ReadWords reads the what 8 words txt file
func ReadWords(wordFile string) ([]string, error) {
f, err := os.ReadFile(wordFile)
if err != nil {
return nil, err
}
words := make([]string, 257)
reader := bufio.NewScanner(bytes.NewReader(f))
counter := 0
for reader.Scan() && counter <= len(words) {
text := reader.Text()
words[counter] = text
counter++
if reader.Err() != nil {
return nil, reader.Err()
}
}
return words, nil
}

1
smegmesh-web Submodule

Submodule smegmesh-web added at c1128bcd98