1
0
forked from extern/smegmesh

Compare commits

...

51 Commits

Author SHA1 Message Date
1e263cc6a8 51-bugfix-routes-not-removing-when-withdrawn
- Routes are not being removed despite being withdrawn from the
configuration.
- Best path routes are not shared across interfaces
- Bug in consistent hashing wrong parameter passed caused by
refactorings.
2023-12-10 15:10:36 +00:00
dae9cd31a1 Merge pull request #50 from tim-beatham/50-give-client-ability-to-bridge-meshes
50-give-client-ability-to-bridge-meshes
2023-12-08 23:58:32 +00:00
f855f53fbf 50-give-client-ability-to-bridge-meshes
Client can act as a route bridging meshes. Cient send keepalives
to all of it's peers in the different meshes act as a bridge between
the meshes
2023-12-08 23:56:07 +00:00
52feb5767b Merge pull request #48 from tim-beatham/47-default-routing
47 default routing
2023-12-08 20:03:45 +00:00
815c4484ee 47-default-routing
Implemented default routing and improved size of gossip. Using 64 bit
hash funciton to identify vector.
2023-12-08 20:02:57 +00:00
0058c9f4c9 47-default-routing
Implementing default routing so that all traffic goes out of an
exit point.
2023-12-08 11:49:24 +00:00
92c0805275 Merge pull request #46 from tim-beatham/45-use-statistical-testing
45 use statistical testing
2023-12-07 18:20:25 +00:00
661fb0d54c 45-use-statistical-testing
Keepalive is based on per mesh and not per node.
Using total ordering mechanism similar to paxos to elect a leader
if leader doesn't update it's timestamp within 3 * keepAlive then
give the leader a gravestone and elect the next leader.
Leader is bassed on lexicographically ordered public key.
2023-12-07 18:18:13 +00:00
64885f1055 45-use-statistical-testing
Using statistical testing to test whether the node has failed.
2023-12-07 01:44:54 +00:00
2169f7796f Merge pull request #44 from tim-beatham/43-gravestones
43-use-gravestones
2023-12-06 22:46:05 +00:00
a3ceff019d 43-use-gravestones
Change of approach from keepalive to a noiseless protocol
2023-12-06 22:45:04 +00:00
b78d96986c Merge pull request #42 from tim-beatham/41-bugfix-fluctuating-ips
41 bugfix fluctuating ips
2023-12-06 14:37:14 +00:00
1b18d89c9f 41-bugfix-fluctuating-ips
Fluctuating ips creating hub and spoke.
2023-12-05 02:00:16 +00:00
245a2c5f58 41-bugfix-fluctuating-ips
If the node is a peer then add the client in the WG
configuration.
2023-12-04 17:40:24 +00:00
c40f7510b8 41-bugfix-fluctuating-ips
IPs of clients fluctuating because there isn't a strict order on
clients. Client's need to be processed before the peers.
2023-12-04 17:32:50 +00:00
78d748770c BUGIX Hash client by public key 2023-12-04 17:13:51 +00:00
0ff2a8eef9 BUGFIX: Allowed IPs fluctuating 2023-12-04 17:11:37 +00:00
fd7bd80485 BUGFIX
Don't get device each time it is an expensive operation.
2023-12-04 16:40:15 +00:00
3ef1b68ba5 BUGFIX: Hashing datastore to work out changes
Changed hashing implementation to work out if there are changes
in the data store
2023-11-30 15:58:26 +00:00
b9ba836ae3 Merge pull request #40 from tim-beatham/39-implement-two-phase-map
39-implement-two-phase-map
2023-11-30 02:03:36 +00:00
650901aba1 39-implement-two-phase-map
Implemented my own two phase map based on vector clocks
2023-11-30 02:02:38 +00:00
a82eab0686 Bugfix
Added replace peers so that deleted nodes are automatically removed
2023-11-28 14:43:55 +00:00
32e7e4c7df main
Bugfix. Fixed issue where consistent hashing was not working.
2023-11-28 14:42:09 +00:00
1fae0a6c2c Merge pull request #37 from tim-beatham/36-add-route-path-into-route-object
36-add-route-path-into-route-object
2023-11-27 21:03:56 +00:00
d8e156f13f 36-add-route-path-into-route-object
Added the route path into the route object so that we can
see what meshes packets are routed across.
2023-11-27 18:55:41 +00:00
3fca49a1c9 Merge pull request #35 from tim-beatham/34-fix-routing
34 fix routing
2023-11-27 16:05:06 +00:00
a2517a1e72 34-fix-routing
- Added mesh-to-mesh routing of hop count > 1
- If there is a tie-breaker with respect to the hop-count use consistent
hashing to determine the route to take based on the public key.
2023-11-27 15:56:30 +00:00
aef8b59f22 32-fix-routing
Flooding routes into other meshes a bit like BGP.
2023-11-25 03:15:58 +00:00
4030d17b41 Fixed routing issue 2023-11-24 17:49:06 +00:00
73db65660b Merge pull request #33 from tim-beatham/32-incorporate-dns
32-incorporate-dns
2023-11-24 15:05:40 +00:00
d1a74a7b95 32-incorporate-dns
Incorporated a DNS server. A DNS server can be run to resolve host
names.
2023-11-24 15:04:07 +00:00
f28ed8260d Merge pull request #30 from tim-beatham/29-only-ping-clients-who-have-updated-their-config
29-only-ping-clients-who-have-updated-their-config
2023-11-24 12:39:14 +00:00
2c406718df 29-only-ping-clients-who-have-updated-their-config
Only consider clients who have updated their config when synchronising
with peers. Consider a dead time where we don't have a handshake and
a prune time when we remove them from the WireGuard configuration.
2023-11-24 12:37:54 +00:00
11b003b549 Merge pull request #28 from tim-beatham/27-remove-client-grpc-endpoint
27-remove-client-grpc-endpoint
2023-11-24 12:08:42 +00:00
7be11dbaa3 27-remove-client-grpc-endpoint
Removed a client's grpc endpoint value. Client's aren't publicly
available so there is no need for a client's gRPC endpoint.
Also changed a node ID's to their public key. A node id's public
address is an issue for mobility of clients as their endpoint
is subject to change
2023-11-24 12:07:03 +00:00
e7ac8c5542 Only updating WireGuard config if node exists 2023-11-22 13:08:02 +00:00
09c64c4628 Fixed container file 2023-11-22 12:45:01 +00:00
2c4f18f52b Merge pull request #26 from tim-beatham/25-modify-code-to-use-public-api
25-modify-code-to-use-public-api
2023-11-22 10:42:48 +00:00
4c54022f63 25-modify-code-to-use-public-api
Modify the code to use a public IP address by default if none is
specified
2023-11-22 10:41:54 +00:00
bf0724f6e5 Merge pull request #24 from tim-beatham/24-keepalive-holepunch
24 keepalive holepunch
2023-11-21 21:28:16 +00:00
624bd6e921 24-keepalive
Persistent keep alive working
2023-11-21 21:26:31 +00:00
7b939e0468 24-keepalive-holepunch
Added the ability to hole punch NAT
2023-11-21 20:42:43 +00:00
6e201ebaf5 24-keepalive-holepunch
Nodes acting as peers and nodes acting as clients
2023-11-21 16:42:49 +00:00
06542da03c main
Fixed problems with timestamp not updating
2023-11-21 13:31:34 +00:00
0d63cd6624 main
Adding words.txt for what words
2023-11-20 18:12:58 +00:00
f13319cfc1 Merge pull request #22 from tim-beatham/21-phonetic-words-ipv6
21 phonetic words ipv6
2023-11-20 18:08:49 +00:00
95f4495b0b 21-phonetic-words-ipv6
Simple what 8 words implementation
2023-11-20 18:07:52 +00:00
330fa74ef4 IPv6 What 8 Words
what 8 words for ipv6 started
2023-11-20 15:22:32 +00:00
3e5b57e41f Merge pull request #20 from tim-beatham/19-hash-wg-interface
Hashing the WireGuard interface
2023-11-20 13:04:19 +00:00
b179cd3cf4 Hashing the WireGuard interface
Hashing the interface and using ephmeral ports so that the admin doesn't
choose an interface and port combination. An administrator can alteranatively
decide to provide port but this isn't critical.
2023-11-20 13:03:42 +00:00
8f211aa116 Merge pull request #18 from tim-beatham/26-performance-testing
Stubbing out WireGuard components
2023-11-20 11:29:37 +00:00
55 changed files with 3333 additions and 970 deletions

View File

@ -8,4 +8,5 @@ RUN apt-get update && apt-get install -y \
tmux \ tmux \
vim vim
WORKDIR /wgmesh WORKDIR /wgmesh
RUN go mod tidy
RUN go build -o /usr/local/bin ./... RUN go build -o /usr/local/bin ./...

View File

@ -7,11 +7,13 @@ import (
) )
func main() { func main() {
apiServer, err := api.NewSmegServer() apiServer, err := api.NewSmegServer(api.ApiServerConf{
WordsFile: "./cmd/api/words.txt",
})
if err != nil { if err != nil {
log.Fatal(err.Error()) log.Fatal(err.Error())
} }
apiServer.Run(":40000") apiServer.Run(":8080")
} }

257
cmd/api/words.txt Normal file
View File

@ -0,0 +1,257 @@
be
to
of
it
in
we
do
he
on
go
at
if
or
up
by
hi
the
and
you
not
for
but
say
get
she
one
all
can
out
who
now
see
way
how
lot
yes
use
any
day
try
put
let
why
new
off
big
too
ask
man
bit
end
may
own
run
pay
job
old
kid
bad
few
ago
far
buy
set
guy
car
sit
war
win
yet
top
law
cut
low
die
eat
age
hit
air
add
boy
act
tax
oil
eye
son
key
fun
dad
dog
arm
fly
box
gas
lie
hot
gun
per
art
red
fit
bed
fan
mix
mom
sex
bus
fix
bar
lay
ice
bet
bag
due
aid
tie
leg
ban
odd
cup
dry
cry
rid
pop
sir
cat
map
sad
sea
aim
sun
fat
row
egg
tea
god
wed
tip
ear
hat
net
ill
dig
fee
mad
gap
nor
bid
era
toy
sky
bin
owe
wet
tap
pro
ski
cow
pen
van
web
pot
sum
cap
log
pub
pig
joy
raw
rat
via
lip
two
six
ten
lab
ton
mid
bat
hip
gut
sin
non
rub
sub
par
pre
ray
cue
dye
fin
ion
neo
hey
wow
mum
bye
aye
jet
sue
pet
flu
cop
ooh
rip
spy
pie
bug
gum
wan
rap
nut
beg
pin
pit
jam
tag
fax
vet
fry
pad
lad
mud
bay
con
pan
gee
toe
dip
shy
gym
zoo
fox
bow
tin
hop
wee
kit
opt
vow
sew
cab
bee
rob
rig
yep
ego
rib
nod
hug
lap
ash
hum
dam
bum
yen
jar

18
cmd/dns/main.go Normal file
View File

@ -0,0 +1,18 @@
package main
import (
"log"
smegdns "github.com/tim-beatham/wgmesh/pkg/dns"
)
func main() {
server, err := smegdns.NewDns(53)
if err != nil {
log.Fatal(err.Error())
}
defer server.Close()
server.Listen()
}

View File

@ -8,7 +8,9 @@ import (
"time" "time"
"github.com/akamensky/argparse" "github.com/akamensky/argparse"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/wgmesh/pkg/log"
) )
@ -16,7 +18,6 @@ const SockAddr = "/tmp/wgmesh_ipc.sock"
type CreateMeshParams struct { type CreateMeshParams struct {
Client *ipcRpc.Client Client *ipcRpc.Client
IfName string
WgPort int WgPort int
Endpoint string Endpoint string
} }
@ -24,7 +25,6 @@ type CreateMeshParams struct {
func createMesh(args *CreateMeshParams) string { func createMesh(args *CreateMeshParams) string {
var reply string var reply string
newMeshParams := ipc.NewMeshArgs{ newMeshParams := ipc.NewMeshArgs{
IfName: args.IfName,
WgPort: args.WgPort, WgPort: args.WgPort,
Endpoint: args.Endpoint, Endpoint: args.Endpoint,
} }
@ -68,7 +68,6 @@ func joinMesh(params *JoinMeshParams) string {
args := ipc.JoinMeshArgs{ args := ipc.JoinMeshArgs{
MeshId: params.MeshId, MeshId: params.MeshId,
IpAdress: params.IpAddress, IpAdress: params.IpAddress,
IfName: params.IfName,
Port: params.WgPort, Port: params.WgPort,
} }
@ -96,9 +95,13 @@ func getMesh(client *ipcRpc.Client, meshId string) {
fmt.Println("Control Endpoint: " + node.HostEndpoint) fmt.Println("Control Endpoint: " + node.HostEndpoint)
fmt.Println("WireGuard Endpoint: " + node.WgEndpoint) fmt.Println("WireGuard Endpoint: " + node.WgEndpoint)
fmt.Println("Wg IP: " + node.WgHost) fmt.Println("Wg IP: " + node.WgHost)
fmt.Println(fmt.Sprintf("Timestamp: %s", time.Unix(node.Timestamp, 0).String())) fmt.Printf("Timestamp: %s", time.Unix(node.Timestamp, 0).String())
advertiseRoutes := strings.Join(node.Routes, ",") mapFunc := func(r ctrlserver.MeshRoute) string {
return r.Destination
}
advertiseRoutes := strings.Join(lib.Map(node.Routes, mapFunc), ",")
fmt.Printf("Routes: %s\n", advertiseRoutes) fmt.Printf("Routes: %s\n", advertiseRoutes)
fmt.Println("---") fmt.Println("---")
@ -240,7 +243,6 @@ func main() {
newMeshCmd := parser.NewCommand("new-mesh", "Create a new mesh") newMeshCmd := parser.NewCommand("new-mesh", "Create a new mesh")
listMeshCmd := parser.NewCommand("list-meshes", "List meshes the node is connected to") listMeshCmd := parser.NewCommand("list-meshes", "List meshes the node is connected to")
joinMeshCmd := parser.NewCommand("join-mesh", "Join a mesh network") joinMeshCmd := parser.NewCommand("join-mesh", "Join a mesh network")
// getMeshCmd := parser.NewCommand("get-mesh", "Get a mesh network")
enableInterfaceCmd := parser.NewCommand("enable-interface", "Enable A Specific Mesh Interface") enableInterfaceCmd := parser.NewCommand("enable-interface", "Enable A Specific Mesh Interface")
getGraphCmd := parser.NewCommand("get-graph", "Convert a mesh into DOT format") getGraphCmd := parser.NewCommand("get-graph", "Convert a mesh into DOT format")
leaveMeshCmd := parser.NewCommand("leave-mesh", "Leave a mesh network") leaveMeshCmd := parser.NewCommand("leave-mesh", "Leave a mesh network")
@ -251,14 +253,12 @@ func main() {
deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements") deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements")
getNodeCmd := parser.NewCommand("get-node", "Get a specific node from the mesh") getNodeCmd := parser.NewCommand("get-node", "Get a specific node from the mesh")
var newMeshIfName *string = newMeshCmd.String("f", "ifname", &argparse.Options{Required: true}) var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{})
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{}) var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{})
var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true}) var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true}) var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true})
var joinMeshIfName *string = joinMeshCmd.String("f", "ifname", &argparse.Options{Required: true}) var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{})
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{}) var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{})
var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true}) var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true})
@ -298,7 +298,6 @@ func main() {
if newMeshCmd.Happened() { if newMeshCmd.Happened() {
fmt.Println(createMesh(&CreateMeshParams{ fmt.Println(createMesh(&CreateMeshParams{
Client: client, Client: client,
IfName: *newMeshIfName,
WgPort: *newMeshPort, WgPort: *newMeshPort,
Endpoint: *newMeshEndpoint, Endpoint: *newMeshEndpoint,
})) }))
@ -311,7 +310,6 @@ func main() {
if joinMeshCmd.Happened() { if joinMeshCmd.Happened() {
fmt.Println(joinMesh(&JoinMeshParams{ fmt.Println(joinMesh(&JoinMeshParams{
Client: client, Client: client,
IfName: *joinMeshIfName,
WgPort: *joinMeshPort, WgPort: *joinMeshPort,
IpAddress: *joinMeshIpAddress, IpAddress: *joinMeshIpAddress,
MeshId: *joinMeshId, MeshId: *joinMeshId,

View File

@ -13,7 +13,7 @@ import (
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/robin" "github.com/tim-beatham/wgmesh/pkg/robin"
"github.com/tim-beatham/wgmesh/pkg/sync" "github.com/tim-beatham/wgmesh/pkg/sync"
"github.com/tim-beatham/wgmesh/pkg/timestamp" timer "github.com/tim-beatham/wgmesh/pkg/timers"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
) )
@ -45,20 +45,25 @@ func main() {
var robinRpc robin.WgRpc var robinRpc robin.WgRpc
var robinIpc robin.IpcHandler var robinIpc robin.IpcHandler
var syncProvider sync.SyncServiceImpl var syncProvider sync.SyncServiceImpl
var syncRequester sync.SyncRequester
var syncer sync.Syncer
ctrlServerParams := ctrlserver.NewCtrlServerParams{ ctrlServerParams := ctrlserver.NewCtrlServerParams{
Conf: conf, Conf: conf,
CtrlProvider: &robinRpc, CtrlProvider: &robinRpc,
SyncProvider: &syncProvider, SyncProvider: &syncProvider,
Client: client, Client: client,
OnDelete: func(mp mesh.MeshProvider) {
syncer.SyncMeshes()
},
} }
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams) ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
syncProvider.Server = ctrlServer syncProvider.Server = ctrlServer
syncRequester := sync.NewSyncRequester(ctrlServer) syncRequester = sync.NewSyncRequester(ctrlServer)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester) syncer = sync.NewSyncer(ctrlServer.MeshManager, conf, syncRequester)
timestampScheduler := timestamp.NewTimestampScheduler(ctrlServer) syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, syncer)
pruneScheduler := mesh.NewPruner(ctrlServer.MeshManager, *conf) keepAlive := timer.NewTimestampScheduler(ctrlServer)
robinIpcParams := robin.RobinIpcParams{ robinIpcParams := robin.RobinIpcParams{
CtrlServer: ctrlServer, CtrlServer: ctrlServer,
@ -76,13 +81,12 @@ func main() {
go ipc.RunIpcHandler(&robinIpc) go ipc.RunIpcHandler(&robinIpc)
go syncScheduler.Run() go syncScheduler.Run()
go timestampScheduler.Run() go keepAlive.Run()
go pruneScheduler.Run()
closeResources := func() { closeResources := func() {
logging.Log.WriteInfof("Closing resources") logging.Log.WriteInfof("Closing resources")
syncScheduler.Stop() syncScheduler.Stop()
timestampScheduler.Stop() keepAlive.Stop()
ctrlServer.Close() ctrlServer.Close()
client.Close() client.Close()
} }

View File

@ -10,6 +10,7 @@ import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/wgmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/what8words"
) )
const SockAddr = "/tmp/wgmesh_ipc.sock" const SockAddr = "/tmp/wgmesh_ipc.sock"
@ -22,11 +23,36 @@ type ApiServer interface {
type SmegServer struct { type SmegServer struct {
router *gin.Engine router *gin.Engine
client *ipcRpc.Client client *ipcRpc.Client
words *what8words.What8Words
} }
func meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode { func (s *SmegServer) routeToApiRoute(meshNode ctrlserver.MeshNode) []Route {
routes := make([]Route, len(meshNode.Routes))
for index, route := range meshNode.Routes {
if route.Path == nil {
route.Path = make([]string, 0)
}
routes[index] = Route{
Prefix: route.Destination,
Path: route.Path,
}
}
return routes
}
func (s *SmegServer) meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode {
if meshNode.Routes == nil { if meshNode.Routes == nil {
meshNode.Routes = make([]string, 0) meshNode.Routes = make([]ctrlserver.MeshRoute, 0)
}
alias := meshNode.Alias
if alias == "" {
alias, _ = s.words.ConvertIdentifier(meshNode.WgHost)
} }
return &SmegNode{ return &SmegNode{
@ -35,20 +61,20 @@ func meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode {
Endpoint: meshNode.HostEndpoint, Endpoint: meshNode.HostEndpoint,
Timestamp: int(meshNode.Timestamp), Timestamp: int(meshNode.Timestamp),
Description: meshNode.Description, Description: meshNode.Description,
Routes: meshNode.Routes, Routes: s.routeToApiRoute(meshNode),
PublicKey: meshNode.PublicKey, PublicKey: meshNode.PublicKey,
Alias: meshNode.Alias, Alias: alias,
Services: meshNode.Services, Services: meshNode.Services,
} }
} }
func meshToAPIMesh(meshId string, nodes []ctrlserver.MeshNode) SmegMesh { func (s *SmegServer) meshToAPIMesh(meshId string, nodes []ctrlserver.MeshNode) SmegMesh {
var smegMesh SmegMesh var smegMesh SmegMesh
smegMesh.MeshId = meshId smegMesh.MeshId = meshId
smegMesh.Nodes = make(map[string]SmegNode) smegMesh.Nodes = make(map[string]SmegNode)
for _, node := range nodes { for _, node := range nodes {
smegMesh.Nodes[node.WgHost] = *meshNodeToAPIMeshNode(node) smegMesh.Nodes[node.WgHost] = *s.meshNodeToAPIMeshNode(node)
} }
return smegMesh return smegMesh
@ -62,11 +88,11 @@ func (s *SmegServer) CreateMesh(c *gin.Context) {
c.JSON(http.StatusBadRequest, &gin.H{ c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(), "error": err.Error(),
}) })
return return
} }
ipcRequest := ipc.NewMeshArgs{ ipcRequest := ipc.NewMeshArgs{
IfName: createMesh.IfName,
WgPort: createMesh.WgPort, WgPort: createMesh.WgPort,
} }
@ -100,7 +126,6 @@ func (s *SmegServer) JoinMesh(c *gin.Context) {
ipcRequest := ipc.JoinMeshArgs{ ipcRequest := ipc.JoinMeshArgs{
MeshId: joinMesh.MeshId, MeshId: joinMesh.MeshId,
IpAdress: joinMesh.Bootstrap, IpAdress: joinMesh.Bootstrap,
IfName: joinMesh.IfName,
Port: joinMesh.WgPort, Port: joinMesh.WgPort,
} }
@ -139,7 +164,7 @@ func (s *SmegServer) GetMesh(c *gin.Context) {
return return
} }
mesh := meshToAPIMesh(meshidParam, getMeshReply.Nodes) mesh := s.meshToAPIMesh(meshidParam, getMeshReply.Nodes)
c.JSON(http.StatusOK, mesh) c.JSON(http.StatusOK, mesh)
} }
@ -168,7 +193,7 @@ func (s *SmegServer) GetMeshes(c *gin.Context) {
return return
} }
meshes = append(meshes, meshToAPIMesh(mesh, getMeshReply.Nodes)) meshes = append(meshes, s.meshToAPIMesh(mesh, getMeshReply.Nodes))
} }
c.JSON(http.StatusOK, meshes) c.JSON(http.StatusOK, meshes)
@ -179,13 +204,19 @@ func (s *SmegServer) Run(addr string) error {
return s.router.Run(addr) return s.router.Run(addr)
} }
func NewSmegServer() (ApiServer, error) { func NewSmegServer(conf ApiServerConf) (ApiServer, error) {
client, err := ipcRpc.DialHTTP("unix", SockAddr) client, err := ipcRpc.DialHTTP("unix", SockAddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
words, err := what8words.NewWhat8Words(conf.WordsFile)
if err != nil {
return nil, err
}
router := gin.Default() router := gin.Default()
router.Use(gin.LoggerWithConfig(gin.LoggerConfig{ router.Use(gin.LoggerWithConfig(gin.LoggerConfig{
@ -195,6 +226,7 @@ func NewSmegServer() (ApiServer, error) {
smegServer := &SmegServer{ smegServer := &SmegServer{
router: router, router: router,
client: client, client: client,
words: words,
} }
router.GET("/meshes", smegServer.GetMeshes) router.GET("/meshes", smegServer.GetMeshes)

View File

@ -1,5 +1,10 @@
package api package api
type Route struct {
Prefix string `json:"prefix"`
Path []string `json:"path"`
}
type SmegNode struct { type SmegNode struct {
Alias string `json:"alias"` Alias string `json:"alias"`
WgHost string `json:"wgHost"` WgHost string `json:"wgHost"`
@ -8,7 +13,7 @@ type SmegNode struct {
Timestamp int `json:"timestamp"` Timestamp int `json:"timestamp"`
Description string `json:"description"` Description string `json:"description"`
PublicKey string `json:"publicKey"` PublicKey string `json:"publicKey"`
Routes []string `json:"routes"` Routes []Route `json:"routes"`
Services map[string]string `json:"services"` Services map[string]string `json:"services"`
} }
@ -18,13 +23,15 @@ type SmegMesh struct {
} }
type CreateMeshRequest struct { type CreateMeshRequest struct {
IfName string `json:"ifName" binding:"required"` WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"`
WgPort int `json:"port" binding:"required,gte=1024,lt=65535"`
} }
type JoinMeshRequest struct { type JoinMeshRequest struct {
IfName string `json:"ifName" binding:"required"` WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"`
WgPort int `json:"port" binding:"required,gte=1024,lt=65535"`
Bootstrap string `json:"bootstrap" binding:"required"` Bootstrap string `json:"bootstrap" binding:"required"`
MeshId string `json:"meshid" binding:"required"` MeshId string `json:"meshid" binding:"required"`
} }
type ApiServerConf struct {
WordsFile string
}

View File

@ -1,9 +1,10 @@
package crdt package automerge
import ( import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"slices"
"strings" "strings"
"time" "time"
@ -35,15 +36,55 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
panic("node must be of type *MeshNodeCrdt") panic("node must be of type *MeshNodeCrdt")
} }
crdt.Routes = make(map[string]interface{}) crdt.Routes = make(map[string]Route)
crdt.Services = make(map[string]string) crdt.Services = make(map[string]string)
crdt.Timestamp = time.Now().Unix() crdt.Timestamp = time.Now().Unix()
c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt) c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt)
} }
func (c *CrdtMeshManager) GetNodeIds() []string { func (c *CrdtMeshManager) isPeer(nodeId string) bool {
node, err := c.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return false
}
nodeType, err := node.Map().Get("type")
if err != nil || nodeType.Kind() != automerge.KindStr {
return false
}
return nodeType.Str() == string(conf.PEER_ROLE)
}
// isAlive: checks that the node's configuration has been updated
// since the rquired keep alive time
func (c *CrdtMeshManager) isAlive(nodeId string) bool {
node, err := c.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return false
}
timestamp, err := node.Map().Get("timestamp")
if err != nil || timestamp.Kind() != automerge.KindInt64 {
return false
}
keepAliveTime := timestamp.Int64()
return (time.Now().Unix() - keepAliveTime) < int64(c.conf.DeadTime)
}
func (c *CrdtMeshManager) GetPeers() []string {
keys, _ := c.doc.Path("nodes").Map().Keys() keys, _ := c.doc.Path("nodes").Map().Keys()
keys = lib.Filter(keys, func(publicKey string) bool {
return c.isPeer(publicKey) && c.isAlive(publicKey)
})
return keys return keys
} }
@ -55,7 +96,7 @@ func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
return nil, err return nil, err
} }
if c.cache == nil || len(changes) > 3 { if c.cache == nil || len(changes) > 0 {
c.lastCacheHash = c.LastHash c.lastCacheHash = c.LastHash
cache, err := automerge.As[*MeshCrdt](c.doc.Root()) cache, err := automerge.As[*MeshCrdt](c.doc.Root())
@ -113,7 +154,7 @@ func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, erro
// NodeExists: returns true if the node exists. Returns false // NodeExists: returns true if the node exists. Returns false
func (m *CrdtMeshManager) NodeExists(key string) bool { func (m *CrdtMeshManager) NodeExists(key string) bool {
node, err := m.doc.Path("nodes").Map().Get(key) node, err := m.doc.Path("nodes").Map().Get(key)
return node.Kind() == automerge.KindMap && err != nil return node.Kind() == automerge.KindMap && err == nil
} }
func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) { func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
@ -249,7 +290,8 @@ func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
return fmt.Errorf("AddService: services property does not exist in node") return fmt.Errorf("AddService: services property does not exist in node")
} }
return service.Map().Set(key, value) err = service.Map().Set(key, value)
return err
} }
func (m *CrdtMeshManager) RemoveService(nodeId, key string) error { func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
@ -279,7 +321,7 @@ func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
} }
// AddRoutes: adds routes to the specific nodeId // AddRoutes: adds routes to the specific nodeId
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error { func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
logging.Log.WriteInfof("Adding route to %s", nodeId) logging.Log.WriteInfof("Adding route to %s", nodeId)
@ -298,7 +340,32 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
} }
for _, route := range routes { for _, route := range routes {
err = routeMap.Map().Set(route, struct{}{}) prevRoute, err := routeMap.Map().Get(route.GetDestination().String())
if prevRoute.Kind() == automerge.KindVoid && err != nil {
path, err := prevRoute.Map().Get("path")
if err != nil {
return err
}
if path.Kind() != automerge.KindList {
return fmt.Errorf("path is not a list")
}
pathStr, err := automerge.As[[]string](path)
if err != nil {
return err
}
slices.Equal(route.GetPath(), pathStr)
}
err = routeMap.Map().Set(route.GetDestination().String(), Route{
Destination: route.GetDestination().String(),
Path: route.GetPath(),
})
if err != nil { if err != nil {
return err return err
@ -307,8 +374,82 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
return nil return nil
} }
func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
return nil, err
}
if nodeVal.Kind() != automerge.KindMap {
return nil, fmt.Errorf("node does not exist")
}
routeMap, err := nodeVal.Map().Get("routes")
if err != nil {
return nil, err
}
if routeMap.Kind() != automerge.KindMap {
return nil, fmt.Errorf("node %s is not a map", nodeId)
}
routes, err := automerge.As[map[string]Route](routeMap)
return lib.MapValues(routes), err
}
func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) {
node, err := m.GetNode(targetNode)
if err != nil {
return nil, err
}
routes := make(map[string]mesh.Route)
// Add routes that the node directly has
for _, route := range node.GetRoutes() {
routes[route.GetDestination().String()] = route
}
// Work out the other routes in the mesh
for _, node := range m.GetPeers() {
nodeRoutes, err := m.getRoutes(node)
if err != nil {
return nil, err
}
for _, route := range nodeRoutes {
otherRoute, ok := routes[route.GetDestination().String()]
hopCount := route.GetHopCount()
if node != targetNode {
hopCount += 1
}
if !ok || route.GetHopCount()+1 < otherRoute.GetHopCount() {
routes[route.GetDestination().String()] = &Route{
Destination: route.GetDestination().String(),
Path: append(route.Path, m.GetMeshId()),
}
}
}
}
return routes, nil
}
func (m *CrdtMeshManager) RemoveNode(nodeId string) error {
err := m.doc.Path("nodes").Map().Delete(nodeId)
return err
}
// DeleteRoutes deletes the specified routes // DeleteRoutes deletes the specified routes
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error { func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil { if err != nil {
@ -326,7 +467,7 @@ func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
} }
for _, route := range routes { for _, route := range routes {
err = routeMap.Map().Delete(route) err = routeMap.Map().Delete(route.GetDestination().String())
} }
return err return err
@ -336,54 +477,54 @@ func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
return NewAutomergeSync(m) return NewAutomergeSync(m)
} }
func (m *CrdtMeshManager) Prune(pruneTime int) error { func (m *CrdtMeshManager) Prune() error {
nodes, err := m.doc.Path("nodes").Get() // nodes, err := m.doc.Path("nodes").Get()
if err != nil { // if err != nil {
return err // return err
} // }
if nodes.Kind() != automerge.KindMap { // if nodes.Kind() != automerge.KindMap {
return errors.New("node must be a map") // return errors.New("node must be a map")
} // }
values, err := nodes.Map().Values() // values, err := nodes.Map().Values()
if err != nil { // if err != nil {
return err // return err
} // }
deletionNodes := make([]string, 0) // deletionNodes := make([]string, 0)
for nodeId, node := range values { // for nodeId, node := range values {
if node.Kind() != automerge.KindMap { // if node.Kind() != automerge.KindMap {
return errors.New("node must be a map") // return errors.New("node must be a map")
} // }
nodeMap := node.Map() // nodeMap := node.Map()
timeStamp, err := nodeMap.Get("timestamp") // timeStamp, err := nodeMap.Get("timestamp")
if err != nil { // if err != nil {
return err // return err
} // }
if timeStamp.Kind() != automerge.KindInt64 { // if timeStamp.Kind() != automerge.KindInt64 {
return errors.New("timestamp is not int64") // return errors.New("timestamp is not int64")
} // }
timeValue := timeStamp.Int64() // timeValue := timeStamp.Int64()
nowValue := time.Now().Unix() // nowValue := time.Now().Unix()
if nowValue-timeValue >= int64(pruneTime) { // if nowValue-timeValue >= int64(pruneTime) {
deletionNodes = append(deletionNodes, nodeId) // deletionNodes = append(deletionNodes, nodeId)
} // }
} // }
for _, node := range deletionNodes { // for _, node := range deletionNodes {
logging.Log.WriteInfof("Pruning %s", node) // logging.Log.WriteInfof("Pruning %s", node)
nodes.Map().Delete(node) // nodes.Map().Delete(node)
} // }
return nil return nil
} }
@ -408,7 +549,6 @@ func (m *MeshNodeCrdt) GetWgHost() *net.IPNet {
_, ipnet, err := net.ParseCIDR(m.WgHost) _, ipnet, err := net.ParseCIDR(m.WgHost)
if err != nil { if err != nil {
logging.Log.WriteErrorf("Cannot parse WgHost %s", err.Error())
return nil return nil
} }
@ -419,8 +559,13 @@ func (m *MeshNodeCrdt) GetTimeStamp() int64 {
return m.Timestamp return m.Timestamp
} }
func (m *MeshNodeCrdt) GetRoutes() []string { func (m *MeshNodeCrdt) GetRoutes() []mesh.Route {
return lib.MapKeys(m.Routes) return lib.Map(lib.MapValues(m.Routes), func(r Route) mesh.Route {
return &Route{
Destination: r.Destination,
Path: r.Path,
}
})
} }
func (m *MeshNodeCrdt) GetDescription() string { func (m *MeshNodeCrdt) GetDescription() string {
@ -431,7 +576,6 @@ func (m *MeshNodeCrdt) GetIdentifier() string {
ipv6 := m.WgHost[:len(m.WgHost)-4] ipv6 := m.WgHost[:len(m.WgHost)-4]
constituents := strings.Split(ipv6, ":") constituents := strings.Split(ipv6, ":")
logging.Log.WriteInfof(ipv6)
constituents = constituents[4:] constituents = constituents[4:]
return strings.Join(constituents, ":") return strings.Join(constituents, ":")
} }
@ -450,6 +594,12 @@ func (m *MeshNodeCrdt) GetServices() map[string]string {
return services return services
} }
// GetType refers to the type of the node. Peer means that the node is globally accessible
// Client means the node is only accessible through another peer
func (n *MeshNodeCrdt) GetType() conf.NodeType {
return conf.NodeType(n.Type)
}
func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode { func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
nodes := make(map[string]mesh.MeshNode) nodes := make(map[string]mesh.MeshNode)
@ -464,8 +614,22 @@ func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
Description: node.Description, Description: node.Description,
Alias: node.Alias, Alias: node.Alias,
Services: node.GetServices(), Services: node.GetServices(),
Type: node.Type,
} }
} }
return nodes return nodes
} }
func (r *Route) GetDestination() *net.IPNet {
_, ipnet, _ := net.ParseCIDR(r.Destination)
return ipnet
}
func (r *Route) GetHopCount() int {
return len(r.Path)
}
func (r *Route) GetPath() []string {
return r.Path
}

View File

@ -1,4 +1,4 @@
package crdt package automerge
import ( import (
"github.com/automerge/automerge-go" "github.com/automerge/automerge-go"

View File

@ -1,4 +1,4 @@
package crdt package automerge
import ( import (
"slices" "slices"

View File

@ -1,4 +1,4 @@
package crdt package automerge
import ( import (
"fmt" "fmt"
@ -28,16 +28,23 @@ type MeshNodeFactory struct {
func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode { func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode {
hostName := f.getAddress(params) hostName := f.getAddress(params)
grpcEndpoint := fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort)
if f.Config.Role == conf.CLIENT_ROLE {
grpcEndpoint = "-"
}
return &MeshNodeCrdt{ return &MeshNodeCrdt{
HostEndpoint: fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort), HostEndpoint: grpcEndpoint,
PublicKey: params.PublicKey.String(), PublicKey: params.PublicKey.String(),
WgEndpoint: fmt.Sprintf("%s:%d", hostName, params.WgPort), WgEndpoint: fmt.Sprintf("%s:%d", hostName, params.WgPort),
WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()), WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()),
// Always set the routes as empty. // Always set the routes as empty.
// Routes handled by external component // Routes handled by external component
Routes: map[string]interface{}{}, Routes: make(map[string]Route),
Description: "", Description: "",
Alias: "", Alias: "",
Type: string(f.Config.Role),
} }
} }
@ -50,7 +57,19 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string
} else if len(f.Config.Endpoint) != 0 { } else if len(f.Config.Endpoint) != 0 {
hostName = f.Config.Endpoint hostName = f.Config.Endpoint
} else { } else {
hostName = lib.GetOutboundIP().String() ipFunc := lib.GetPublicIP
if f.Config.IPDiscovery == conf.DNS_IP_DISCOVERY {
ipFunc = lib.GetOutboundIP
}
ip, err := ipFunc()
if err != nil {
return ""
}
hostName = ip.String()
} }
return hostName return hostName

View File

@ -1,16 +1,23 @@
package crdt package automerge
// Route: Represents a CRDT of the given route
type Route struct {
Destination string `automerge:"destination"`
Path []string `automerge:"path"`
}
// MeshNodeCrdt: Represents a CRDT for a mesh nodes // MeshNodeCrdt: Represents a CRDT for a mesh nodes
type MeshNodeCrdt struct { type MeshNodeCrdt struct {
HostEndpoint string `automerge:"hostEndpoint"` HostEndpoint string `automerge:"hostEndpoint"`
WgEndpoint string `automerge:"wgEndpoint"` WgEndpoint string `automerge:"wgEndpoint"`
PublicKey string `automerge:"publicKey"` PublicKey string `automerge:"publicKey"`
WgHost string `automerge:"wgHost"` WgHost string `automerge:"wgHost"`
Timestamp int64 `automerge:"timestamp"` Timestamp int64 `automerge:"timestamp"`
Routes map[string]interface{} `automerge:"routes"` Routes map[string]Route `automerge:"routes"`
Alias string `automerge:"alias"` Alias string `automerge:"alias"`
Description string `automerge:"description"` Description string `automerge:"description"`
Services map[string]string `automerge:"services"` Services map[string]string `automerge:"services"`
Type string `automerge:"type"`
} }
// MeshCrdt: Represents the mesh network as a whole // MeshCrdt: Represents the mesh network as a whole

View File

@ -16,6 +16,20 @@ func (m *WgMeshConfigurationError) Error() string {
return m.msg return m.msg
} }
type NodeType string
const (
PEER_ROLE NodeType = "peer"
CLIENT_ROLE NodeType = "client"
)
type IPDiscovery string
const (
PUBLIC_IP_DISCOVERY = "public"
DNS_IP_DISCOVERY = "dns"
)
type WgMeshConfiguration struct { type WgMeshConfiguration struct {
// CertificatePath is the path to the certificate to use in mTLS // CertificatePath is the path to the certificate to use in mTLS
CertificatePath string `yaml:"certificatePath"` CertificatePath string `yaml:"certificatePath"`
@ -28,8 +42,13 @@ type WgMeshConfiguration struct {
SkipCertVerification bool `yaml:"skipCertVerification"` SkipCertVerification bool `yaml:"skipCertVerification"`
// Port to run the GrpcServer on // Port to run the GrpcServer on
GrpcPort string `yaml:"gRPCPort"` GrpcPort string `yaml:"gRPCPort"`
// IPDIscovery: how to discover your IP if not specified. Use DNS server 8.8.8.8 or
// use public IP discovery library
IPDiscovery IPDiscovery `yaml:"ipDiscovery"`
// AdvertiseRoutes advertises other meshes if the node is in multiple meshes // AdvertiseRoutes advertises other meshes if the node is in multiple meshes
AdvertiseRoutes bool `yaml:"advertiseRoutes"` AdvertiseRoutes bool `yaml:"advertiseRoutes"`
// AdvertiseDefaultRoute advertises a default route out of the mesh.
AdvertiseDefaultRoute bool `yaml:"advertiseDefaults"`
// Endpoint is the IP in which this computer is publicly reachable. // Endpoint is the IP in which this computer is publicly reachable.
// usecase is when the node has multiple IP addresses // usecase is when the node has multiple IP addresses
Endpoint string `yaml:"publicEndpoint"` Endpoint string `yaml:"publicEndpoint"`
@ -47,12 +66,21 @@ type WgMeshConfiguration struct {
KeepAliveTime int `yaml:"keepAliveTime"` KeepAliveTime int `yaml:"keepAliveTime"`
// Timeout number of seconds before we consider the node as dead // Timeout number of seconds before we consider the node as dead
Timeout int `yaml:"timeout"` Timeout int `yaml:"timeout"`
// PruneTime number of seconds before we consider the 'node' as dead // PruneTime number of seconds before we remove nodes that are likely to be dead
PruneTime int `yaml:"pruneTime"` PruneTime int `yaml:"pruneTime"`
// DeadTime: number of seconds before we consider the node as dead and stop considering it
// when picking a random peer
DeadTime int `yaml:"deadTime"`
// Profile whether or not to include a http server that profiles the code // Profile whether or not to include a http server that profiles the code
Profile bool `yaml:"profile"` Profile bool `yaml:"profile"`
// StubWg whether or not to stub the WireGuard types // StubWg whether or not to stub the WireGuard types
StubWg bool `yaml:"stubWg"` StubWg bool `yaml:"stubWg"`
// Role specifies whether or not the user is globally accessible.
// If the user is globaly accessible they specify themselves as a client.
Role NodeType `yaml:"role"`
// KeepAliveWg configures the implementation so that we send keep alive packets to peers.
// KeepAlive can only be set if role is type client
KeepAliveWg int `yaml:"keepAliveWg"`
} }
func ValidateConfiguration(c *WgMeshConfiguration) error { func ValidateConfiguration(c *WgMeshConfiguration) error {
@ -122,9 +150,15 @@ func ValidateConfiguration(c *WgMeshConfiguration) error {
} }
} }
if c.PruneTime <= 1 { if c.PruneTime < 1 {
return &WgMeshConfigurationError{ return &WgMeshConfigurationError{
msg: "Prune time cannot be <= 1", msg: "Prune time cannot be < 1",
}
}
if c.DeadTime < 1 {
return &WgMeshConfigurationError{
msg: "Dead time cannot be < 1",
} }
} }
@ -134,6 +168,14 @@ func ValidateConfiguration(c *WgMeshConfiguration) error {
} }
} }
if c.Role == "" {
c.Role = PEER_ROLE
}
if c.IPDiscovery == "" {
c.IPDiscovery = PUBLIC_IP_DISCOVERY
}
return nil return nil
} }

508
pkg/crdt/datastore.go Normal file
View File

@ -0,0 +1,508 @@
package crdt
import (
"bytes"
"encoding/gob"
"fmt"
"net"
"slices"
"strings"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type Route struct {
Destination string
Path []string
}
// GetDestination implements mesh.Route.
func (r *Route) GetDestination() *net.IPNet {
_, ipnet, _ := net.ParseCIDR(r.Destination)
return ipnet
}
// GetHopCount implements mesh.Route.
func (r *Route) GetHopCount() int {
return len(r.Path)
}
// GetPath implements mesh.Route.
func (r *Route) GetPath() []string {
return r.Path
}
type MeshNode struct {
HostEndpoint string
WgEndpoint string
PublicKey string
WgHost string
Timestamp int64
Routes map[string]Route
Alias string
Description string
Services map[string]string
Type string
Tombstone bool
}
// Mark: marks the node is unreachable. This is not broadcast on
// syncrhonisation
func (m *TwoPhaseStoreMeshManager) Mark(nodeId string) {
m.store.Mark(nodeId)
}
// GetHostEndpoint: gets the gRPC endpoint of the node
func (n *MeshNode) GetHostEndpoint() string {
return n.HostEndpoint
}
// GetPublicKey: gets the public key of the node
func (n *MeshNode) GetPublicKey() (wgtypes.Key, error) {
return wgtypes.ParseKey(n.PublicKey)
}
// GetWgEndpoint(): get IP and port of the wireguard endpoint
func (n *MeshNode) GetWgEndpoint() string {
return n.WgEndpoint
}
// GetWgHost: get the IP address of the WireGuard node
func (n *MeshNode) GetWgHost() *net.IPNet {
_, ipnet, _ := net.ParseCIDR(n.WgHost)
return ipnet
}
// GetTimestamp: get the UNIX time stamp of the ndoe
func (n *MeshNode) GetTimeStamp() int64 {
return n.Timestamp
}
// GetRoutes: returns the routes that the nodes provides
func (n *MeshNode) GetRoutes() []mesh.Route {
routes := make([]mesh.Route, len(n.Routes))
for index, route := range lib.MapValues(n.Routes) {
routes[index] = &Route{
Destination: route.Destination,
Path: route.Path,
}
}
return routes
}
// GetIdentifier: returns the identifier of the node
func (m *MeshNode) GetIdentifier() string {
ipv6 := m.WgHost[:len(m.WgHost)-4]
constituents := strings.Split(ipv6, ":")
constituents = constituents[4:]
return strings.Join(constituents, ":")
}
// GetDescription: returns the description for this node
func (n *MeshNode) GetDescription() string {
return n.Description
}
// GetAlias: associates the node with an alias. Potentially used
// for DNS and so forth.
func (n *MeshNode) GetAlias() string {
return n.Alias
}
// GetServices: returns a list of services offered by the node
func (n *MeshNode) GetServices() map[string]string {
return n.Services
}
func (n *MeshNode) GetType() conf.NodeType {
return conf.NodeType(n.Type)
}
type MeshSnapshot struct {
Nodes map[string]MeshNode
}
// GetNodes() returns the nodes in the mesh
func (m *MeshSnapshot) GetNodes() map[string]mesh.MeshNode {
newMap := make(map[string]mesh.MeshNode)
for key, value := range m.Nodes {
newMap[key] = &MeshNode{
HostEndpoint: value.HostEndpoint,
PublicKey: value.PublicKey,
WgHost: value.WgHost,
WgEndpoint: value.WgEndpoint,
Timestamp: value.Timestamp,
Routes: value.Routes,
Alias: value.Alias,
Description: value.Description,
Services: value.Services,
Type: value.Type,
}
}
return newMap
}
type TwoPhaseStoreMeshManager struct {
MeshId string
IfName string
Client *wgctrl.Client
LastClock uint64
conf *conf.WgMeshConfiguration
store *TwoPhaseMap[string, MeshNode]
}
// AddNode() adds a node to the mesh
func (m *TwoPhaseStoreMeshManager) AddNode(node mesh.MeshNode) {
crdt, ok := node.(*MeshNode)
if !ok {
panic("node must be of type mesh node")
}
crdt.Routes = make(map[string]Route)
crdt.Services = make(map[string]string)
crdt.Timestamp = time.Now().Unix()
m.store.Put(crdt.PublicKey, *crdt)
}
// GetMesh() returns a snapshot of the mesh provided by the mesh provider.
func (m *TwoPhaseStoreMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
nodes := m.store.AsList()
snapshot := make(map[string]MeshNode)
for _, node := range nodes {
snapshot[node.PublicKey] = node
}
return &MeshSnapshot{
Nodes: snapshot,
}, nil
}
// GetMeshId() returns the ID of the mesh network
func (m *TwoPhaseStoreMeshManager) GetMeshId() string {
return m.MeshId
}
// Save() saves the mesh network
func (m *TwoPhaseStoreMeshManager) Save() []byte {
snapshot := m.store.Snapshot()
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(*snapshot)
if err != nil {
logging.Log.WriteInfof(err.Error())
}
return buf.Bytes()
}
// Load() loads a mesh network
func (m *TwoPhaseStoreMeshManager) Load(bs []byte) error {
buf := bytes.NewBuffer(bs)
dec := gob.NewDecoder(buf)
var snapshot TwoPhaseMapSnapshot[string, MeshNode]
err := dec.Decode(&snapshot)
m.store.Merge(snapshot)
return err
}
// GetDevice() get the device corresponding with the mesh
func (m *TwoPhaseStoreMeshManager) GetDevice() (*wgtypes.Device, error) {
dev, err := m.Client.Device(m.IfName)
if err != nil {
return nil, err
}
return dev, nil
}
// HasChanges returns true if we have changes since last time we synced
func (m *TwoPhaseStoreMeshManager) HasChanges() bool {
clockValue := m.store.GetHash()
return clockValue != m.LastClock
}
// Record that we have changes and save the corresponding changes
func (m *TwoPhaseStoreMeshManager) SaveChanges() {
clockValue := m.store.GetHash()
m.LastClock = clockValue
}
// UpdateTimeStamp: update the timestamp of the given node
func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
// Sort nodes by their public key
peers := m.GetPeers()
slices.Sort(peers)
if len(peers) == 0 {
return nil
}
peerToUpdate := peers[0]
if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.conf.KeepAliveTime) {
m.store.Mark(peerToUpdate)
if len(peers) < 2 {
return nil
}
peerToUpdate = peers[1]
}
if peerToUpdate != nodeId {
return nil
}
// Refresh causing node to update it's time stamp
node := m.store.Get(nodeId)
node.Timestamp = time.Now().Unix()
m.store.Put(nodeId, node)
return nil
}
// AddRoutes: adds routes to the given node
func (m *TwoPhaseStoreMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
if len(routes) == 0 {
return nil
}
node := m.store.Get(nodeId)
changes := false
for _, route := range routes {
prevRoute, ok := node.Routes[route.GetDestination().String()]
if !ok || route.GetHopCount() < prevRoute.GetHopCount() {
changes = true
node.Routes[route.GetDestination().String()] = Route{
Destination: route.GetDestination().String(),
Path: route.GetPath(),
}
}
}
if changes {
m.store.Put(nodeId, node)
}
return nil
}
// DeleteRoutes: deletes the routes from the node
func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
if len(routes) == 0 {
return nil
}
node := m.store.Get(nodeId)
changes := false
for _, route := range routes {
changes = true
delete(node.Routes, route.GetDestination().String())
}
if changes {
m.store.Put(nodeId, node)
}
return nil
}
// GetSyncer: returns the automerge syncer for sync
func (m *TwoPhaseStoreMeshManager) GetSyncer() mesh.MeshSyncer {
return NewTwoPhaseSyncer(m)
}
// GetNode get a particular not within the mesh
func (m *TwoPhaseStoreMeshManager) GetNode(nodeId string) (mesh.MeshNode, error) {
if !m.store.Contains(nodeId) {
return nil, fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
node := m.store.Get(nodeId)
return &node, nil
}
// NodeExists: returns true if a particular node exists false otherwise
func (m *TwoPhaseStoreMeshManager) NodeExists(nodeId string) bool {
return m.store.Contains(nodeId)
}
// SetDescription: sets the description of this automerge data type
func (m *TwoPhaseStoreMeshManager) SetDescription(nodeId string, description string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
node := m.store.Get(nodeId)
node.Description = description
m.store.Put(nodeId, node)
return nil
}
// SetAlias: set the alias of the nodeId
func (m *TwoPhaseStoreMeshManager) SetAlias(nodeId string, alias string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
node := m.store.Get(nodeId)
node.Description = alias
m.store.Put(nodeId, node)
return nil
}
// AddService: adds the service to the given node
func (m *TwoPhaseStoreMeshManager) AddService(nodeId string, key string, value string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
node := m.store.Get(nodeId)
node.Services[key] = value
m.store.Put(nodeId, node)
return nil
}
// RemoveService: removes the service form the node. throws an error if the service does not exist
func (m *TwoPhaseStoreMeshManager) RemoveService(nodeId string, key string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
node := m.store.Get(nodeId)
delete(node.Services, key)
m.store.Put(nodeId, node)
return nil
}
// Prune: prunes all nodes that have not updated their timestamp in
func (m *TwoPhaseStoreMeshManager) Prune() error {
m.store.Prune()
return nil
}
// GetPeers: get a list of contactable peers
func (m *TwoPhaseStoreMeshManager) GetPeers() []string {
nodes := m.store.AsList()
nodes = lib.Filter(nodes, func(mn MeshNode) bool {
if mn.Type != string(conf.PEER_ROLE) {
return false
}
// If the node is marked as unreachable don't consider it a peer.
// this help to optimize convergence time for unreachable nodes.
// However advertising it to other nodes could result in flapping.
if m.store.IsMarked(mn.PublicKey) {
return false
}
return true
})
return lib.Map(nodes, func(mn MeshNode) string {
return mn.PublicKey
})
}
func (m *TwoPhaseStoreMeshManager) getRoutes(targetNode string) (map[string]Route, error) {
if !m.store.Contains(targetNode) {
return nil, fmt.Errorf("getRoute: cannot get route %s does not exist", targetNode)
}
node := m.store.Get(targetNode)
return node.Routes, nil
}
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen
func (m *TwoPhaseStoreMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) {
node, err := m.GetNode(targetNode)
if err != nil {
return nil, err
}
routes := make(map[string]mesh.Route)
// Add routes that the node directly has
for _, route := range node.GetRoutes() {
routes[route.GetDestination().String()] = route
}
// Work out the other routes in the mesh
for _, node := range m.GetPeers() {
nodeRoutes, err := m.getRoutes(node)
if err != nil {
return nil, err
}
for _, route := range nodeRoutes {
otherRoute, ok := routes[route.GetDestination().String()]
hopCount := route.GetHopCount()
if node != targetNode {
hopCount += 1
}
if !ok || route.GetHopCount()+1 < otherRoute.GetHopCount() {
routes[route.GetDestination().String()] = &Route{
Destination: route.GetDestination().String(),
Path: append(route.GetPath(), m.GetMeshId()),
}
}
}
}
return routes, nil
}
// RemoveNode(): remove the node from the mesh
func (m *TwoPhaseStoreMeshManager) RemoveNode(nodeId string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
m.store.Remove(nodeId)
return nil
}

78
pkg/crdt/factory.go Normal file
View File

@ -0,0 +1,78 @@
package crdt
import (
"fmt"
"hash/fnv"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
)
type TwoPhaseMapFactory struct{}
func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) {
return &TwoPhaseStoreMeshManager{
MeshId: params.MeshId,
IfName: params.DevName,
Client: params.Client,
conf: params.Conf,
store: NewTwoPhaseMap[string, MeshNode](params.NodeID, func(s string) uint64 {
h := fnv.New64a()
h.Write([]byte(s))
return h.Sum64()
}, uint64(3*params.Conf.KeepAliveTime)),
}, nil
}
type MeshNodeFactory struct {
Config conf.WgMeshConfiguration
}
func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode {
hostName := f.getAddress(params)
grpcEndpoint := fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort)
if f.Config.Role == conf.CLIENT_ROLE {
grpcEndpoint = "-"
}
return &MeshNode{
HostEndpoint: grpcEndpoint,
PublicKey: params.PublicKey.String(),
WgEndpoint: fmt.Sprintf("%s:%d", hostName, params.WgPort),
WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()),
Routes: make(map[string]Route),
Description: "",
Alias: "",
Type: string(f.Config.Role),
}
}
// getAddress returns the routable address of the machine.
func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string {
var hostName string = ""
if params.Endpoint != "" {
hostName = params.Endpoint
} else if len(f.Config.Endpoint) != 0 {
hostName = f.Config.Endpoint
} else {
ipFunc := lib.GetPublicIP
if f.Config.IPDiscovery == conf.DNS_IP_DISCOVERY {
ipFunc = lib.GetOutboundIP
}
ip, err := ipFunc()
if err != nil {
return ""
}
hostName = ip.String()
}
return hostName
}

176
pkg/crdt/g_map.go Normal file
View File

@ -0,0 +1,176 @@
// crdt is a golang implementation of a crdt
package crdt
import (
"cmp"
"sync"
)
type Bucket[D any] struct {
Vector uint64
Contents D
Gravestone bool
}
// GMap is a set that can only grow in size
type GMap[K cmp.Ordered, D any] struct {
lock sync.RWMutex
contents map[uint64]Bucket[D]
clock *VectorClock[K]
}
func (g *GMap[K, D]) Put(key K, value D) {
g.lock.Lock()
clock := g.clock.IncrementClock()
g.contents[g.clock.hashFunc(key)] = Bucket[D]{
Vector: clock,
Contents: value,
}
g.lock.Unlock()
}
func (g *GMap[K, D]) Contains(key K) bool {
return g.contains(g.clock.hashFunc(key))
}
func (g *GMap[K, D]) contains(key uint64) bool {
g.lock.RLock()
_, ok := g.contents[key]
g.lock.RUnlock()
return ok
}
func (g *GMap[K, D]) put(key uint64, b Bucket[D]) {
g.lock.Lock()
if g.contents[key].Vector < b.Vector {
g.contents[key] = b
}
g.lock.Unlock()
}
func (g *GMap[K, D]) get(key uint64) Bucket[D] {
g.lock.RLock()
bucket := g.contents[key]
g.lock.RUnlock()
return bucket
}
func (g *GMap[K, D]) Get(key K) D {
return g.get(g.clock.hashFunc(key)).Contents
}
func (g *GMap[K, D]) Mark(key K) {
g.lock.Lock()
bucket := g.contents[g.clock.hashFunc(key)]
bucket.Gravestone = true
g.contents[g.clock.hashFunc(key)] = bucket
g.lock.Unlock()
}
// IsMarked: returns true if the node is marked
func (g *GMap[K, D]) IsMarked(key K) bool {
marked := false
g.lock.RLock()
bucket, ok := g.contents[g.clock.hashFunc(key)]
if ok {
marked = bucket.Gravestone
}
g.lock.RUnlock()
return marked
}
func (g *GMap[K, D]) Keys() []uint64 {
g.lock.RLock()
contents := make([]uint64, len(g.contents))
index := 0
for key := range g.contents {
contents[index] = key
index++
}
g.lock.RUnlock()
return contents
}
func (g *GMap[K, D]) Save() map[uint64]Bucket[D] {
buckets := make(map[uint64]Bucket[D])
g.lock.RLock()
for key, value := range g.contents {
buckets[key] = value
}
g.lock.RUnlock()
return buckets
}
func (g *GMap[K, D]) SaveWithKeys(keys []uint64) map[uint64]Bucket[D] {
buckets := make(map[uint64]Bucket[D])
g.lock.RLock()
for _, key := range keys {
buckets[key] = g.contents[key]
}
g.lock.RUnlock()
return buckets
}
func (g *GMap[K, D]) GetClock() map[uint64]uint64 {
clock := make(map[uint64]uint64)
g.lock.RLock()
for key, bucket := range g.contents {
clock[key] = bucket.Vector
}
g.lock.RUnlock()
return clock
}
func (g *GMap[K, D]) GetHash() uint64 {
hash := uint64(0)
g.lock.RLock()
for _, value := range g.contents {
hash += value.Vector
}
g.lock.RUnlock()
return hash
}
func (g *GMap[K, D]) Prune() {
stale := g.clock.getStale()
g.lock.Lock()
for _, outlier := range stale {
delete(g.contents, outlier)
}
g.lock.Unlock()
}
func NewGMap[K cmp.Ordered, D any](clock *VectorClock[K]) *GMap[K, D] {
return &GMap[K, D]{
contents: make(map[uint64]Bucket[D]),
clock: clock,
}
}

211
pkg/crdt/two_phase_map.go Normal file
View File

@ -0,0 +1,211 @@
package crdt
import (
"cmp"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
type TwoPhaseMap[K cmp.Ordered, D any] struct {
addMap *GMap[K, D]
removeMap *GMap[K, bool]
Clock *VectorClock[K]
processId K
}
type TwoPhaseMapSnapshot[K cmp.Ordered, D any] struct {
Add map[uint64]Bucket[D]
Remove map[uint64]Bucket[bool]
}
// Contains checks whether the value exists in the map
func (m *TwoPhaseMap[K, D]) Contains(key K) bool {
return m.contains(m.Clock.hashFunc(key))
}
// Contains checks whether the value exists in the map
func (m *TwoPhaseMap[K, D]) contains(key uint64) bool {
if !m.addMap.contains(key) {
return false
}
addValue := m.addMap.get(key)
if !m.removeMap.contains(key) {
return true
}
removeValue := m.removeMap.get(key)
return addValue.Vector >= removeValue.Vector
}
func (m *TwoPhaseMap[K, D]) Get(key K) D {
var result D
if !m.Contains(key) {
return result
}
return m.addMap.Get(key)
}
func (m *TwoPhaseMap[K, D]) get(key uint64) D {
var result D
if !m.contains(key) {
return result
}
return m.addMap.get(key).Contents
}
// Put places the key K in the map
func (m *TwoPhaseMap[K, D]) Put(key K, data D) {
msgSequence := m.Clock.IncrementClock()
m.Clock.Put(key, msgSequence)
m.addMap.Put(key, data)
}
func (m *TwoPhaseMap[K, D]) Mark(key K) {
m.addMap.Mark(key)
}
// Remove removes the value from the map
func (m *TwoPhaseMap[K, D]) Remove(key K) {
m.removeMap.Put(key, true)
}
func (m *TwoPhaseMap[K, D]) keys() []uint64 {
keys := make([]uint64, 0)
addKeys := m.addMap.Keys()
for _, key := range addKeys {
if !m.contains(key) {
continue
}
keys = append(keys, key)
}
return keys
}
func (m *TwoPhaseMap[K, D]) AsList() []D {
theList := make([]D, 0)
keys := m.keys()
for _, key := range keys {
theList = append(theList, m.get(key))
}
return theList
}
func (m *TwoPhaseMap[K, D]) Snapshot() *TwoPhaseMapSnapshot[K, D] {
return &TwoPhaseMapSnapshot[K, D]{
Add: m.addMap.Save(),
Remove: m.removeMap.Save(),
}
}
func (m *TwoPhaseMap[K, D]) SnapShotFromState(state *TwoPhaseMapState[K]) *TwoPhaseMapSnapshot[K, D] {
addKeys := lib.MapKeys(state.AddContents)
removeKeys := lib.MapKeys(state.RemoveContents)
return &TwoPhaseMapSnapshot[K, D]{
Add: m.addMap.SaveWithKeys(addKeys),
Remove: m.removeMap.SaveWithKeys(removeKeys),
}
}
type TwoPhaseMapState[K cmp.Ordered] struct {
Vectors map[uint64]uint64
AddContents map[uint64]uint64
RemoveContents map[uint64]uint64
}
func (m *TwoPhaseMap[K, D]) IsMarked(key K) bool {
return m.addMap.IsMarked(key)
}
// GetHash: Get the hash of the current state of the map
// Sums the current values of the vectors. Provides good approximation
// of increasing numbers
func (m *TwoPhaseMap[K, D]) GetHash() uint64 {
return (m.addMap.GetHash() + 1) * (m.removeMap.GetHash() + 1)
}
// GetState: get the current vector clock of the add and remove
// map
func (m *TwoPhaseMap[K, D]) GenerateMessage() *TwoPhaseMapState[K] {
addContents := m.addMap.GetClock()
removeContents := m.removeMap.GetClock()
return &TwoPhaseMapState[K]{
Vectors: m.Clock.GetClock(),
AddContents: addContents,
RemoveContents: removeContents,
}
}
func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMapState[K] {
mapState := &TwoPhaseMapState[K]{
AddContents: make(map[uint64]uint64),
RemoveContents: make(map[uint64]uint64),
}
for key, value := range state.AddContents {
otherValue, ok := m.AddContents[key]
if !ok || otherValue < value {
mapState.AddContents[key] = value
}
}
for key, value := range state.RemoveContents {
otherValue, ok := m.RemoveContents[key]
if !ok || otherValue < value {
mapState.RemoveContents[key] = value
}
}
return mapState
}
func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) {
for key, value := range snapshot.Add {
// Gravestone is local only to that node.
// Discover ourselves if the node is alive
m.addMap.put(key, value)
m.Clock.put(key, value.Vector)
}
for key, value := range snapshot.Remove {
m.removeMap.put(key, value)
m.Clock.put(key, value.Vector)
}
}
func (m *TwoPhaseMap[K, D]) Prune() {
m.addMap.Prune()
m.removeMap.Prune()
m.Clock.Prune()
}
// NewTwoPhaseMap: create a new two phase map. Consists of two maps
// a grow map and a remove map. If both timestamps equal then favour keeping
// it in the map
func NewTwoPhaseMap[K cmp.Ordered, D any](processId K, hashKey func(K) uint64, staleTime uint64) *TwoPhaseMap[K, D] {
m := TwoPhaseMap[K, D]{
processId: processId,
Clock: NewVectorClock(processId, hashKey, staleTime),
}
m.addMap = NewGMap[K, D](m.Clock)
m.removeMap = NewGMap[K, bool](m.Clock)
return &m
}

View File

@ -0,0 +1,187 @@
package crdt
import (
"bytes"
"encoding/gob"
logging "github.com/tim-beatham/wgmesh/pkg/log"
)
type SyncState int
const (
HASH SyncState = iota
PREPARE
PRESENT
EXCHANGE
MERGE
FINISHED
)
// TwoPhaseSyncer is a type to sync a TwoPhase data store
type TwoPhaseSyncer struct {
manager *TwoPhaseStoreMeshManager
generateMessageFSM SyncFSM
state SyncState
mapState *TwoPhaseMapState[string]
peerMsg []byte
}
type TwoPhaseHash struct {
Hash uint64
}
type SyncFSM map[SyncState]func(*TwoPhaseSyncer) ([]byte, bool)
func hash(syncer *TwoPhaseSyncer) ([]byte, bool) {
hash := TwoPhaseHash{
Hash: syncer.manager.store.Clock.GetHash(),
}
var buffer bytes.Buffer
enc := gob.NewEncoder(&buffer)
err := enc.Encode(hash)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
syncer.IncrementState()
return buffer.Bytes(), true
}
func prepare(syncer *TwoPhaseSyncer) ([]byte, bool) {
var recvBuffer = bytes.NewBuffer(syncer.peerMsg)
dec := gob.NewDecoder(recvBuffer)
var hash TwoPhaseHash
err := dec.Decode(&hash)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
// If vector clocks are equal then no need to merge state
// Helps to reduce bandwidth by detecting early
if hash.Hash == syncer.manager.store.Clock.GetHash() {
return nil, false
}
var buffer bytes.Buffer
enc := gob.NewEncoder(&buffer)
err = enc.Encode(*syncer.mapState)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
syncer.IncrementState()
return buffer.Bytes(), true
}
func present(syncer *TwoPhaseSyncer) ([]byte, bool) {
if syncer.peerMsg == nil {
panic("peer msg is nil")
}
var recvBuffer = bytes.NewBuffer(syncer.peerMsg)
dec := gob.NewDecoder(recvBuffer)
var mapState TwoPhaseMapState[string]
err := dec.Decode(&mapState)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
difference := syncer.mapState.Difference(&mapState)
syncer.manager.store.Clock.Merge(mapState.Vectors)
var sendBuffer bytes.Buffer
enc := gob.NewEncoder(&sendBuffer)
enc.Encode(*difference)
syncer.IncrementState()
return sendBuffer.Bytes(), true
}
func exchange(syncer *TwoPhaseSyncer) ([]byte, bool) {
if syncer.peerMsg == nil {
panic("peer msg is nil")
}
var recvBuffer = bytes.NewBuffer(syncer.peerMsg)
dec := gob.NewDecoder(recvBuffer)
var mapState TwoPhaseMapState[string]
dec.Decode(&mapState)
snapshot := syncer.manager.store.SnapShotFromState(&mapState)
var sendBuffer bytes.Buffer
enc := gob.NewEncoder(&sendBuffer)
enc.Encode(*snapshot)
syncer.IncrementState()
return sendBuffer.Bytes(), true
}
func merge(syncer *TwoPhaseSyncer) ([]byte, bool) {
if syncer.peerMsg == nil {
panic("peer msg is nil")
}
var recvBuffer = bytes.NewBuffer(syncer.peerMsg)
dec := gob.NewDecoder(recvBuffer)
var snapshot TwoPhaseMapSnapshot[string, MeshNode]
dec.Decode(&snapshot)
syncer.manager.store.Merge(snapshot)
return nil, false
}
func (t *TwoPhaseSyncer) IncrementState() {
t.state = min(t.state+1, FINISHED)
}
func (t *TwoPhaseSyncer) GenerateMessage() ([]byte, bool) {
fsmFunc, ok := t.generateMessageFSM[t.state]
if !ok {
panic("state not handled")
}
return fsmFunc(t)
}
func (t *TwoPhaseSyncer) RecvMessage(msg []byte) error {
t.peerMsg = msg
return nil
}
func (t *TwoPhaseSyncer) Complete() {
logging.Log.WriteInfof("SYNC COMPLETED")
if t.state >= MERGE {
t.manager.store.Clock.IncrementClock()
}
}
func NewTwoPhaseSyncer(manager *TwoPhaseStoreMeshManager) *TwoPhaseSyncer {
var generateMessageFsm SyncFSM = SyncFSM{
HASH: hash,
PREPARE: prepare,
PRESENT: present,
EXCHANGE: exchange,
MERGE: merge,
}
return &TwoPhaseSyncer{
manager: manager,
state: HASH,
mapState: manager.store.GenerateMessage(),
generateMessageFSM: generateMessageFsm,
}
}

154
pkg/crdt/vector_clock.go Normal file
View File

@ -0,0 +1,154 @@
package crdt
import (
"cmp"
"sync"
"time"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
type VectorBucket struct {
// clock current value of the node's clock
clock uint64
// lastUpdate we've seen
lastUpdate uint64
}
// Vector clock defines an abstract data type
// for a vector clock implementation
type VectorClock[K cmp.Ordered] struct {
vectors map[uint64]*VectorBucket
lock sync.RWMutex
processID K
staleTime uint64
hashFunc func(K) uint64
}
// IncrementClock: increments the node's value in the vector clock
func (m *VectorClock[K]) IncrementClock() uint64 {
maxClock := uint64(0)
m.lock.Lock()
for _, value := range m.vectors {
maxClock = max(maxClock, value.clock)
}
newBucket := VectorBucket{
clock: maxClock + 1,
lastUpdate: uint64(time.Now().Unix()),
}
m.vectors[m.hashFunc(m.processID)] = &newBucket
m.lock.Unlock()
return maxClock
}
// GetHash: gets the hash of the vector clock used to determine if there
// are any changes
func (m *VectorClock[K]) GetHash() uint64 {
m.lock.RLock()
hash := uint64(0)
for key, bucket := range m.vectors {
hash += key * (bucket.clock + 1)
}
m.lock.RUnlock()
return hash
}
func (m *VectorClock[K]) Merge(vectors map[uint64]uint64) {
for key, value := range vectors {
m.put(key, value)
}
}
// getStale: get all entries that are stale within the mesh
func (m *VectorClock[K]) getStale() []uint64 {
m.lock.RLock()
maxTimeStamp := lib.Reduce(0, lib.MapValues(m.vectors), func(i uint64, vb *VectorBucket) uint64 {
return max(i, vb.lastUpdate)
})
toRemove := make([]uint64, 0)
for key, bucket := range m.vectors {
if maxTimeStamp-bucket.lastUpdate > m.staleTime {
toRemove = append(toRemove, key)
}
}
m.lock.RUnlock()
return toRemove
}
func (m *VectorClock[K]) Prune() {
stale := m.getStale()
m.lock.Lock()
for _, key := range stale {
delete(m.vectors, key)
}
m.lock.Unlock()
}
func (m *VectorClock[K]) GetTimestamp(processId K) uint64 {
m.lock.RLock()
lastUpdate := m.vectors[m.hashFunc(m.processID)].lastUpdate
m.lock.RUnlock()
return lastUpdate
}
func (m *VectorClock[K]) Put(key K, value uint64) {
m.put(m.hashFunc(key), value)
}
func (m *VectorClock[K]) put(key uint64, value uint64) {
clockValue := uint64(0)
m.lock.Lock()
bucket, ok := m.vectors[key]
if ok {
clockValue = bucket.clock
}
if value > clockValue {
newBucket := VectorBucket{
clock: value,
lastUpdate: uint64(time.Now().Unix()),
}
m.vectors[key] = &newBucket
}
m.lock.Unlock()
}
func (m *VectorClock[K]) GetClock() map[uint64]uint64 {
clock := make(map[uint64]uint64)
m.lock.RLock()
for key, value := range m.vectors {
clock[key] = value.clock
}
m.lock.RUnlock()
return clock
}
func NewVectorClock[K cmp.Ordered](processID K, hashFunc func(K) uint64, staleTime uint64) *VectorClock[K] {
return &VectorClock[K]{
vectors: make(map[uint64]*VectorBucket),
processID: processID,
staleTime: staleTime,
hashFunc: hashFunc,
}
}

View File

@ -1,9 +1,9 @@
package ctrlserver package ctrlserver
import ( import (
crdt "github.com/tim-beatham/wgmesh/pkg/automerge"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn" "github.com/tim-beatham/wgmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/crdt"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/wgmesh/pkg/log"
@ -21,31 +21,33 @@ type NewCtrlServerParams struct {
CtrlProvider rpc.MeshCtrlServerServer CtrlProvider rpc.MeshCtrlServerServer
SyncProvider rpc.SyncServiceServer SyncProvider rpc.SyncServiceServer
Querier query.Querier Querier query.Querier
OnDelete func(mesh.MeshProvider)
} }
// Create a new instance of the MeshCtrlServer or error if the // Create a new instance of the MeshCtrlServer or error if the
// operation failed // operation failed
func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) { func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
ctrlServer := new(MeshCtrlServer) ctrlServer := new(MeshCtrlServer)
meshFactory := crdt.CrdtProviderFactory{} meshFactory := &crdt.TwoPhaseMapFactory{}
nodeFactory := crdt.MeshNodeFactory{ nodeFactory := &crdt.MeshNodeFactory{
Config: *params.Conf, Config: *params.Conf,
} }
idGenerator := &lib.UUIDGenerator{} idGenerator := &lib.IDNameGenerator{}
ipAllocator := &ip.ULABuilder{} ipAllocator := &ip.ULABuilder{}
interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client) interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client)
configApplyer := mesh.NewWgMeshConfigApplyer() configApplyer := mesh.NewWgMeshConfigApplyer(params.Conf)
meshManagerParams := &mesh.NewMeshManagerParams{ meshManagerParams := &mesh.NewMeshManagerParams{
Conf: *params.Conf, Conf: *params.Conf,
Client: params.Client, Client: params.Client,
MeshProvider: &meshFactory, MeshProvider: meshFactory,
NodeFactory: &nodeFactory, NodeFactory: nodeFactory,
IdGenerator: idGenerator, IdGenerator: idGenerator,
IPAllocator: ipAllocator, IPAllocator: ipAllocator,
InterfaceManipulator: interfaceManipulator, InterfaceManipulator: interfaceManipulator,
ConfigApplyer: configApplyer, ConfigApplyer: configApplyer,
OnDelete: params.OnDelete,
} }
ctrlServer.MeshManager = mesh.NewMeshManager(meshManagerParams) ctrlServer.MeshManager = mesh.NewMeshManager(meshManagerParams)

View File

@ -9,6 +9,11 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
type MeshRoute struct {
Destination string
Path []string
}
// Represents a WireGuard MeshNode // Represents a WireGuard MeshNode
type MeshNode struct { type MeshNode struct {
HostEndpoint string HostEndpoint string
@ -16,7 +21,7 @@ type MeshNode struct {
PublicKey string PublicKey string
WgHost string WgHost string
Timestamp int64 Timestamp int64
Routes []string Routes []MeshRoute
Description string Description string
Alias string Alias string
Services map[string]string Services map[string]string

114
pkg/dns/dns.go Normal file
View File

@ -0,0 +1,114 @@
package smegdns
import (
"encoding/json"
"fmt"
"net"
"net/rpc"
"github.com/miekg/dns"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/query"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
const MeshRegularExpression = `(?P<meshId>.+)\.(?P<alias>.+)\.smeg\.`
type DNSHandler struct {
client *rpc.Client
server *dns.Server
}
// queryMesh: queries the mesh network for the given meshId and node
// with alias
func (d *DNSHandler) queryMesh(meshId, alias string) net.IP {
var reply string
err := d.client.Call("IpcHandler.Query", &ipc.QueryMesh{
MeshId: meshId,
Query: fmt.Sprintf("[?alias == '%s'] | [0]", alias),
}, &reply)
if err != nil {
return nil
}
var node *query.QueryNode
err = json.Unmarshal([]byte(reply), &node)
if err != nil || node == nil {
return nil
}
ip, _, _ := net.ParseCIDR(node.WgHost)
return ip
}
func (d *DNSHandler) handleQuery(m *dns.Msg) {
for _, q := range m.Question {
switch q.Qtype {
case dns.TypeAAAA:
logging.Log.WriteInfof("Query for %s", q.Name)
groups := lib.MatchCaptureGroup(MeshRegularExpression, q.Name)
if len(groups) == 0 {
continue
}
ip := d.queryMesh(groups["meshId"], groups["alias"])
rr, err := dns.NewRR(fmt.Sprintf("%s AAAA %s", q.Name, ip))
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
}
}
func (h *DNSHandler) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
msg := new(dns.Msg)
msg.SetReply(r)
msg.Authoritative = true
switch r.Opcode {
case dns.OpcodeQuery:
h.handleQuery(msg)
}
w.WriteMsg(msg)
}
func (h *DNSHandler) Listen() error {
return h.server.ListenAndServe()
}
func (h *DNSHandler) Close() error {
return h.server.Shutdown()
}
func NewDns(udpPort int) (*DNSHandler, error) {
client, err := rpc.DialHTTP("unix", SockAddr)
if err != nil {
return nil, err
}
dnsHander := DNSHandler{
client: client,
}
dns.HandleFunc("smeg.", dnsHander.handleDnsRequest)
dnsHander.server = &dns.Server{Addr: fmt.Sprintf(":%d", udpPort), Net: "udp"}
return &dnsHander, nil
}

View File

@ -4,13 +4,13 @@ package rpctypes;
option go_package = "pkg/rpc"; option go_package = "pkg/rpc";
service MeshCtrlServer { service MeshCtrlServer {
rpc JoinMesh(JoinMeshRequest) returns (JoinMeshReply) {} rpc GetMesh(GetMeshRequest) returns (GetMeshReply) {}
} }
message JoinMeshRequest { message GetMeshRequest {
string meshId = 2; string meshId = 1;
} }
message JoinMeshReply { message GetMeshReply {
bool success = 1; bytes mesh = 1;
} }

View File

@ -11,8 +11,6 @@ import (
) )
type NewMeshArgs struct { type NewMeshArgs struct {
// IfName is the interface that the mesh instance will run on
IfName string
// WgPort is the WireGuard port to expose // WgPort is the WireGuard port to expose
WgPort int WgPort int
// Endpoint is the routable alias of the machine. Can be an IP // Endpoint is the routable alias of the machine. Can be an IP
@ -25,13 +23,14 @@ type JoinMeshArgs struct {
MeshId string MeshId string
// IpAddress is a routable IP in another mesh // IpAddress is a routable IP in another mesh
IpAdress string IpAdress string
// IfName is the interface name of the mesh
IfName string
// Port is the WireGuard port to expose // Port is the WireGuard port to expose
Port int Port int
// Endpoint is the routable address of this machine. If not provided // Endpoint is the routable address of this machine. If not provided
// defaults to the default address // defaults to the default address
Endpoint string Endpoint string
// Client specifies whether we should join as a client of the peer
// we are connecting to
Client bool
} }
type PutServiceArgs struct { type PutServiceArgs struct {
@ -63,7 +62,6 @@ type MeshIpc interface {
JoinMesh(args JoinMeshArgs, reply *string) error JoinMesh(args JoinMeshArgs, reply *string) error
LeaveMesh(meshId string, reply *string) error LeaveMesh(meshId string, reply *string) error
GetMesh(meshId string, reply *GetMeshReply) error GetMesh(meshId string, reply *GetMeshReply) error
EnableInterface(meshId string, reply *string) error
GetDOT(meshId string, reply *string) error GetDOT(meshId string, reply *string) error
Query(query QueryMesh, reply *string) error Query(query QueryMesh, reply *string) error
PutDescription(description string, reply *string) error PutDescription(description string, reply *string) error

View File

@ -1,11 +1,34 @@
package lib package lib
import "cmp"
// MapToSlice converts a map to a slice in go // MapToSlice converts a map to a slice in go
func MapValues[K comparable, V any](m map[K]V) []V { func MapValues[K cmp.Ordered, V any](m map[K]V) []V {
return MapValuesWithExclude(m, map[K]struct{}{}) return MapValuesWithExclude(m, map[K]struct{}{})
} }
func MapValuesWithExclude[K comparable, V any](m map[K]V, exclude map[K]struct{}) []V { type MapItemsEntry[K cmp.Ordered, V any] struct {
Key K
Value V
}
func MapItems[K cmp.Ordered, V any](m map[K]V) []MapItemsEntry[K, V] {
keys := MapKeys(m)
values := MapValues(m)
vs := make([]MapItemsEntry[K, V], len(keys))
for index, _ := range keys {
vs[index] = MapItemsEntry[K, V]{
Key: keys[index],
Value: values[index],
}
}
return vs
}
func MapValuesWithExclude[K cmp.Ordered, V any](m map[K]V, exclude map[K]struct{}) []V {
values := make([]V, len(m)-len(exclude)) values := make([]V, len(m)-len(exclude))
i := 0 i := 0
@ -26,7 +49,7 @@ func MapValuesWithExclude[K comparable, V any](m map[K]V, exclude map[K]struct{}
return values return values
} }
func MapKeys[K comparable, V any](m map[K]V) []K { func MapKeys[K cmp.Ordered, V any](m map[K]V) []K {
values := make([]K, len(m)) values := make([]K, len(m))
i := 0 i := 0
@ -66,3 +89,23 @@ func Filter[V any](list []V, f filterFunc[V]) []V {
return newList return newList
} }
func Contains[V any](list []V, proposition func(V) bool) bool {
for _, elem := range list {
if proposition(elem) {
return true
}
}
return false
}
func Reduce[A any, V any](start A, values []V, reduce func(A, V) A) A {
accum := start
for _, elem := range values {
accum = reduce(accum, elem)
}
return accum
}

46
pkg/lib/hashing.go Normal file
View File

@ -0,0 +1,46 @@
package lib
import (
"hash/fnv"
"sort"
)
type consistentHashRecord[V any] struct {
record V
value int
}
func HashString(value string) int {
f := fnv.New32a()
f.Write([]byte(value))
return int(f.Sum32())
}
// ConsistentHash implementation. Traverse the values until we find a key
// less than ours.
func ConsistentHash[V any, K any](values []V, client K, bucketFunc func(V) int, keyFunc func(K) int) V {
if len(values) == 0 {
panic("values is empty")
}
vs := Map(values, func(v V) consistentHashRecord[V] {
return consistentHashRecord[V]{
v,
bucketFunc(v),
}
})
sort.SliceStable(vs, func(i, j int) bool {
return vs[i].value < vs[j].value
})
ourKey := keyFunc(client)
for _, record := range vs {
if ourKey < record.value {
return record.record
}
}
return vs[0].record
}

View File

@ -1,6 +1,9 @@
package lib package lib
import "github.com/google/uuid" import (
"github.com/anandvarma/namegen"
"github.com/google/uuid"
)
// IdGenerator generates unique ids // IdGenerator generates unique ids
type IdGenerator interface { type IdGenerator interface {
@ -15,3 +18,11 @@ func (g *UUIDGenerator) GetId() (string, error) {
id := uuid.New() id := uuid.New()
return id.String(), nil return id.String(), nil
} }
type IDNameGenerator struct {
}
func (i *IDNameGenerator) GetId() (string, error) {
name_schema := namegen.New()
return name_schema.Get(), nil
}

View File

@ -1,17 +1,61 @@
package lib package lib
import ( import (
"encoding/json"
"io"
"log" "log"
"net" "net"
"net/http"
) )
// GetOutboundIP: gets the oubound IP of this packet // GetOutboundIP: gets the oubound IP of this packet
func GetOutboundIP() net.IP { func GetOutboundIP() (net.IP, error) {
conn, err := net.Dial("udp", "8.8.8.8:80") conn, err := net.Dial("udp", "8.8.8.8:80")
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
defer conn.Close() defer conn.Close()
localAddr := conn.LocalAddr().(*net.UDPAddr) localAddr := conn.LocalAddr().(*net.UDPAddr)
return localAddr.IP return localAddr.IP, nil
}
const IP_SERVICE = "https://api.ipify.org?format=json"
type IpResponse struct {
Ip string `json:"ip"`
}
func (i *IpResponse) GetIP() net.IP {
return net.ParseIP(i.Ip)
}
// GetPublicIP: get the nodes public IP address. For when a node is behind NAT
func GetPublicIP() (net.IP, error) {
req, err := http.NewRequest(http.MethodGet, IP_SERVICE, nil)
if err != nil {
return nil, err
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
resBody, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
var jsonResponse IpResponse
err = json.Unmarshal([]byte(resBody), &jsonResponse)
if err != nil {
return nil, err
}
return jsonResponse.GetIP(), nil
} }

19
pkg/lib/regex.go Normal file
View File

@ -0,0 +1,19 @@
package lib
import "regexp"
func MatchCaptureGroup(pattern, payload string) map[string]string {
patterns := make(map[string]string)
expr := regexp.MustCompile(pattern)
match := expr.FindStringSubmatch(payload)
for i, name := range expr.SubexpNames() {
if i != 0 && name != "" {
patterns[name] = match[i]
}
}
return patterns
}

View File

@ -140,26 +140,38 @@ func (c *RtNetlinkConfig) AddRoute(ifName string, route Route) error {
family = unix.AF_INET family = unix.AF_INET
} }
attr := rtnetlink.RouteAttributes{ routes, err := c.listRoutes(ifName, family)
Dst: dst.IP,
OutIface: uint32(iface.Index),
Gateway: gw,
}
ones, _ := dst.Mask.Size()
err = c.conn.Route.Replace(&rtnetlink.RouteMessage{
Family: family,
Table: unix.RT_TABLE_MAIN,
Protocol: unix.RTPROT_BOOT,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
DstLength: uint8(ones),
Attributes: attr,
})
if err != nil { if err != nil {
return fmt.Errorf("failed to add route %w", err) return err
}
// If it already exists no need to add the route
if !Contains(routes, func(prevRoute rtnetlink.RouteMessage) bool {
return prevRoute.Attributes.Dst.Equal(route.Destination.IP) &&
prevRoute.Attributes.Gateway.Equal(route.Gateway)
}) {
attr := rtnetlink.RouteAttributes{
Dst: dst.IP,
OutIface: uint32(iface.Index),
Gateway: gw,
}
ones, _ := dst.Mask.Size()
err = c.conn.Route.Replace(&rtnetlink.RouteMessage{
Family: family,
Table: unix.RT_TABLE_MAIN,
Protocol: unix.RTPROT_BOOT,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
DstLength: uint8(ones),
Attributes: attr,
})
if err != nil {
return fmt.Errorf("failed to add route %w", err)
}
} }
return nil return nil
@ -201,7 +213,7 @@ func (c *RtNetlinkConfig) DeleteRoute(ifName string, route Route) error {
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to delete route %w", err) return fmt.Errorf("failed to delete route %s", dst.IP.String())
} }
return nil return nil
@ -219,22 +231,15 @@ func (r1 Route) equal(r2 Route) bool {
// DeleteRoutes deletes all routes not in exclude // DeleteRoutes deletes all routes not in exclude
func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...Route) error { func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...Route) error {
routes := make([]rtnetlink.RouteMessage, 0) routes, err := c.listRoutes(ifName, family)
if len(exclude) != 0 { if err != nil {
lRoutes, err := c.listRoutes(ifName, family, exclude[0].Gateway) return err
if err != nil {
return err
}
routes = lRoutes
} }
ifRoutes := make([]Route, 0) ifRoutes := make([]Route, 0)
for _, rtRoute := range routes { for _, rtRoute := range routes {
logging.Log.WriteInfof("Routes: %s", rtRoute.Attributes.Dst.String())
maskSize := 128 maskSize := 128
if family == unix.AF_INET { if family == unix.AF_INET {
@ -255,6 +260,14 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
if route.equal(r) { if route.equal(r) {
return false return false
} }
if family == unix.AF_INET && route.Destination.IP.To4() == nil {
return false
}
if family == unix.AF_INET6 && route.Destination.IP.To16() == nil {
return false
}
} }
return true return true
} }
@ -262,7 +275,7 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
toDelete := Filter(ifRoutes, shouldExclude) toDelete := Filter(ifRoutes, shouldExclude)
for _, route := range toDelete { for _, route := range toDelete {
logging.Log.WriteInfof("Deleting route %s", route.Destination.String()) logging.Log.WriteInfof("Deleting route: %s", route.Destination.String())
err := c.DeleteRoute(ifName, route) err := c.DeleteRoute(ifName, route)
if err != nil { if err != nil {
@ -274,7 +287,7 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
} }
// listRoutes lists all routes on the interface // listRoutes lists all routes on the interface
func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8, gateway net.IP) ([]rtnetlink.RouteMessage, error) { func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8) ([]rtnetlink.RouteMessage, error) {
iface, err := net.InterfaceByName(ifName) iface, err := net.InterfaceByName(ifName)
if err != nil { if err != nil {
@ -288,7 +301,7 @@ func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8, gateway net.IP
} }
filterFunc := func(r rtnetlink.RouteMessage) bool { filterFunc := func(r rtnetlink.RouteMessage) bool {
return r.Attributes.Gateway.Equal(gateway) && r.Attributes.OutIface == uint32(iface.Index) return r.Attributes.Gateway != nil && r.Attributes.OutIface == uint32(iface.Index)
} }
routes = Filter(routes, filterFunc) routes = Filter(routes, filterFunc)

40
pkg/lib/stats.go Normal file
View File

@ -0,0 +1,40 @@
// lib contains helper functions for the implementation
package lib
import (
"cmp"
"math"
"gonum.org/v1/gonum/stat"
"gonum.org/v1/gonum/stat/distuv"
)
// Modelling the distribution using a normal distribution get the count
// of the outliers
func GetOutliers[K cmp.Ordered](counts map[K]uint64, alpha float64) []K {
n := float64(len(counts))
keys := MapKeys(counts)
values := make([]float64, len(keys))
for index, key := range keys {
values[index] = float64(counts[key])
}
mean := stat.Mean(values, nil)
stdDev := stat.StdDev(values, nil)
moe := distuv.Normal{Mu: 0, Sigma: 1}.Quantile(1-alpha/2) * (stdDev / math.Sqrt(n))
lowerBound := mean - moe
var outliers []K
for i, count := range values {
if count < lowerBound {
outliers = append(outliers, keys[i])
}
}
return outliers
}

View File

@ -3,7 +3,14 @@ package mesh
import ( import (
"fmt" "fmt"
"net" "net"
"slices"
"strings"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/route"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@ -16,15 +23,21 @@ type MeshConfigApplyer interface {
// WgMeshConfigApplyer applies WireGuard configuration // WgMeshConfigApplyer applies WireGuard configuration
type WgMeshConfigApplyer struct { type WgMeshConfigApplyer struct {
meshManager MeshManager meshManager MeshManager
config *conf.WgMeshConfiguration
routeInstaller route.RouteInstaller
hashFunc func(MeshNode) int
} }
func convertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) { type routeNode struct {
endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint()) gateway string
route Route
}
if err != nil { func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, self MeshNode,
return nil, err device *wgtypes.Device,
} peerToClients map[string][]net.IPNet,
routes map[string][]routeNode) (*wgtypes.PeerConfig, error) {
pubKey, err := node.GetPublicKey() pubKey, err := node.GetPublicKey()
@ -35,59 +48,392 @@ func convertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) {
allowedips := make([]net.IPNet, 1) allowedips := make([]net.IPNet, 1)
allowedips[0] = *node.GetWgHost() allowedips[0] = *node.GetWgHost()
clients, ok := peerToClients[pubKey.String()]
if ok {
allowedips = append(allowedips, clients...)
}
for _, route := range node.GetRoutes() { for _, route := range node.GetRoutes() {
_, ipnet, _ := net.ParseCIDR(route) bestRoutes := routes[route.GetDestination().String()]
allowedips = append(allowedips, *ipnet) var pickedRoute routeNode
if len(bestRoutes) == 1 {
pickedRoute = bestRoutes[0]
} else if len(bestRoutes) > 1 {
bucketFunc := func(rn routeNode) int {
return lib.HashString(rn.gateway)
}
// Else there is more than one candidate so consistently hash
pickedRoute = lib.ConsistentHash(bestRoutes, self, bucketFunc, m.hashFunc)
}
if pickedRoute.gateway == pubKey.String() {
allowedips = append(allowedips, *pickedRoute.route.GetDestination())
}
}
keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second
existing := slices.IndexFunc(device.Peers, func(p wgtypes.Peer) bool {
pubKey, _ := node.GetPublicKey()
return p.PublicKey.String() == pubKey.String()
})
endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint())
if err != nil {
return nil, err
}
// Don't override the existing IP in case it already exists
if existing != -1 {
endpoint = device.Peers[existing].Endpoint
} }
peerConfig := wgtypes.PeerConfig{ peerConfig := wgtypes.PeerConfig{
PublicKey: pubKey, PublicKey: pubKey,
Endpoint: endpoint, Endpoint: endpoint,
AllowedIPs: allowedips, AllowedIPs: allowedips,
PersistentKeepaliveInterval: &keepAlive,
ReplaceAllowedIPs: true,
} }
return &peerConfig, nil return &peerConfig, nil
} }
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { // getRoutes: finds the routes with the least hop distance. If more than one route exists
snap, err := mesh.GetMesh() // consistently hash to evenly spread the distribution of traffic
func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]routeNode {
mesh, _ := meshProvider.GetMesh()
routes := make(map[string][]routeNode)
if err != nil { peers := lib.Filter(lib.MapValues(mesh.GetNodes()), func(p MeshNode) bool {
return err return p.GetType() == conf.PEER_ROLE
})
meshPrefixes := lib.Map(lib.MapValues(m.meshManager.GetMeshes()), func(mesh MeshProvider) *net.IPNet {
ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
return ipNet
})
for _, node := range mesh.GetNodes() {
pubKey, _ := node.GetPublicKey()
for _, route := range node.GetRoutes() {
if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool {
v6Default, _, _ := net.ParseCIDR("::/0")
v4Default, _, _ := net.ParseCIDR("0.0.0.0/0")
if (prefix.IP.Equal(v6Default) || prefix.IP.Equal(v4Default)) && m.config.AdvertiseDefaultRoute {
return true
}
return prefix.Contains(route.GetDestination().IP)
}) {
continue
}
destination := route.GetDestination().String()
otherRoute, ok := routes[destination]
rn := routeNode{
gateway: pubKey.String(),
route: route,
}
// Client's only acessible by another peer
if node.GetType() == conf.CLIENT_ROLE {
peer := m.getCorrespondingPeer(peers, node)
self, _ := m.meshManager.GetSelf(meshProvider.GetMeshId())
// If the node isn't the self use that peer as the gateway
if !NodeEquals(peer, self) {
peerPub, _ := peer.GetPublicKey()
rn.gateway = peerPub.String()
rn.route = &RouteStub{
Destination: rn.route.GetDestination(),
HopCount: rn.route.GetHopCount() + 1,
// Append the path to this peer
Path: append(rn.route.GetPath(), peer.GetWgHost().IP.String()),
}
}
}
if !ok {
otherRoute = make([]routeNode, 1)
otherRoute[0] = rn
routes[destination] = otherRoute
} else if route.GetHopCount() < otherRoute[0].route.GetHopCount() {
otherRoute[0] = rn
} else if otherRoute[0].route.GetHopCount() == route.GetHopCount() {
routes[destination] = append(otherRoute, rn)
}
}
} }
nodes := snap.GetNodes() return routes
peerConfigs := make([]wgtypes.PeerConfig, len(nodes)) }
var count int = 0 // getCorrespondignPeer: gets the peer corresponding to the client
func (m *WgMeshConfigApplyer) getCorrespondingPeer(peers []MeshNode, client MeshNode) MeshNode {
peer := lib.ConsistentHash(peers, client, m.hashFunc, m.hashFunc)
return peer
}
for _, n := range nodes { func (m *WgMeshConfigApplyer) getPeerCfgsToRemove(dev *wgtypes.Device, newPeers []wgtypes.PeerConfig) []wgtypes.PeerConfig {
peer, err := convertMeshNode(n) peers := dev.Peers
peers = lib.Filter(peers, func(p1 wgtypes.Peer) bool {
return !lib.Contains(newPeers, func(p2 wgtypes.PeerConfig) bool {
return p1.PublicKey.String() == p2.PublicKey.String()
})
})
if err != nil { return lib.Map(peers, func(p wgtypes.Peer) wgtypes.PeerConfig {
return err return wgtypes.PeerConfig{
PublicKey: p.PublicKey,
Remove: true,
}
})
}
type GetConfigParams struct {
mesh MeshProvider
peers []MeshNode
clients []MeshNode
dev *wgtypes.Device
routes map[string][]routeNode
}
func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes.Config, error) {
self, err := m.meshManager.GetSelf(params.mesh.GetMeshId())
ula := &ip.ULABuilder{}
meshNet, _ := ula.GetIPNet(params.mesh.GetMeshId())
routesForMesh := lib.Map(lib.MapValues(params.routes), func(rns []routeNode) []routeNode {
return lib.Filter(rns, func(rn routeNode) bool {
ip, _, _ := net.ParseCIDR(rn.gateway)
return meshNet.Contains(ip)
})
})
routes := lib.Map(routesForMesh, func(rs []routeNode) net.IPNet {
return *rs[0].route.GetDestination()
})
routes = append(routes, *meshNet)
if err != nil {
return nil, err
}
peer := m.getCorrespondingPeer(params.peers, self)
pubKey, _ := peer.GetPublicKey()
keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second
endpoint, err := net.ResolveUDPAddr("udp", peer.GetWgEndpoint())
if err != nil {
return nil, err
}
peerCfgs := make([]wgtypes.PeerConfig, 1)
peerCfgs[0] = wgtypes.PeerConfig{
PublicKey: pubKey,
Endpoint: endpoint,
PersistentKeepaliveInterval: &keepAlive,
AllowedIPs: routes,
ReplaceAllowedIPs: true,
}
installedRoutes := make([]lib.Route, 0)
for _, route := range peerCfgs[0].AllowedIPs {
installedRoutes = append(installedRoutes, lib.Route{
Gateway: peer.GetWgHost().IP,
Destination: route,
})
}
cfg := wgtypes.Config{
Peers: peerCfgs,
}
m.routeInstaller.InstallRoutes(params.dev.Name, installedRoutes...)
return &cfg, err
}
func (m *WgMeshConfigApplyer) getRoutesToInstall(wgNode *wgtypes.PeerConfig, mesh MeshProvider, node MeshNode) []lib.Route {
routes := make([]lib.Route, 0)
for _, route := range wgNode.AllowedIPs {
ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
_, defaultRoute, _ := net.ParseCIDR("::/0")
if !ipNet.Contains(route.IP) && !ipNet.IP.Equal(defaultRoute.IP) {
routes = append(routes, lib.Route{
Gateway: node.GetWgHost().IP,
Destination: route,
})
}
}
return routes
}
func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.Config, error) {
peerToClients := make(map[string][]net.IPNet)
installedRoutes := make([]lib.Route, 0)
peerConfigs := make([]wgtypes.PeerConfig, 0)
self, err := m.meshManager.GetSelf(params.mesh.GetMeshId())
if err != nil {
return nil, err
}
for _, n := range params.clients {
if len(params.peers) > 0 {
peer := m.getCorrespondingPeer(params.peers, n)
pubKey, _ := peer.GetPublicKey()
clients, ok := peerToClients[pubKey.String()]
if !ok {
clients = make([]net.IPNet, 0)
peerToClients[pubKey.String()] = clients
}
peerToClients[pubKey.String()] = append(clients, *n.GetWgHost())
if NodeEquals(self, peer) {
cfg, err := m.convertMeshNode(n, self, params.dev, peerToClients, params.routes)
if err != nil {
return nil, err
}
installedRoutes = append(installedRoutes, m.getRoutesToInstall(cfg, params.mesh, n)...)
peerConfigs = append(peerConfigs, *cfg)
}
}
}
for _, n := range params.peers {
if NodeEquals(n, self) {
continue
} }
peerConfigs[count] = *peer peer, err := m.convertMeshNode(n, self, params.dev, peerToClients, params.routes)
count++
if err != nil {
return nil, err
}
installedRoutes = append(installedRoutes, m.getRoutesToInstall(peer, params.mesh, n)...)
peerConfigs = append(peerConfigs, *peer)
} }
cfg := wgtypes.Config{ cfg := wgtypes.Config{
Peers: peerConfigs, Peers: peerConfigs,
} }
dev, err := mesh.GetDevice() err = m.routeInstaller.InstallRoutes(params.dev.Name, installedRoutes...)
return &cfg, err
}
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string][]routeNode) error {
snap, err := mesh.GetMesh()
if err != nil { if err != nil {
return err return err
} }
return m.meshManager.GetClient().ConfigureDevice(dev.Name, cfg) nodes := lib.MapValues(snap.GetNodes())
dev, _ := mesh.GetDevice()
slices.SortFunc(nodes, func(a, b MeshNode) int {
return strings.Compare(string(a.GetType()), string(b.GetType()))
})
peers := lib.Filter(nodes, func(mn MeshNode) bool {
return mn.GetType() == conf.PEER_ROLE
})
clients := lib.Filter(nodes, func(mn MeshNode) bool {
return mn.GetType() == conf.CLIENT_ROLE
})
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
if err != nil {
return err
}
var cfg *wgtypes.Config = nil
configParams := &GetConfigParams{
mesh: mesh,
peers: peers,
clients: clients,
dev: dev,
routes: routes,
}
switch self.GetType() {
case conf.PEER_ROLE:
cfg, err = m.getPeerConfig(configParams)
case conf.CLIENT_ROLE:
cfg, err = m.getClientConfig(configParams)
}
if err != nil {
return err
}
toRemove := m.getPeerCfgsToRemove(dev, cfg.Peers)
cfg.Peers = append(cfg.Peers, toRemove...)
err = m.meshManager.GetClient().ConfigureDevice(dev.Name, *cfg)
if err != nil {
return err
}
return nil
}
func (m *WgMeshConfigApplyer) getAllRoutes() map[string][]routeNode {
allRoutes := make(map[string][]routeNode)
for _, mesh := range m.meshManager.GetMeshes() {
routes := m.getRoutes(mesh)
for destination, route := range routes {
_, ok := allRoutes[destination]
if !ok {
allRoutes[destination] = route
continue
}
if allRoutes[destination][0].route.GetHopCount() == route[0].route.GetHopCount() {
allRoutes[destination] = append(allRoutes[destination], route...)
} else if route[0].route.GetHopCount() < allRoutes[destination][0].route.GetHopCount() {
allRoutes[destination] = route
}
}
}
return allRoutes
} }
func (m *WgMeshConfigApplyer) ApplyConfig() error { func (m *WgMeshConfigApplyer) ApplyConfig() error {
allRoutes := m.getAllRoutes()
for _, mesh := range m.meshManager.GetMeshes() { for _, mesh := range m.meshManager.GetMeshes() {
err := m.updateWgConf(mesh) err := m.updateWgConf(mesh, allRoutes)
if err != nil { if err != nil {
return err return err
@ -111,8 +457,8 @@ func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error {
} }
m.meshManager.GetClient().ConfigureDevice(dev.Name, wgtypes.Config{ m.meshManager.GetClient().ConfigureDevice(dev.Name, wgtypes.Config{
Peers: make([]wgtypes.PeerConfig, 0),
ReplacePeers: true, ReplacePeers: true,
Peers: make([]wgtypes.PeerConfig, 1),
}) })
return nil return nil
@ -122,6 +468,13 @@ func (m *WgMeshConfigApplyer) SetMeshManager(manager MeshManager) {
m.meshManager = manager m.meshManager = manager
} }
func NewWgMeshConfigApplyer() MeshConfigApplyer { func NewWgMeshConfigApplyer(config *conf.WgMeshConfiguration) MeshConfigApplyer {
return &WgMeshConfigApplyer{} return &WgMeshConfigApplyer{
config: config,
routeInstaller: route.NewRouteInstaller(),
hashFunc: func(mn MeshNode) int {
pubKey, _ := mn.GetPublicKey()
return lib.HashString(pubKey.String())
},
}
} }

View File

@ -61,7 +61,7 @@ func (c *MeshDOTConverter) graphNode(g *graph.Graph, node MeshNode, meshId strin
self, _ := c.manager.GetSelf(meshId) self, _ := c.manager.GetSelf(meshId)
if node.GetHostEndpoint() == self.GetHostEndpoint() { if NodeEquals(self, node) {
return return
} }

View File

@ -3,23 +3,22 @@ package mesh
import ( import (
"errors" "errors"
"fmt" "fmt"
"sync"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/wg" "github.com/tim-beatham/wgmesh/pkg/wg"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
type MeshManager interface { type MeshManager interface {
CreateMesh(devName string, port int) (string, error) CreateMesh(port int) (string, error)
AddMesh(params *AddMeshParams) error AddMesh(params *AddMeshParams) error
HasChanges(meshid string) bool HasChanges(meshid string) bool
GetMesh(meshId string) MeshProvider GetMesh(meshId string) MeshProvider
EnableInterface(meshId string) error GetPublicKey() *wgtypes.Key
GetPublicKey(meshId string) (*wgtypes.Key, error)
AddSelf(params *AddSelfParams) error AddSelf(params *AddSelfParams) error
LeaveMesh(meshId string) error LeaveMesh(meshId string) error
GetSelf(meshId string) (MeshNode, error) GetSelf(meshId string) (MeshNode, error)
@ -35,9 +34,11 @@ type MeshManager interface {
Close() error Close() error
GetMonitor() MeshMonitor GetMonitor() MeshMonitor
GetNode(string, string) MeshNode GetNode(string, string) MeshNode
GetRouteManager() RouteManager
} }
type MeshManagerImpl struct { type MeshManagerImpl struct {
lock sync.RWMutex
Meshes map[string]MeshProvider Meshes map[string]MeshProvider
RouteManager RouteManager RouteManager RouteManager
Client *wgctrl.Client Client *wgctrl.Client
@ -52,12 +53,18 @@ type MeshManagerImpl struct {
ipAllocator ip.IPAllocator ipAllocator ip.IPAllocator
interfaceManipulator wg.WgInterfaceManipulator interfaceManipulator wg.WgInterfaceManipulator
Monitor MeshMonitor Monitor MeshMonitor
OnDelete func(MeshProvider)
}
// GetRouteManager implements MeshManager.
func (m *MeshManagerImpl) GetRouteManager() RouteManager {
return m.RouteManager
} }
// RemoveService implements MeshManager. // RemoveService implements MeshManager.
func (m *MeshManagerImpl) RemoveService(service string) error { func (m *MeshManagerImpl) RemoveService(service string) error {
for _, mesh := range m.Meshes { for _, mesh := range m.Meshes {
err := mesh.RemoveService(m.HostParameters.HostEndpoint, service) err := mesh.RemoveService(m.HostParameters.GetPublicKey(), service)
if err != nil { if err != nil {
return err return err
@ -70,7 +77,7 @@ func (m *MeshManagerImpl) RemoveService(service string) error {
// SetService implements MeshManager. // SetService implements MeshManager.
func (m *MeshManagerImpl) SetService(service string, value string) error { func (m *MeshManagerImpl) SetService(service string, value string) error {
for _, mesh := range m.Meshes { for _, mesh := range m.Meshes {
err := mesh.AddService(m.HostParameters.HostEndpoint, service, value) err := mesh.AddService(m.HostParameters.GetPublicKey(), service, value)
if err != nil { if err != nil {
return err return err
@ -104,7 +111,7 @@ func (m *MeshManagerImpl) GetMonitor() MeshMonitor {
// Prune implements MeshManager. // Prune implements MeshManager.
func (m *MeshManagerImpl) Prune() error { func (m *MeshManagerImpl) Prune() error {
for _, mesh := range m.Meshes { for _, mesh := range m.Meshes {
err := mesh.Prune(m.conf.PruneTime) err := mesh.Prune()
if err != nil { if err != nil {
return err return err
@ -115,55 +122,68 @@ func (m *MeshManagerImpl) Prune() error {
} }
// CreateMesh: Creates a new mesh, stores it and returns the mesh id // CreateMesh: Creates a new mesh, stores it and returns the mesh id
func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) { func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
meshId, err := m.idGenerator.GetId() meshId, err := m.idGenerator.GetId()
var ifName string = ""
if err != nil { if err != nil {
return "", err return "", err
} }
nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: devName,
Port: port,
Conf: m.conf,
Client: m.Client,
MeshId: meshId,
})
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
}
if !m.conf.StubWg { if !m.conf.StubWg {
err = m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{ ifName, err = m.interfaceManipulator.CreateInterface(port, m.HostParameters.PrivateKey)
IfName: devName,
Port: port,
})
if err != nil { if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err) return "", fmt.Errorf("error creating mesh: %w", err)
} }
} }
nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: ifName,
Port: port,
Conf: m.conf,
Client: m.Client,
MeshId: meshId,
NodeID: m.HostParameters.GetPublicKey(),
})
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
}
m.lock.Lock()
m.Meshes[meshId] = nodeManager m.Meshes[meshId] = nodeManager
m.lock.Unlock()
return meshId, nil return meshId, nil
} }
type AddMeshParams struct { type AddMeshParams struct {
MeshId string MeshId string
DevName string
WgPort int WgPort int
MeshBytes []byte MeshBytes []byte
} }
// AddMesh: Add the mesh to the list of meshes // AddMesh: Add the mesh to the list of meshes
func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error { func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
var ifName string
var err error
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(params.WgPort, m.HostParameters.PrivateKey)
if err != nil {
return err
}
}
meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{ meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: params.DevName, DevName: ifName,
Port: params.WgPort, Port: params.WgPort,
Conf: m.conf, Conf: m.conf,
Client: m.Client, Client: m.Client,
MeshId: params.MeshId, MeshId: params.MeshId,
NodeID: m.HostParameters.GetPublicKey(),
}) })
if err != nil { if err != nil {
@ -176,15 +196,9 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
return err return err
} }
m.lock.Lock()
m.Meshes[params.MeshId] = meshProvider m.Meshes[params.MeshId] = meshProvider
m.lock.Unlock()
if !m.conf.StubWg {
return m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: params.DevName,
Port: params.WgPort,
})
}
return nil return nil
} }
@ -199,43 +213,10 @@ func (m *MeshManagerImpl) GetMesh(meshId string) MeshProvider {
return theMesh return theMesh
} }
// EnableInterface: Enables the given WireGuard interface.
func (s *MeshManagerImpl) EnableInterface(meshId string) error {
err := s.configApplyer.ApplyConfig()
if err != nil {
return err
}
err = s.RouteManager.InstallRoutes()
if err != nil {
return err
}
return nil
}
// GetPublicKey: Gets the public key of the WireGuard mesh // GetPublicKey: Gets the public key of the WireGuard mesh
func (s *MeshManagerImpl) GetPublicKey(meshId string) (*wgtypes.Key, error) { func (s *MeshManagerImpl) GetPublicKey() *wgtypes.Key {
if s.conf.StubWg { key := s.HostParameters.PrivateKey.PublicKey()
zeroedKey := make([]byte, wgtypes.KeyLen) return &key
return (*wgtypes.Key)(zeroedKey), nil
}
mesh, ok := s.Meshes[meshId]
if !ok {
return nil, errors.New("mesh does not exist")
}
dev, err := mesh.GetDevice()
if err != nil {
return nil, err
}
return &dev.PublicKey, nil
} }
type AddSelfParams struct { type AddSelfParams struct {
@ -255,20 +236,26 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
return fmt.Errorf("addself: mesh %s does not exist", params.MeshId) return fmt.Errorf("addself: mesh %s does not exist", params.MeshId)
} }
pubKey, err := s.GetPublicKey(params.MeshId) if params.WgPort == 0 && !s.conf.StubWg {
device, err := mesh.GetDevice()
if err != nil { if err != nil {
return err return err
}
params.WgPort = device.ListenPort
} }
nodeIP, err := s.ipAllocator.GetIP(*pubKey, params.MeshId) pubKey := s.HostParameters.PrivateKey.PublicKey()
nodeIP, err := s.ipAllocator.GetIP(pubKey, params.MeshId)
if err != nil { if err != nil {
return err return err
} }
node := s.nodeFactory.Build(&MeshNodeFactoryParams{ node := s.nodeFactory.Build(&MeshNodeFactoryParams{
PublicKey: pubKey, PublicKey: &pubKey,
NodeIP: nodeIP, NodeIP: nodeIP,
WgPort: params.WgPort, WgPort: params.WgPort,
Endpoint: params.Endpoint, Endpoint: params.Endpoint,
@ -289,34 +276,45 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
} }
s.Meshes[params.MeshId].AddNode(node) s.Meshes[params.MeshId].AddNode(node)
return s.RouteManager.UpdateRoutes() return nil
} }
// LeaveMesh leaves the mesh network // LeaveMesh leaves the mesh network
func (s *MeshManagerImpl) LeaveMesh(meshId string) error { func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
mesh, exists := s.Meshes[meshId] mesh := s.GetMesh(meshId)
if !exists { if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId) return fmt.Errorf("mesh %s does not exist", meshId)
} }
err := s.RouteManager.RemoveRoutes(meshId) err := mesh.RemoveNode(s.HostParameters.GetPublicKey())
if err != nil { if err != nil {
return err return err
} }
if !s.conf.StubWg { if s.OnDelete != nil {
device, e := mesh.GetDevice() s.OnDelete(mesh)
}
if e != nil { s.lock.Lock()
delete(s.Meshes, meshId)
s.lock.Unlock()
if !s.conf.StubWg {
device, err := mesh.GetDevice()
if err != nil {
return err return err
} }
err = s.interfaceManipulator.RemoveInterface(device.Name) err = s.interfaceManipulator.RemoveInterface(device.Name)
if err != nil {
return err
}
} }
delete(s.Meshes, meshId)
return err return err
} }
@ -327,7 +325,7 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
return nil, fmt.Errorf("mesh %s does not exist", meshId) return nil, fmt.Errorf("mesh %s does not exist", meshId)
} }
node, err := meshInstance.GetNode(s.HostParameters.HostEndpoint) node, err := meshInstance.GetNode(s.HostParameters.GetPublicKey())
if err != nil { if err != nil {
return nil, errors.New("the node doesn't exist in the mesh") return nil, errors.New("the node doesn't exist in the mesh")
@ -337,6 +335,10 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
} }
func (s *MeshManagerImpl) ApplyConfig() error { func (s *MeshManagerImpl) ApplyConfig() error {
if s.conf.StubWg {
return nil
}
err := s.configApplyer.ApplyConfig() err := s.configApplyer.ApplyConfig()
if err != nil { if err != nil {
@ -347,9 +349,10 @@ func (s *MeshManagerImpl) ApplyConfig() error {
} }
func (s *MeshManagerImpl) SetDescription(description string) error { func (s *MeshManagerImpl) SetDescription(description string) error {
for _, mesh := range s.Meshes { meshes := s.GetMeshes()
if mesh.NodeExists(s.HostParameters.HostEndpoint) { for _, mesh := range meshes {
err := mesh.SetDescription(s.HostParameters.HostEndpoint, description) if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetDescription(s.HostParameters.GetPublicKey(), description)
if err != nil { if err != nil {
return err return err
@ -362,9 +365,10 @@ func (s *MeshManagerImpl) SetDescription(description string) error {
// SetAlias implements MeshManager. // SetAlias implements MeshManager.
func (s *MeshManagerImpl) SetAlias(alias string) error { func (s *MeshManagerImpl) SetAlias(alias string) error {
for _, mesh := range s.Meshes { meshes := s.GetMeshes()
if mesh.NodeExists(s.HostParameters.HostEndpoint) { for _, mesh := range meshes {
err := mesh.SetAlias(s.HostParameters.HostEndpoint, alias) if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetAlias(s.HostParameters.GetPublicKey(), alias)
if err != nil { if err != nil {
return err return err
@ -376,9 +380,10 @@ func (s *MeshManagerImpl) SetAlias(alias string) error {
// UpdateTimeStamp updates the timestamp of this node in all meshes // UpdateTimeStamp updates the timestamp of this node in all meshes
func (s *MeshManagerImpl) UpdateTimeStamp() error { func (s *MeshManagerImpl) UpdateTimeStamp() error {
for _, mesh := range s.Meshes { meshes := s.GetMeshes()
if mesh.NodeExists(s.HostParameters.HostEndpoint) { for _, mesh := range meshes {
err := mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint) if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.UpdateTimeStamp(s.HostParameters.GetPublicKey())
if err != nil { if err != nil {
return err return err
@ -394,7 +399,16 @@ func (s *MeshManagerImpl) GetClient() *wgctrl.Client {
} }
func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider { func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider {
return s.Meshes meshes := make(map[string]MeshProvider)
s.lock.RLock()
for id, mesh := range s.Meshes {
meshes[id] = mesh
}
s.lock.RUnlock()
return meshes
} }
// Close the mesh manager // Close the mesh manager
@ -431,21 +445,16 @@ type NewMeshManagerParams struct {
InterfaceManipulator wg.WgInterfaceManipulator InterfaceManipulator wg.WgInterfaceManipulator
ConfigApplyer MeshConfigApplyer ConfigApplyer MeshConfigApplyer
RouteManager RouteManager RouteManager RouteManager
OnDelete func(MeshProvider)
} }
// Creates a new instance of a mesh manager with the given parameters // Creates a new instance of a mesh manager with the given parameters
func NewMeshManager(params *NewMeshManagerParams) MeshManager { func NewMeshManager(params *NewMeshManagerParams) MeshManager {
hostParams := HostParameters{} privateKey, _ := wgtypes.GeneratePrivateKey()
hostParams := HostParameters{
switch params.Conf.Endpoint { PrivateKey: &privateKey,
case "":
hostParams.HostEndpoint = fmt.Sprintf("%s:%s", lib.GetOutboundIP().String(), params.Conf.GrpcPort)
default:
hostParams.HostEndpoint = fmt.Sprintf("%s:%s", params.Conf.Endpoint, params.Conf.GrpcPort)
} }
logging.Log.WriteInfof("Endpoint %s", hostParams.HostEndpoint)
m := &MeshManagerImpl{ m := &MeshManagerImpl{
Meshes: make(map[string]MeshProvider), Meshes: make(map[string]MeshProvider),
HostParameters: &hostParams, HostParameters: &hostParams,
@ -459,7 +468,7 @@ func NewMeshManager(params *NewMeshManagerParams) MeshManager {
m.RouteManager = params.RouteManager m.RouteManager = params.RouteManager
if m.RouteManager == nil { if m.RouteManager == nil {
m.RouteManager = NewRouteManager(m) m.RouteManager = NewRouteManager(m, &params.Conf)
} }
m.idGenerator = params.IdGenerator m.idGenerator = params.IdGenerator
@ -471,5 +480,6 @@ func NewMeshManager(params *NewMeshManagerParams) MeshManager {
aliasManager := NewAliasManager() aliasManager := NewAliasManager()
m.Monitor.AddUpdateCallback(aliasManager.AddAliases) m.Monitor.AddUpdateCallback(aliasManager.AddAliases)
m.Monitor.AddRemoveCallback(aliasManager.RemoveAliases) m.Monitor.AddRemoveCallback(aliasManager.RemoveAliases)
m.OnDelete = params.OnDelete
return m return m
} }

View File

@ -64,7 +64,6 @@ func TestAddMeshAddsAMesh(t *testing.T) {
manager.AddMesh(&AddMeshParams{ manager.AddMesh(&AddMeshParams{
MeshId: meshId, MeshId: meshId,
DevName: "wg0",
WgPort: 6000, WgPort: 6000,
MeshBytes: make([]byte, 0), MeshBytes: make([]byte, 0),
}) })
@ -83,7 +82,6 @@ func TestAddMeshMeshAlreadyExistsReplacesIt(t *testing.T) {
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
err := manager.AddMesh(&AddMeshParams{ err := manager.AddMesh(&AddMeshParams{
MeshId: meshId, MeshId: meshId,
DevName: "wg0",
WgPort: 6000, WgPort: 6000,
MeshBytes: make([]byte, 0), MeshBytes: make([]byte, 0),
}) })
@ -106,7 +104,6 @@ func TestAddSelfAddsSelfToTheMesh(t *testing.T) {
err := manager.AddMesh(&AddMeshParams{ err := manager.AddMesh(&AddMeshParams{
MeshId: meshId, MeshId: meshId,
DevName: "wg0",
WgPort: 6000, WgPort: 6000,
MeshBytes: make([]byte, 0), MeshBytes: make([]byte, 0),
}) })
@ -175,7 +172,6 @@ func TestLeaveMeshDeletesMesh(t *testing.T) {
err := manager.AddMesh(&AddMeshParams{ err := manager.AddMesh(&AddMeshParams{
MeshId: meshId, MeshId: meshId,
DevName: "wg0",
WgPort: 6000, WgPort: 6000,
MeshBytes: make([]byte, 0), MeshBytes: make([]byte, 0),
}) })
@ -201,8 +197,8 @@ func TestSetDescription(t *testing.T) {
manager := getMeshManager() manager := getMeshManager()
description := "wooooo" description := "wooooo"
meshId1, _ := manager.CreateMesh("wg0", 5000) meshId1, _ := manager.CreateMesh(5000)
meshId2, _ := manager.CreateMesh("wg0", 5001) meshId2, _ := manager.CreateMesh(5001)
manager.AddSelf(&AddSelfParams{ manager.AddSelf(&AddSelfParams{
MeshId: meshId1, MeshId: meshId1,
@ -225,8 +221,8 @@ func TestSetDescription(t *testing.T) {
func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) { func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) {
manager := getMeshManager() manager := getMeshManager()
meshId1, _ := manager.CreateMesh("wg0", 5000) meshId1, _ := manager.CreateMesh(5000)
meshId2, _ := manager.CreateMesh("wg0", 5001) meshId2, _ := manager.CreateMesh(5001)
manager.AddSelf(&AddSelfParams{ manager.AddSelf(&AddSelfParams{
MeshId: meshId1, MeshId: meshId1,

View File

@ -1,184 +1,111 @@
package mesh package mesh
import ( import (
"fmt"
"net" "net"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/route"
"golang.org/x/sys/unix"
) )
type RouteManager interface { type RouteManager interface {
UpdateRoutes() error UpdateRoutes() error
InstallRoutes() error
RemoveRoutes(meshId string) error
} }
type RouteManagerImpl struct { type RouteManagerImpl struct {
meshManager MeshManager meshManager MeshManager
routeInstaller route.RouteInstaller conf *conf.WgMeshConfiguration
} }
func (r *RouteManagerImpl) UpdateRoutes() error { func (r *RouteManagerImpl) UpdateRoutes() error {
meshes := r.meshManager.GetMeshes() meshes := r.meshManager.GetMeshes()
ulaBuilder := new(ip.ULABuilder) routes := make(map[string][]Route)
for _, mesh1 := range meshes { for _, mesh1 := range meshes {
self, err := r.meshManager.GetSelf(mesh1.GetMeshId())
if err != nil {
return err
}
if _, ok := routes[mesh1.GetMeshId()]; !ok {
routes[mesh1.GetMeshId()] = make([]Route, 0)
}
routeMap, err := mesh1.GetRoutes(NodeID(self))
if err != nil {
return err
}
if r.conf.AdvertiseDefaultRoute {
_, ipv6Default, _ := net.ParseCIDR("::/0")
mesh1.AddRoutes(NodeID(self),
&RouteStub{
Destination: ipv6Default,
HopCount: 0,
Path: make([]string, 0),
})
}
for _, mesh2 := range meshes { for _, mesh2 := range meshes {
routeValues, ok := routes[mesh2.GetMeshId()]
if !ok {
routeValues = make([]Route, 0)
}
if mesh1 == mesh2 { if mesh1 == mesh2 {
continue continue
} }
ipNet, err := ulaBuilder.GetIPNet(mesh2.GetMeshId()) mesh1IpNet, _ := (&ip.ULABuilder{}).GetIPNet(mesh1.GetMeshId())
if err != nil { routeValues = append(routeValues, &RouteStub{
logging.Log.WriteErrorf(err.Error()) Destination: mesh1IpNet,
return err HopCount: 0,
} Path: []string{mesh1.GetMeshId()},
})
self, err := r.meshManager.GetSelf(mesh1.GetMeshId()) routeValues = append(routeValues, lib.MapValues(routeMap)...)
mesh2IpNet, _ := (&ip.ULABuilder{}).GetIPNet(mesh2.GetMeshId())
routeValues = lib.Filter(routeValues, func(r Route) bool {
pathNotMesh := func(s string) bool {
return s == mesh2.GetMeshId()
}
if err != nil { // Ensure that the route does not see it's own IP
return err return !r.GetDestination().IP.Equal(mesh2IpNet.IP) && !lib.Contains(r.GetPath()[1:], pathNotMesh)
} })
err = mesh1.AddRoutes(self.GetHostEndpoint(), ipNet.String()) routes[mesh2.GetMeshId()] = routeValues
}
}
if err != nil { // Calculate the set different of each, working out routes to remove and to keep.
return err for meshId, meshRoutes := range routes {
mesh := r.meshManager.GetMesh(meshId)
self, _ := r.meshManager.GetSelf(meshId)
toRemove := make([]Route, 0)
prevRoutes, _ := mesh.GetRoutes(NodeID(self))
for _, route := range prevRoutes {
if !lib.Contains(meshRoutes, func(r Route) bool {
return RouteEquals(r, route)
}) {
toRemove = append(toRemove, route)
} }
} }
mesh.RemoveRoutes(NodeID(self), toRemove...)
mesh.AddRoutes(NodeID(self), meshRoutes...)
} }
return nil return nil
} }
// removeRoutes: removes all meshes we are no longer a part of func NewRouteManager(m MeshManager, conf *conf.WgMeshConfiguration) RouteManager {
func (r *RouteManagerImpl) RemoveRoutes(meshId string) error { return &RouteManagerImpl{meshManager: m, conf: conf}
ulaBuilder := new(ip.ULABuilder)
meshes := r.meshManager.GetMeshes()
ipNet, err := ulaBuilder.GetIPNet(meshId)
if err != nil {
return err
}
for _, mesh1 := range meshes {
self, err := r.meshManager.GetSelf(meshId)
if err != nil {
return err
}
mesh1.RemoveRoutes(self.GetHostEndpoint(), ipNet.String())
}
return nil
}
// AddRoute adds a route to the given interface
func (m *RouteManagerImpl) addRoute(ifName string, meshPrefix string, routes ...lib.Route) error {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return fmt.Errorf("failed to create config: %w", err)
}
defer rtnl.Close()
// Delete any routes that may be vacant
err = rtnl.DeleteRoutes(ifName, unix.AF_INET6, routes...)
if err != nil {
return err
}
for _, route := range routes {
if route.Destination.String() == meshPrefix {
continue
}
err = rtnl.AddRoute(ifName, route)
if err != nil {
return err
}
}
return nil
}
func (m *RouteManagerImpl) installRoute(ifName string, meshid string, node MeshNode) error {
routeMapFunc := func(route string) lib.Route {
_, cidr, _ := net.ParseCIDR(route)
r := lib.Route{
Destination: *cidr,
Gateway: node.GetWgHost().IP,
}
return r
}
ipBuilder := &ip.ULABuilder{}
ipNet, err := ipBuilder.GetIPNet(meshid)
if err != nil {
return err
}
routes := lib.Map(append(node.GetRoutes(), ipNet.String()), routeMapFunc)
return m.addRoute(ifName, ipNet.String(), routes...)
}
func (m *RouteManagerImpl) installRoutes(meshProvider MeshProvider) error {
mesh, err := meshProvider.GetMesh()
if err != nil {
return err
}
dev, err := meshProvider.GetDevice()
if err != nil {
return err
}
self, err := m.meshManager.GetSelf(meshProvider.GetMeshId())
if err != nil {
return err
}
for _, node := range mesh.GetNodes() {
if self.GetHostEndpoint() == node.GetHostEndpoint() {
continue
}
err = m.installRoute(dev.Name, meshProvider.GetMeshId(), node)
if err != nil {
return err
}
}
return nil
}
// InstallRoutes installs all routes to the RIB
func (r *RouteManagerImpl) InstallRoutes() error {
for _, mesh := range r.meshManager.GetMeshes() {
err := r.installRoutes(mesh)
if err != nil {
return err
}
}
return nil
}
func NewRouteManager(m MeshManager) RouteManager {
return &RouteManagerImpl{meshManager: m, routeInstaller: route.NewRouteInstaller()}
} }

View File

@ -16,11 +16,16 @@ type MeshNodeStub struct {
wgEndpoint string wgEndpoint string
wgHost *net.IPNet wgHost *net.IPNet
timeStamp int64 timeStamp int64
routes []string routes []Route
identifier string identifier string
description string description string
} }
// GetType implements MeshNode.
func (*MeshNodeStub) GetType() conf.NodeType {
return conf.PEER_ROLE
}
// GetServices implements MeshNode. // GetServices implements MeshNode.
func (*MeshNodeStub) GetServices() map[string]string { func (*MeshNodeStub) GetServices() map[string]string {
return make(map[string]string) return make(map[string]string)
@ -51,7 +56,7 @@ func (m *MeshNodeStub) GetTimeStamp() int64 {
return m.timeStamp return m.timeStamp
} }
func (m *MeshNodeStub) GetRoutes() []string { func (m *MeshNodeStub) GetRoutes() []Route {
return m.routes return m.routes
} }
@ -76,43 +81,57 @@ type MeshProviderStub struct {
snapshot *MeshSnapshotStub snapshot *MeshSnapshotStub
} }
// GetNodeIds implements MeshProvider. // Mark implements MeshProvider.
func (*MeshProviderStub) GetNodeIds() []string { func (*MeshProviderStub) Mark(nodeId string) {
panic("unimplemented") panic("unimplemented")
} }
// RemoveNode implements MeshProvider.
func (*MeshProviderStub) RemoveNode(nodeId string) error {
panic("unimplemented")
}
func (*MeshProviderStub) GetRoutes(targetId string) (map[string]Route, error) {
return nil, nil
}
// GetNodeIds implements MeshProvider.
func (*MeshProviderStub) GetPeers() []string {
return make([]string, 0)
}
// GetNode implements MeshProvider. // GetNode implements MeshProvider.
func (*MeshProviderStub) GetNode(string) (MeshNode, error) { func (*MeshProviderStub) GetNode(string) (MeshNode, error) {
panic("unimplemented") return nil, nil
} }
// NodeExists implements MeshProvider. // NodeExists implements MeshProvider.
func (*MeshProviderStub) NodeExists(string) bool { func (*MeshProviderStub) NodeExists(string) bool {
panic("unimplemented") return false
} }
// AddService implements MeshProvider. // AddService implements MeshProvider.
func (*MeshProviderStub) AddService(nodeId string, key string, value string) error { func (*MeshProviderStub) AddService(nodeId string, key string, value string) error {
panic("unimplemented") return nil
} }
// RemoveService implements MeshProvider. // RemoveService implements MeshProvider.
func (*MeshProviderStub) RemoveService(nodeId string, key string) error { func (*MeshProviderStub) RemoveService(nodeId string, key string) error {
panic("unimplemented") return nil
} }
// SetAlias implements MeshProvider. // SetAlias implements MeshProvider.
func (*MeshProviderStub) SetAlias(nodeId string, alias string) error { func (*MeshProviderStub) SetAlias(nodeId string, alias string) error {
panic("unimplemented") return nil
} }
// RemoveRoutes implements MeshProvider. // RemoveRoutes implements MeshProvider.
func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error { func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...Route) error {
panic("unimplemented") return nil
} }
// Prune implements MeshProvider. // Prune implements MeshProvider.
func (*MeshProviderStub) Prune(pruneAmount int) error { func (*MeshProviderStub) Prune() error {
return nil return nil
} }
@ -154,7 +173,7 @@ func (s *MeshProviderStub) HasChanges() bool {
return false return false
} }
func (s *MeshProviderStub) AddRoutes(nodeId string, route ...string) error { func (s *MeshProviderStub) AddRoutes(nodeId string, route ...Route) error {
return nil return nil
} }
@ -188,7 +207,7 @@ func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode {
wgEndpoint: fmt.Sprintf("%s:%s", params.Endpoint, s.Config.GrpcPort), wgEndpoint: fmt.Sprintf("%s:%s", params.Endpoint, s.Config.GrpcPort),
wgHost: wgHost, wgHost: wgHost,
timeStamp: time.Now().Unix(), timeStamp: time.Now().Unix(),
routes: make([]string, 0), routes: make([]Route, 0),
identifier: "abc", identifier: "abc",
description: "A Mesh Node Stub", description: "A Mesh Node Stub",
} }
@ -211,6 +230,11 @@ type MeshManagerStub struct {
meshes map[string]MeshProvider meshes map[string]MeshProvider
} }
// GetRouteManager implements MeshManager.
func (*MeshManagerStub) GetRouteManager() RouteManager {
panic("unimplemented")
}
// GetNode implements MeshManager. // GetNode implements MeshManager.
func (*MeshManagerStub) GetNode(string, string) MeshNode { func (*MeshManagerStub) GetNode(string, string) MeshNode {
panic("unimplemented") panic("unimplemented")
@ -250,7 +274,7 @@ func NewMeshManagerStub() MeshManager {
return &MeshManagerStub{meshes: make(map[string]MeshProvider)} return &MeshManagerStub{meshes: make(map[string]MeshProvider)}
} }
func (m *MeshManagerStub) CreateMesh(devName string, port int) (string, error) { func (m *MeshManagerStub) CreateMesh(port int) (string, error) {
return "tim123", nil return "tim123", nil
} }
@ -273,13 +297,9 @@ func (m *MeshManagerStub) GetMesh(meshId string) MeshProvider {
snapshot: &MeshSnapshotStub{nodes: make(map[string]MeshNode)}} snapshot: &MeshSnapshotStub{nodes: make(map[string]MeshNode)}}
} }
func (m *MeshManagerStub) EnableInterface(meshId string) error { func (m *MeshManagerStub) GetPublicKey() *wgtypes.Key {
return nil
}
func (m *MeshManagerStub) GetPublicKey(meshId string) (*wgtypes.Key, error) {
key, _ := wgtypes.GenerateKey() key, _ := wgtypes.GenerateKey()
return &key, nil return &key
} }
func (m *MeshManagerStub) AddSelf(params *AddSelfParams) error { func (m *MeshManagerStub) AddSelf(params *AddSelfParams) error {

View File

@ -11,6 +11,39 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
type Route interface {
// GetDestination: returns the destination of the route
GetDestination() *net.IPNet
// GetHopCount: get the total hopcount of the prefix
GetHopCount() int
// GetPath: get a list of AS paths to get to the destination
GetPath() []string
}
func RouteEquals(r1, r2 Route) bool {
return r1.GetDestination().String() == r2.GetDestination().String() &&
r1.GetHopCount() == r2.GetHopCount() &&
slices.Equal(r1.GetPath(), r2.GetPath())
}
type RouteStub struct {
Destination *net.IPNet
HopCount int
Path []string
}
func (r *RouteStub) GetDestination() *net.IPNet {
return r.Destination
}
func (r *RouteStub) GetHopCount() int {
return r.HopCount
}
func (r *RouteStub) GetPath() []string {
return r.Path
}
// MeshNode represents an implementation of a node in a mesh // MeshNode represents an implementation of a node in a mesh
type MeshNode interface { type MeshNode interface {
// GetHostEndpoint: gets the gRPC endpoint of the node // GetHostEndpoint: gets the gRPC endpoint of the node
@ -24,7 +57,7 @@ type MeshNode interface {
// GetTimestamp: get the UNIX time stamp of the ndoe // GetTimestamp: get the UNIX time stamp of the ndoe
GetTimeStamp() int64 GetTimeStamp() int64
// GetRoutes: returns the routes that the nodes provides // GetRoutes: returns the routes that the nodes provides
GetRoutes() []string GetRoutes() []Route
// GetIdentifier: returns the identifier of the node // GetIdentifier: returns the identifier of the node
GetIdentifier() string GetIdentifier() string
// GetDescription: returns the description for this node // GetDescription: returns the description for this node
@ -34,46 +67,20 @@ type MeshNode interface {
GetAlias() string GetAlias() string
// GetServices: returns a list of services offered by the node // GetServices: returns a list of services offered by the node
GetServices() map[string]string GetServices() map[string]string
GetType() conf.NodeType
} }
// NodeEquals: determines if two mesh nodes are equivalent to one another // NodeEquals: determines if two mesh nodes are equivalent to one another
func NodeEquals(node1, node2 MeshNode) bool { func NodeEquals(node1, node2 MeshNode) bool {
if node1.GetHostEndpoint() != node2.GetHostEndpoint() { key1, _ := node1.GetPublicKey()
return false key2, _ := node2.GetPublicKey()
}
node1Pub, _ := node1.GetPublicKey() return key1.String() == key2.String()
node2Pub, _ := node2.GetPublicKey() }
if node1Pub != node2Pub { func NodeID(node MeshNode) string {
return false key, _ := node.GetPublicKey()
} return key.String()
if node1.GetWgEndpoint() != node2.GetWgEndpoint() {
return false
}
if node1.GetWgHost() != node2.GetWgHost() {
return false
}
if !slices.Equal(node1.GetRoutes(), node2.GetRoutes()) {
return false
}
if node1.GetIdentifier() != node2.GetIdentifier() {
return false
}
if node1.GetDescription() != node2.GetDescription() {
return false
}
if node1.GetAlias() != node2.GetAlias() {
return false
}
return true
} }
type MeshSnapshot interface { type MeshSnapshot interface {
@ -109,9 +116,9 @@ type MeshProvider interface {
// UpdateTimeStamp: update the timestamp of the given node // UpdateTimeStamp: update the timestamp of the given node
UpdateTimeStamp(nodeId string) error UpdateTimeStamp(nodeId string) error
// AddRoutes: adds routes to the given node // AddRoutes: adds routes to the given node
AddRoutes(nodeId string, route ...string) error AddRoutes(nodeId string, route ...Route) error
// DeleteRoutes: deletes the routes from the node // DeleteRoutes: deletes the routes from the node
RemoveRoutes(nodeId string, route ...string) error RemoveRoutes(nodeId string, route ...Route) error
// GetSyncer: returns the automerge syncer for sync // GetSyncer: returns the automerge syncer for sync
GetSyncer() MeshSyncer GetSyncer() MeshSyncer
// GetNode get a particular not within the mesh // GetNode get a particular not within the mesh
@ -126,15 +133,28 @@ type MeshProvider interface {
AddService(nodeId, key, value string) error AddService(nodeId, key, value string) error
// RemoveService: removes the service form the node. throws an error if the service does not exist // RemoveService: removes the service form the node. throws an error if the service does not exist
RemoveService(nodeId, key string) error RemoveService(nodeId, key string) error
// Prune: prunes all nodes that have not updated their timestamp in // Prune: prunes all nodes that have not updated their
// pruneAmount seconds // vector clock
Prune(pruneAmount int) error Prune() error
GetNodeIds() []string // GetPeers: get a list of contactable peers
GetPeers() []string
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen
GetRoutes(targetNode string) (map[string]Route, error)
// RemoveNode(): remove the node from the mesh
RemoveNode(nodeId string) error
// Mark: marks the node as unreachable. This is not broadcast to the entire
// this is not considered when syncing node state
Mark(nodeId string)
} }
// HostParameters contains the IDs of a node // HostParameters contains the IDs of a node
type HostParameters struct { type HostParameters struct {
HostEndpoint string PrivateKey *wgtypes.Key
}
// GetPublicKey: gets the public key of the node
func (h *HostParameters) GetPublicKey() string {
return h.PrivateKey.PublicKey().String()
} }
// MeshProviderFactoryParams parameters required to build a mesh provider // MeshProviderFactoryParams parameters required to build a mesh provider
@ -144,6 +164,7 @@ type MeshProviderFactoryParams struct {
Port int Port int
Conf *conf.WgMeshConfiguration Conf *conf.WgMeshConfiguration
Client *wgctrl.Client Client *wgctrl.Client
NodeID string
} }
// MeshProviderFactory creates an instance of a mesh provider // MeshProviderFactory creates an instance of a mesh provider

View File

@ -3,8 +3,10 @@ package query
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strings"
"github.com/jmespath/go-jmespath" "github.com/jmespath/go-jmespath"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/wgmesh/pkg/mesh"
) )
@ -23,16 +25,23 @@ type QueryError struct {
msg string msg string
} }
type QueryRoute struct {
Destination string `json:"destination"`
HopCount int `json:"hopCount"`
Path string `json:"path"`
}
type QueryNode struct { type QueryNode struct {
HostEndpoint string `json:"hostEndpoint"` HostEndpoint string `json:"hostEndpoint"`
PublicKey string `json:"publicKey"` PublicKey string `json:"publicKey"`
WgEndpoint string `json:"wgEndpoint"` WgEndpoint string `json:"wgEndpoint"`
WgHost string `json:"wgHost"` WgHost string `json:"wgHost"`
Timestamp int64 `json:"timestmap"` Timestamp int64 `json:"timestamp"`
Description string `json:"description"` Description string `json:"description"`
Routes []string `json:"routes"` Routes []QueryRoute `json:"routes"`
Alias string `json:"alias"` Alias string `json:"alias"`
Services map[string]string `json:"services"` Services map[string]string `json:"services"`
Type conf.NodeType `json:"type"`
} }
func (m *QueryError) Error() string { func (m *QueryError) Error() string {
@ -76,10 +85,17 @@ func MeshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode.WgHost = node.GetWgHost().String() queryNode.WgHost = node.GetWgHost().String()
queryNode.Timestamp = node.GetTimeStamp() queryNode.Timestamp = node.GetTimeStamp()
queryNode.Routes = node.GetRoutes() queryNode.Routes = lib.Map(node.GetRoutes(), func(r mesh.Route) QueryRoute {
return QueryRoute{
Destination: r.GetDestination().String(),
HopCount: r.GetHopCount(),
Path: strings.Join(r.GetPath(), ","),
}
})
queryNode.Description = node.GetDescription() queryNode.Description = node.GetDescription()
queryNode.Alias = node.GetAlias() queryNode.Alias = node.GetAlias()
queryNode.Services = node.GetServices() queryNode.Services = node.GetServices()
queryNode.Type = node.GetType()
return queryNode return queryNode
} }

View File

@ -10,6 +10,7 @@ import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query" "github.com/tim-beatham/wgmesh/pkg/query"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/wgmesh/pkg/rpc"
@ -20,7 +21,7 @@ type IpcHandler struct {
} }
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error { func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
meshId, err := n.Server.GetMeshManager().CreateMesh(args.IfName, args.WgPort) meshId, err := n.Server.GetMeshManager().CreateMesh(args.WgPort)
if err != nil { if err != nil {
return err return err
@ -72,7 +73,9 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
return err return err
} }
ctx, cancel := context.WithTimeout(context.Background(), time.Second) configuration := n.Server.GetConfiguration()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(configuration.Timeout))
defer cancel() defer cancel()
meshReply, err := c.GetMesh(ctx, &rpc.GetMeshRequest{MeshId: args.MeshId}) meshReply, err := c.GetMesh(ctx, &rpc.GetMeshRequest{MeshId: args.MeshId})
@ -83,7 +86,6 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{ err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
MeshId: args.MeshId, MeshId: args.MeshId,
DevName: args.IfName,
WgPort: args.Port, WgPort: args.Port,
MeshBytes: meshReply.Mesh, MeshBytes: meshReply.Mesh,
}) })
@ -118,19 +120,19 @@ func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error {
} }
func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error { func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
mesh := n.Server.GetMeshManager().GetMesh(meshId) theMesh := n.Server.GetMeshManager().GetMesh(meshId)
if mesh == nil { if theMesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId) return fmt.Errorf("mesh %s does not exist", meshId)
} }
meshSnapshot, err := mesh.GetMesh() meshSnapshot, err := theMesh.GetMesh()
if err != nil { if err != nil {
return err return err
} }
if mesh == nil { if theMesh == nil {
return errors.New("mesh does not exist") return errors.New("mesh does not exist")
} }
@ -150,10 +152,15 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
PublicKey: pubKey.String(), PublicKey: pubKey.String(),
WgHost: node.GetWgHost().String(), WgHost: node.GetWgHost().String(),
Timestamp: node.GetTimeStamp(), Timestamp: node.GetTimeStamp(),
Routes: node.GetRoutes(), Routes: lib.Map(node.GetRoutes(), func(r mesh.Route) ctrlserver.MeshRoute {
Description: node.GetDescription(), return ctrlserver.MeshRoute{
Alias: node.GetAlias(), Destination: r.GetDestination().String(),
Services: node.GetServices(), Path: r.GetPath(),
}
}),
Description: node.GetDescription(),
Alias: node.GetAlias(),
Services: node.GetServices(),
} }
nodes[i] = node nodes[i] = node
@ -164,18 +171,6 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
return nil return nil
} }
func (n *IpcHandler) EnableInterface(meshId string, reply *string) error {
err := n.Server.GetMeshManager().EnableInterface(meshId)
if err != nil {
*reply = err.Error()
return err
}
*reply = "up"
return nil
}
func (n *IpcHandler) GetDOT(meshId string, reply *string) error { func (n *IpcHandler) GetDOT(meshId string, reply *string) error {
g := mesh.NewMeshDotConverter(n.Server.GetMeshManager()) g := mesh.NewMeshDotConverter(n.Server.GetMeshManager())

View File

@ -28,7 +28,3 @@ func (m *WgRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.
return &reply, nil return &reply, nil
} }
func (m *WgRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (*rpc.JoinMeshReply, error) {
return &rpc.JoinMeshReply{Success: true}, nil
}

View File

@ -1,22 +1,32 @@
package route package route
import ( import (
"net" "github.com/tim-beatham/wgmesh/pkg/lib"
"os/exec" "golang.org/x/sys/unix"
logging "github.com/tim-beatham/wgmesh/pkg/log"
) )
type RouteInstaller interface { type RouteInstaller interface {
InstallRoutes(devName string, routes ...*net.IPNet) error InstallRoutes(devName string, routes ...lib.Route) error
} }
type RouteInstallerImpl struct{} type RouteInstallerImpl struct{}
// InstallRoutes: installs a route into the routing table // InstallRoutes: installs a route into the routing table
func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...*net.IPNet) error { func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...lib.Route) error {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return err
}
err = rtnl.DeleteRoutes(devName, unix.AF_INET6, routes...)
if err != nil {
return err
}
for _, route := range routes { for _, route := range routes {
err := r.installRoute(devName, route) err := rtnl.AddRoute(devName, route)
if err != nil { if err != nil {
return err return err
@ -26,22 +36,6 @@ func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...*net.IPNet)
return nil return nil
} }
// installRoute: installs a route into the linux table
func (r *RouteInstallerImpl) installRoute(devName string, route *net.IPNet) error {
// TODO: Find a library that automates this
cmd := exec.Command("/usr/bin/ip", "-6", "route", "add", route.String(), "dev", devName)
logging.Log.WriteInfof("%s %s", route.String(), devName)
if msg, err := cmd.CombinedOutput(); err != nil {
logging.Log.WriteErrorf(err.Error())
logging.Log.WriteErrorf(string(msg))
return err
}
return nil
}
func NewRouteInstaller() RouteInstaller { func NewRouteInstaller() RouteInstaller {
return &RouteInstallerImpl{} return &RouteInstallerImpl{}
} }

View File

@ -20,77 +20,6 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
) )
type MeshNode struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
PublicKey string `protobuf:"bytes,1,opt,name=publicKey,proto3" json:"publicKey,omitempty"`
WgEndpoint string `protobuf:"bytes,2,opt,name=wgEndpoint,proto3" json:"wgEndpoint,omitempty"`
Endpoint string `protobuf:"bytes,3,opt,name=endpoint,proto3" json:"endpoint,omitempty"`
WgHost string `protobuf:"bytes,4,opt,name=wgHost,proto3" json:"wgHost,omitempty"`
}
func (x *MeshNode) Reset() {
*x = MeshNode{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *MeshNode) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*MeshNode) ProtoMessage() {}
func (x *MeshNode) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use MeshNode.ProtoReflect.Descriptor instead.
func (*MeshNode) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{0}
}
func (x *MeshNode) GetPublicKey() string {
if x != nil {
return x.PublicKey
}
return ""
}
func (x *MeshNode) GetWgEndpoint() string {
if x != nil {
return x.WgEndpoint
}
return ""
}
func (x *MeshNode) GetEndpoint() string {
if x != nil {
return x.Endpoint
}
return ""
}
func (x *MeshNode) GetWgHost() string {
if x != nil {
return x.WgHost
}
return ""
}
type GetMeshRequest struct { type GetMeshRequest struct {
state protoimpl.MessageState state protoimpl.MessageState
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
@ -102,7 +31,7 @@ type GetMeshRequest struct {
func (x *GetMeshRequest) Reset() { func (x *GetMeshRequest) Reset() {
*x = GetMeshRequest{} *x = GetMeshRequest{}
if protoimpl.UnsafeEnabled { if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1] mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi) ms.StoreMessageInfo(mi)
} }
@ -115,7 +44,7 @@ func (x *GetMeshRequest) String() string {
func (*GetMeshRequest) ProtoMessage() {} func (*GetMeshRequest) ProtoMessage() {}
func (x *GetMeshRequest) ProtoReflect() protoreflect.Message { func (x *GetMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1] mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil { if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil { if ms.LoadMessageInfo() == nil {
@ -128,7 +57,7 @@ func (x *GetMeshRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetMeshRequest.ProtoReflect.Descriptor instead. // Deprecated: Use GetMeshRequest.ProtoReflect.Descriptor instead.
func (*GetMeshRequest) Descriptor() ([]byte, []int) { func (*GetMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{1} return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{0}
} }
func (x *GetMeshRequest) GetMeshId() string { func (x *GetMeshRequest) GetMeshId() string {
@ -149,7 +78,7 @@ type GetMeshReply struct {
func (x *GetMeshReply) Reset() { func (x *GetMeshReply) Reset() {
*x = GetMeshReply{} *x = GetMeshReply{}
if protoimpl.UnsafeEnabled { if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2] mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi) ms.StoreMessageInfo(mi)
} }
@ -162,7 +91,7 @@ func (x *GetMeshReply) String() string {
func (*GetMeshReply) ProtoMessage() {} func (*GetMeshReply) ProtoMessage() {}
func (x *GetMeshReply) ProtoReflect() protoreflect.Message { func (x *GetMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2] mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil { if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil { if ms.LoadMessageInfo() == nil {
@ -175,7 +104,7 @@ func (x *GetMeshReply) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetMeshReply.ProtoReflect.Descriptor instead. // Deprecated: Use GetMeshReply.ProtoReflect.Descriptor instead.
func (*GetMeshReply) Descriptor() ([]byte, []int) { func (*GetMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{2} return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{1}
} }
func (x *GetMeshReply) GetMesh() []byte { func (x *GetMeshReply) GetMesh() []byte {
@ -185,145 +114,24 @@ func (x *GetMeshReply) GetMesh() []byte {
return nil return nil
} }
type JoinMeshRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Changes []byte `protobuf:"bytes,1,opt,name=changes,proto3" json:"changes,omitempty"`
MeshId string `protobuf:"bytes,2,opt,name=meshId,proto3" json:"meshId,omitempty"`
}
func (x *JoinMeshRequest) Reset() {
*x = JoinMeshRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *JoinMeshRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JoinMeshRequest) ProtoMessage() {}
func (x *JoinMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JoinMeshRequest.ProtoReflect.Descriptor instead.
func (*JoinMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{3}
}
func (x *JoinMeshRequest) GetChanges() []byte {
if x != nil {
return x.Changes
}
return nil
}
func (x *JoinMeshRequest) GetMeshId() string {
if x != nil {
return x.MeshId
}
return ""
}
type JoinMeshReply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"`
}
func (x *JoinMeshReply) Reset() {
*x = JoinMeshReply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *JoinMeshReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JoinMeshReply) ProtoMessage() {}
func (x *JoinMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JoinMeshReply.ProtoReflect.Descriptor instead.
func (*JoinMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{4}
}
func (x *JoinMeshReply) GetSuccess() bool {
if x != nil {
return x.Success
}
return false
}
var File_pkg_grpc_ctrlserver_ctrlserver_proto protoreflect.FileDescriptor var File_pkg_grpc_ctrlserver_ctrlserver_proto protoreflect.FileDescriptor
var file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc = []byte{ var file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc = []byte{
0x0a, 0x24, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73, 0x0a, 0x24, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73,
0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72,
0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73,
0x22, 0x7c, 0x0a, 0x08, 0x4d, 0x65, 0x73, 0x68, 0x4e, 0x6f, 0x64, 0x65, 0x12, 0x1c, 0x0a, 0x09, 0x22, 0x28, 0x0a, 0x0e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65,
0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01,
0x09, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x77, 0x67, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x22, 0x22, 0x0a, 0x0c, 0x47, 0x65,
0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x65,
0x77, 0x67, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, 0x73, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x32, 0x4f,
0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x6e, 0x0a, 0x0e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x74, 0x72, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72,
0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x77, 0x67, 0x48, 0x6f, 0x73, 0x74, 0x12, 0x3d, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x18, 0x2e, 0x72, 0x70,
0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x77, 0x67, 0x48, 0x6f, 0x73, 0x74, 0x22, 0x28, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65,
0x0a, 0x0e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73,
0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x42,
0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x22, 0x22, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x4d, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x6f, 0x33,
0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x22, 0x43, 0x0a, 0x0f,
0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12,
0x18, 0x0a, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c,
0x52, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73,
0x68, 0x49, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49,
0x64, 0x22, 0x29, 0x0a, 0x0d, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70,
0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20,
0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x32, 0x91, 0x01, 0x0a,
0x0e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x74, 0x72, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12,
0x3d, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x18, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e,
0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x40,
0x0a, 0x08, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x19, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73,
0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00,
0x42, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x33,
} }
var ( var (
@ -338,21 +146,16 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP() []byte {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescData return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescData
} }
var file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes = make([]protoimpl.MessageInfo, 5) var file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_pkg_grpc_ctrlserver_ctrlserver_proto_goTypes = []interface{}{ var file_pkg_grpc_ctrlserver_ctrlserver_proto_goTypes = []interface{}{
(*MeshNode)(nil), // 0: rpctypes.MeshNode (*GetMeshRequest)(nil), // 0: rpctypes.GetMeshRequest
(*GetMeshRequest)(nil), // 1: rpctypes.GetMeshRequest (*GetMeshReply)(nil), // 1: rpctypes.GetMeshReply
(*GetMeshReply)(nil), // 2: rpctypes.GetMeshReply
(*JoinMeshRequest)(nil), // 3: rpctypes.JoinMeshRequest
(*JoinMeshReply)(nil), // 4: rpctypes.JoinMeshReply
} }
var file_pkg_grpc_ctrlserver_ctrlserver_proto_depIdxs = []int32{ var file_pkg_grpc_ctrlserver_ctrlserver_proto_depIdxs = []int32{
1, // 0: rpctypes.MeshCtrlServer.GetMesh:input_type -> rpctypes.GetMeshRequest 0, // 0: rpctypes.MeshCtrlServer.GetMesh:input_type -> rpctypes.GetMeshRequest
3, // 1: rpctypes.MeshCtrlServer.JoinMesh:input_type -> rpctypes.JoinMeshRequest 1, // 1: rpctypes.MeshCtrlServer.GetMesh:output_type -> rpctypes.GetMeshReply
2, // 2: rpctypes.MeshCtrlServer.GetMesh:output_type -> rpctypes.GetMeshReply 1, // [1:2] is the sub-list for method output_type
4, // 3: rpctypes.MeshCtrlServer.JoinMesh:output_type -> rpctypes.JoinMeshReply 0, // [0:1] is the sub-list for method input_type
2, // [2:4] is the sub-list for method output_type
0, // [0:2] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name 0, // [0:0] is the sub-list for field type_name
@ -365,18 +168,6 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
} }
if !protoimpl.UnsafeEnabled { if !protoimpl.UnsafeEnabled {
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*MeshNode); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetMeshRequest); i { switch v := v.(*GetMeshRequest); i {
case 0: case 0:
return &v.state return &v.state
@ -388,7 +179,7 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
return nil return nil
} }
} }
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetMeshReply); i { switch v := v.(*GetMeshReply); i {
case 0: case 0:
return &v.state return &v.state
@ -400,30 +191,6 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
return nil return nil
} }
} }
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*JoinMeshRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*JoinMeshReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
} }
type x struct{} type x struct{}
out := protoimpl.TypeBuilder{ out := protoimpl.TypeBuilder{
@ -431,7 +198,7 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(), GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc, RawDescriptor: file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc,
NumEnums: 0, NumEnums: 0,
NumMessages: 5, NumMessages: 2,
NumExtensions: 0, NumExtensions: 0,
NumServices: 1, NumServices: 1,
}, },

View File

@ -23,7 +23,6 @@ const _ = grpc.SupportPackageIsVersion7
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type MeshCtrlServerClient interface { type MeshCtrlServerClient interface {
GetMesh(ctx context.Context, in *GetMeshRequest, opts ...grpc.CallOption) (*GetMeshReply, error) GetMesh(ctx context.Context, in *GetMeshRequest, opts ...grpc.CallOption) (*GetMeshReply, error)
JoinMesh(ctx context.Context, in *JoinMeshRequest, opts ...grpc.CallOption) (*JoinMeshReply, error)
} }
type meshCtrlServerClient struct { type meshCtrlServerClient struct {
@ -43,21 +42,11 @@ func (c *meshCtrlServerClient) GetMesh(ctx context.Context, in *GetMeshRequest,
return out, nil return out, nil
} }
func (c *meshCtrlServerClient) JoinMesh(ctx context.Context, in *JoinMeshRequest, opts ...grpc.CallOption) (*JoinMeshReply, error) {
out := new(JoinMeshReply)
err := c.cc.Invoke(ctx, "/rpctypes.MeshCtrlServer/JoinMesh", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// MeshCtrlServerServer is the server API for MeshCtrlServer service. // MeshCtrlServerServer is the server API for MeshCtrlServer service.
// All implementations must embed UnimplementedMeshCtrlServerServer // All implementations must embed UnimplementedMeshCtrlServerServer
// for forward compatibility // for forward compatibility
type MeshCtrlServerServer interface { type MeshCtrlServerServer interface {
GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error) GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error)
JoinMesh(context.Context, *JoinMeshRequest) (*JoinMeshReply, error)
mustEmbedUnimplementedMeshCtrlServerServer() mustEmbedUnimplementedMeshCtrlServerServer()
} }
@ -68,9 +57,6 @@ type UnimplementedMeshCtrlServerServer struct {
func (UnimplementedMeshCtrlServerServer) GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error) { func (UnimplementedMeshCtrlServerServer) GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetMesh not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetMesh not implemented")
} }
func (UnimplementedMeshCtrlServerServer) JoinMesh(context.Context, *JoinMeshRequest) (*JoinMeshReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method JoinMesh not implemented")
}
func (UnimplementedMeshCtrlServerServer) mustEmbedUnimplementedMeshCtrlServerServer() {} func (UnimplementedMeshCtrlServerServer) mustEmbedUnimplementedMeshCtrlServerServer() {}
// UnsafeMeshCtrlServerServer may be embedded to opt out of forward compatibility for this service. // UnsafeMeshCtrlServerServer may be embedded to opt out of forward compatibility for this service.
@ -102,24 +88,6 @@ func _MeshCtrlServer_GetMesh_Handler(srv interface{}, ctx context.Context, dec f
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
func _MeshCtrlServer_JoinMesh_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(JoinMeshRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(MeshCtrlServerServer).JoinMesh(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/rpctypes.MeshCtrlServer/JoinMesh",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(MeshCtrlServerServer).JoinMesh(ctx, req.(*JoinMeshRequest))
}
return interceptor(ctx, in, info, handler)
}
// MeshCtrlServer_ServiceDesc is the grpc.ServiceDesc for MeshCtrlServer service. // MeshCtrlServer_ServiceDesc is the grpc.ServiceDesc for MeshCtrlServer service.
// It's only intended for direct use with grpc.RegisterService, // It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy) // and not to be introspected or modified (even as a copy)
@ -131,10 +99,6 @@ var MeshCtrlServer_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetMesh", MethodName: "GetMesh",
Handler: _MeshCtrlServer_GetMesh_Handler, Handler: _MeshCtrlServer_GetMesh_Handler,
}, },
{
MethodName: "JoinMesh",
Handler: _MeshCtrlServer_JoinMesh_Handler,
},
}, },
Streams: []grpc.StreamDesc{}, Streams: []grpc.StreamDesc{},
Metadata: "pkg/grpc/ctrlserver/ctrlserver.proto", Metadata: "pkg/grpc/ctrlserver/ctrlserver.proto",

View File

@ -1,8 +1,8 @@
package sync package sync
import ( import (
"io"
"math/rand" "math/rand"
"sync"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/conf"
@ -12,7 +12,7 @@ import (
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/wgmesh/pkg/mesh"
) )
// Syncer: picks random nodes from the mesh // Syncer: picks random nodes from the meshs
type Syncer interface { type Syncer interface {
Sync(meshId string) error Sync(meshId string) error
SyncMeshes() error SyncMeshes() error
@ -25,71 +25,98 @@ type SyncerImpl struct {
syncCount int syncCount int
cluster conn.ConnCluster cluster conn.ConnCluster
conf *conf.WgMeshConfiguration conf *conf.WgMeshConfiguration
lastSync uint64
} }
// Sync: Sync random nodes // Sync: Sync random nodes
func (s *SyncerImpl) Sync(meshId string) error { func (s *SyncerImpl) Sync(meshId string) error {
if !s.manager.HasChanges(meshId) && s.infectionCount == 0 { // Self can be nil if the node is removed
self, _ := s.manager.GetSelf(meshId)
s.manager.GetMesh(meshId).Prune()
if self != nil && self.GetType() == conf.PEER_ROLE && !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
logging.Log.WriteInfof("No changes for %s", meshId) logging.Log.WriteInfof("No changes for %s", meshId)
return nil return nil
} }
logging.Log.WriteInfof("UPDATING WG CONF") before := time.Now()
s.manager.GetRouteManager().UpdateRoutes()
if s.manager.HasChanges(meshId) { publicKey := s.manager.GetPublicKey()
err := s.manager.ApplyConfig()
if err != nil { logging.Log.WriteInfof(publicKey.String())
logging.Log.WriteInfof("Failed to update config %w", err)
nodeNames := s.manager.GetMesh(meshId).GetPeers()
if self != nil {
nodeNames = lib.Filter(nodeNames, func(s string) bool {
return s != mesh.NodeID(self)
})
}
var gossipNodes []string
// Clients always pings its peer for configuration
if self != nil && self.GetType() == conf.CLIENT_ROLE {
keyFunc := lib.HashString
bucketFunc := lib.HashString
neighbour := lib.ConsistentHash(nodeNames, publicKey.String(), keyFunc, bucketFunc)
gossipNodes = make([]string, 1)
gossipNodes[0] = neighbour
} else {
neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
gossipNodes = lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
gossipNodes[len(gossipNodes)-1] = s.cluster.GetInterCluster(nodeNames, publicKey.String())
} }
} }
nodeNames := s.manager.GetMesh(meshId).GetNodeIds() var succeeded bool = false
self, err := s.manager.GetSelf(meshId) // Do this synchronously to conserve bandwidth
for _, node := range gossipNodes {
correspondingPeer := s.manager.GetNode(meshId, node)
if err != nil { if correspondingPeer == nil {
return err logging.Log.WriteErrorf("node %s does not exist", node)
}
err := s.requester.SyncMesh(meshId, correspondingPeer)
if err == nil || err == io.EOF {
succeeded = true
} else {
// If the synchronisation operation has failed them mark a gravestone
// preventing the peer from being re-contacted until it has updated
// itself
s.manager.GetMesh(meshId).Mark(node)
}
} }
neighbours := s.cluster.GetNeighbours(nodeNames, self.GetHostEndpoint())
randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
for _, node := range randomSubset {
logging.Log.WriteInfof("Random node: %s", node)
}
before := time.Now()
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
logging.Log.WriteInfof("Sending to random cluster")
interCluster := s.cluster.GetInterCluster(nodeNames, self.GetHostEndpoint())
randomSubset = append(randomSubset, interCluster)
}
var waitGroup sync.WaitGroup
for index := range randomSubset {
waitGroup.Add(1)
go func(i int) error {
defer waitGroup.Done()
err := s.requester.SyncMesh(meshId, randomSubset[i])
return err
}(index)
}
waitGroup.Wait()
s.syncCount++ s.syncCount++
logging.Log.WriteInfof("SYNC TIME: %v", time.Since(before)) logging.Log.WriteInfof("SYNC TIME: %v", time.Since(before))
logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount) logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount)
s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount) s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount)
// Check if any changes have occurred and trigger callbacks if !succeeded {
// if changes have occurred. // If could not gossip with anyone then repeat.
// return s.manager.GetMonitor().Trigger() s.infectionCount++
}
s.manager.GetMesh(meshId).SaveChanges()
s.lastSync = uint64(time.Now().Unix())
logging.Log.WriteInfof("UPDATING WG CONF")
err := s.manager.ApplyConfig()
if err != nil {
logging.Log.WriteInfof("Failed to update config %w", err)
}
return nil return nil
} }
@ -99,7 +126,7 @@ func (s *SyncerImpl) SyncMeshes() error {
err := s.Sync(meshId) err := s.Sync(meshId)
if err != nil { if err != nil {
return err logging.Log.WriteErrorf(err.Error())
} }
} }

View File

@ -17,31 +17,20 @@ type SyncErrorHandlerImpl struct {
meshManager mesh.MeshManager meshManager mesh.MeshManager
} }
func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint string) bool { func (s *SyncErrorHandlerImpl) handleFailed(meshId string, nodeId string) bool {
mesh := s.meshManager.GetMesh(meshId) mesh := s.meshManager.GetMesh(meshId)
mesh.Mark(nodeId)
if mesh == nil {
return false
}
// self, err := s.meshManager.GetSelf(meshId)
// if err != nil {
// return false
// }
// mesh.DecrementHealth(endpoint, self.GetHostEndpoint())
return true return true
} }
func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error) bool { func (s *SyncErrorHandlerImpl) Handle(meshId string, nodeId string, err error) bool {
errStatus, _ := status.FromError(err) errStatus, _ := status.FromError(err)
logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message()) logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message())
switch errStatus.Code() { switch errStatus.Code() {
case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound: case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound:
return s.incrementFailedCount(meshId, endpoint) return s.handleFailed(meshId, nodeId)
} }
return false return false

View File

@ -15,7 +15,7 @@ import (
// SyncRequester: coordinates the syncing of meshes // SyncRequester: coordinates the syncing of meshes
type SyncRequester interface { type SyncRequester interface {
GetMesh(meshId string, ifName string, port int, endPoint string) error GetMesh(meshId string, ifName string, port int, endPoint string) error
SyncMesh(meshid string, endPoint string) error SyncMesh(meshid string, meshNode mesh.MeshNode) error
} }
type SyncRequesterImpl struct { type SyncRequesterImpl struct {
@ -50,15 +50,14 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endP
err = s.server.MeshManager.AddMesh(&mesh.AddMeshParams{ err = s.server.MeshManager.AddMesh(&mesh.AddMeshParams{
MeshId: meshId, MeshId: meshId,
DevName: ifName,
WgPort: port, WgPort: port,
MeshBytes: reply.Mesh, MeshBytes: reply.Mesh,
}) })
return err return err
} }
func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error { func (s *SyncRequesterImpl) handleErr(meshId, pubKey string, err error) error {
ok := s.errorHdlr.Handle(meshId, endpoint, err) ok := s.errorHdlr.Handle(meshId, pubKey, err)
if ok { if ok {
return nil return nil
@ -68,7 +67,10 @@ func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error
} }
// SyncMesh: Proactively send a sync request to the other mesh // SyncMesh: Proactively send a sync request to the other mesh
func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error { func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) error {
endpoint := meshNode.GetHostEndpoint()
pubKey, _ := meshNode.GetPublicKey()
peerConnection, err := s.server.ConnectionManager.GetConnection(endpoint) peerConnection, err := s.server.ConnectionManager.GetConnection(endpoint)
if err != nil { if err != nil {
@ -97,7 +99,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
err = s.syncMesh(mesh, ctx, c) err = s.syncMesh(mesh, ctx, c)
if err != nil { if err != nil {
return s.handleErr(meshId, endpoint, err) return s.handleErr(meshId, pubKey.String(), err)
} }
logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId) logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId)

View File

@ -5,28 +5,14 @@ import (
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/lib"
) )
// SyncScheduler: Loops through all nodes in the mesh and runs a schedule to
// sync each event
type SyncScheduler interface {
Run() error
Stop() error
}
// SyncSchedulerImpl scheduler for sync scheduling
type SyncSchedulerImpl struct {
quit chan struct{}
server *ctrlserver.MeshCtrlServer
syncer Syncer
}
// Run implements SyncScheduler. // Run implements SyncScheduler.
func syncFunction(syncer Syncer) lib.TimerFunc { func syncFunction(syncer Syncer) lib.TimerFunc {
return func() error { return func() error {
return syncer.SyncMeshes() syncer.SyncMeshes()
return nil
} }
} }
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester) *lib.Timer { func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester, syncer Syncer) *lib.Timer {
syncer := NewSyncer(s.MeshManager, s.Conf, syncRequester)
return lib.NewTimer(syncFunction(syncer), int(s.Conf.SyncRate)) return lib.NewTimer(syncFunction(syncer), int(s.Conf.SyncRate))
} }

View File

@ -64,11 +64,11 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error
syncer = mesh.GetSyncer() syncer = mesh.GetSyncer()
} else if meshId != in.MeshId { } else if meshId != in.MeshId {
return errors.New("Differing MeshIDs") return errors.New("differing meshids")
} }
if syncer == nil { if syncer == nil {
return errors.New("Syncer should not be nil") return errors.New("syncer should not be nil")
} }
msg, moreMessages := syncer.GenerateMessage() msg, moreMessages := syncer.GenerateMessage()

View File

@ -1,12 +1,14 @@
package timestamp package timer
import ( import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
) )
func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer { func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
timerFunc := func() error { timerFunc := func() error {
logging.Log.WriteInfof("Updated Timestamp")
return ctrlServer.MeshManager.UpdateTimeStamp() return ctrlServer.MeshManager.UpdateTimeStamp()
} }

View File

@ -2,8 +2,8 @@ package wg
type WgInterfaceManipulatorStub struct{} type WgInterfaceManipulatorStub struct{}
func (i *WgInterfaceManipulatorStub) CreateInterface(params *CreateInterfaceParams) error { func (i *WgInterfaceManipulatorStub) CreateInterface(port int) (string, error) {
return nil return "", nil
} }
func (i *WgInterfaceManipulatorStub) AddAddress(ifName string, addr string) error { func (i *WgInterfaceManipulatorStub) AddAddress(ifName string, addr string) error {

View File

@ -1,5 +1,7 @@
package wg package wg
import "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
type WgError struct { type WgError struct {
msg string msg string
} }
@ -8,14 +10,9 @@ func (m *WgError) Error() string {
return m.msg return m.msg
} }
type CreateInterfaceParams struct {
IfName string
Port int
}
type WgInterfaceManipulator interface { type WgInterfaceManipulator interface {
// CreateInterface creates a WireGuard interface // CreateInterface creates a WireGuard interface
CreateInterface(params *CreateInterfaceParams) error CreateInterface(port int, privateKey *wgtypes.Key) (string, error)
// AddAddress adds an address to the given interface name // AddAddress adds an address to the given interface name
AddAddress(ifName string, addr string) error AddAddress(ifName string, addr string) error
// RemoveInterface removes the specified interface // RemoveInterface removes the specified interface

View File

@ -1,6 +1,8 @@
package wg package wg
import ( import (
"crypto"
"crypto/rand"
"fmt" "fmt"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/lib"
@ -13,40 +15,47 @@ type WgInterfaceManipulatorImpl struct {
client *wgctrl.Client client *wgctrl.Client
} }
const hashLength = 6
// CreateInterface creates a WireGuard interface // CreateInterface creates a WireGuard interface
func (m *WgInterfaceManipulatorImpl) CreateInterface(params *CreateInterfaceParams) error { func (m *WgInterfaceManipulatorImpl) CreateInterface(port int, privKey *wgtypes.Key) (string, error) {
rtnl, err := lib.NewRtNetlinkConfig() rtnl, err := lib.NewRtNetlinkConfig()
if err != nil { if err != nil {
return fmt.Errorf("failed to access link: %w", err) return "", fmt.Errorf("failed to access link: %w", err)
} }
defer rtnl.Close() defer rtnl.Close()
err = rtnl.CreateLink(params.IfName) randomBuf := make([]byte, 32)
_, err = rand.Read(randomBuf)
if err != nil { if err != nil {
return fmt.Errorf("failed to create link: %w", err) return "", err
} }
privateKey, err := wgtypes.GeneratePrivateKey() md5 := crypto.MD5.New().Sum(randomBuf)
md5Str := fmt.Sprintf("wg%x", md5)[:hashLength]
err = rtnl.CreateLink(md5Str)
if err != nil { if err != nil {
return fmt.Errorf("failed to create private key: %w", err) return "", fmt.Errorf("failed to create link: %w", err)
} }
var cfg wgtypes.Config = wgtypes.Config{ var cfg wgtypes.Config = wgtypes.Config{
PrivateKey: &privateKey, PrivateKey: privKey,
ListenPort: &params.Port, ListenPort: &port,
} }
err = m.client.ConfigureDevice(params.IfName, cfg) err = m.client.ConfigureDevice(md5Str, cfg)
if err != nil { if err != nil {
return fmt.Errorf("failed to configure dev: %w", err) m.RemoveInterface(md5Str)
return "", fmt.Errorf("failed to configure dev: %w", err)
} }
logging.Log.WriteInfof("ip link set up dev %s type wireguard", params.IfName) logging.Log.WriteInfof("ip link set up dev %s type wireguard", md5Str)
return nil return md5Str, nil
} }
// Add an address to the given interface // Add an address to the given interface

View File

@ -0,0 +1,92 @@
// Package to convert an IPV6 addres into 8 words
package what8words
import (
"bufio"
"bytes"
"fmt"
"net"
"os"
"strings"
)
type What8Words struct {
words []string
}
// Convert implements What8Words.
func (w *What8Words) Convert(ipStr string) (string, error) {
ip, ipNet, err := net.ParseCIDR(ipStr)
if err != nil {
return "", err
}
ip16 := ip.To16()
if ip16 == nil {
return "", fmt.Errorf("cannot convert ip to 16 representation")
}
representation := make([]string, 7)
for i := 2; i <= net.IPv6len-2; i += 2 {
word1 := w.words[ip16[i]]
word2 := w.words[ip16[i+1]]
representation[i/2-1] = fmt.Sprintf("%s-%s", word1, word2)
}
prefixSize, _ := ipNet.Mask.Size()
return strings.Join(representation[:prefixSize/16-1], "."), nil
}
// Convert implements What8Words.
func (w *What8Words) ConvertIdentifier(ipStr string) (string, error) {
ip, err := w.Convert(ipStr)
if err != nil {
return "", err
}
constituents := strings.Split(ip, ".")
return strings.Join(constituents[3:], "."), nil
}
func NewWhat8Words(pathToWords string) (*What8Words, error) {
words, err := ReadWords(pathToWords)
if err != nil {
return nil, err
}
return &What8Words{words: words}, nil
}
// ReadWords reads the what 8 words txt file
func ReadWords(wordFile string) ([]string, error) {
f, err := os.ReadFile(wordFile)
if err != nil {
return nil, err
}
words := make([]string, 257)
reader := bufio.NewScanner(bytes.NewReader(f))
counter := 0
for reader.Scan() && counter <= len(words) {
text := reader.Text()
words[counter] = text
counter++
if reader.Err() != nil {
return nil, reader.Err()
}
}
return words, nil
}