1
0
forked from extern/smegmesh

Compare commits

...

37 Commits

Author SHA1 Message Date
dae9cd31a1 Merge pull request #50 from tim-beatham/50-give-client-ability-to-bridge-meshes
50-give-client-ability-to-bridge-meshes
2023-12-08 23:58:32 +00:00
f855f53fbf 50-give-client-ability-to-bridge-meshes
Client can act as a route bridging meshes. Cient send keepalives
to all of it's peers in the different meshes act as a bridge between
the meshes
2023-12-08 23:56:07 +00:00
52feb5767b Merge pull request #48 from tim-beatham/47-default-routing
47 default routing
2023-12-08 20:03:45 +00:00
815c4484ee 47-default-routing
Implemented default routing and improved size of gossip. Using 64 bit
hash funciton to identify vector.
2023-12-08 20:02:57 +00:00
0058c9f4c9 47-default-routing
Implementing default routing so that all traffic goes out of an
exit point.
2023-12-08 11:49:24 +00:00
92c0805275 Merge pull request #46 from tim-beatham/45-use-statistical-testing
45 use statistical testing
2023-12-07 18:20:25 +00:00
661fb0d54c 45-use-statistical-testing
Keepalive is based on per mesh and not per node.
Using total ordering mechanism similar to paxos to elect a leader
if leader doesn't update it's timestamp within 3 * keepAlive then
give the leader a gravestone and elect the next leader.
Leader is bassed on lexicographically ordered public key.
2023-12-07 18:18:13 +00:00
64885f1055 45-use-statistical-testing
Using statistical testing to test whether the node has failed.
2023-12-07 01:44:54 +00:00
2169f7796f Merge pull request #44 from tim-beatham/43-gravestones
43-use-gravestones
2023-12-06 22:46:05 +00:00
a3ceff019d 43-use-gravestones
Change of approach from keepalive to a noiseless protocol
2023-12-06 22:45:04 +00:00
b78d96986c Merge pull request #42 from tim-beatham/41-bugfix-fluctuating-ips
41 bugfix fluctuating ips
2023-12-06 14:37:14 +00:00
1b18d89c9f 41-bugfix-fluctuating-ips
Fluctuating ips creating hub and spoke.
2023-12-05 02:00:16 +00:00
245a2c5f58 41-bugfix-fluctuating-ips
If the node is a peer then add the client in the WG
configuration.
2023-12-04 17:40:24 +00:00
c40f7510b8 41-bugfix-fluctuating-ips
IPs of clients fluctuating because there isn't a strict order on
clients. Client's need to be processed before the peers.
2023-12-04 17:32:50 +00:00
78d748770c BUGIX Hash client by public key 2023-12-04 17:13:51 +00:00
0ff2a8eef9 BUGFIX: Allowed IPs fluctuating 2023-12-04 17:11:37 +00:00
fd7bd80485 BUGFIX
Don't get device each time it is an expensive operation.
2023-12-04 16:40:15 +00:00
3ef1b68ba5 BUGFIX: Hashing datastore to work out changes
Changed hashing implementation to work out if there are changes
in the data store
2023-11-30 15:58:26 +00:00
b9ba836ae3 Merge pull request #40 from tim-beatham/39-implement-two-phase-map
39-implement-two-phase-map
2023-11-30 02:03:36 +00:00
650901aba1 39-implement-two-phase-map
Implemented my own two phase map based on vector clocks
2023-11-30 02:02:38 +00:00
a82eab0686 Bugfix
Added replace peers so that deleted nodes are automatically removed
2023-11-28 14:43:55 +00:00
32e7e4c7df main
Bugfix. Fixed issue where consistent hashing was not working.
2023-11-28 14:42:09 +00:00
1fae0a6c2c Merge pull request #37 from tim-beatham/36-add-route-path-into-route-object
36-add-route-path-into-route-object
2023-11-27 21:03:56 +00:00
d8e156f13f 36-add-route-path-into-route-object
Added the route path into the route object so that we can
see what meshes packets are routed across.
2023-11-27 18:55:41 +00:00
3fca49a1c9 Merge pull request #35 from tim-beatham/34-fix-routing
34 fix routing
2023-11-27 16:05:06 +00:00
a2517a1e72 34-fix-routing
- Added mesh-to-mesh routing of hop count > 1
- If there is a tie-breaker with respect to the hop-count use consistent
hashing to determine the route to take based on the public key.
2023-11-27 15:56:30 +00:00
aef8b59f22 32-fix-routing
Flooding routes into other meshes a bit like BGP.
2023-11-25 03:15:58 +00:00
4030d17b41 Fixed routing issue 2023-11-24 17:49:06 +00:00
73db65660b Merge pull request #33 from tim-beatham/32-incorporate-dns
32-incorporate-dns
2023-11-24 15:05:40 +00:00
d1a74a7b95 32-incorporate-dns
Incorporated a DNS server. A DNS server can be run to resolve host
names.
2023-11-24 15:04:07 +00:00
f28ed8260d Merge pull request #30 from tim-beatham/29-only-ping-clients-who-have-updated-their-config
29-only-ping-clients-who-have-updated-their-config
2023-11-24 12:39:14 +00:00
2c406718df 29-only-ping-clients-who-have-updated-their-config
Only consider clients who have updated their config when synchronising
with peers. Consider a dead time where we don't have a handshake and
a prune time when we remove them from the WireGuard configuration.
2023-11-24 12:37:54 +00:00
11b003b549 Merge pull request #28 from tim-beatham/27-remove-client-grpc-endpoint
27-remove-client-grpc-endpoint
2023-11-24 12:08:42 +00:00
7be11dbaa3 27-remove-client-grpc-endpoint
Removed a client's grpc endpoint value. Client's aren't publicly
available so there is no need for a client's gRPC endpoint.
Also changed a node ID's to their public key. A node id's public
address is an issue for mobility of clients as their endpoint
is subject to change
2023-11-24 12:07:03 +00:00
e7ac8c5542 Only updating WireGuard config if node exists 2023-11-22 13:08:02 +00:00
09c64c4628 Fixed container file 2023-11-22 12:45:01 +00:00
2c4f18f52b Merge pull request #26 from tim-beatham/25-modify-code-to-use-public-api
25-modify-code-to-use-public-api
2023-11-22 10:42:48 +00:00
50 changed files with 2525 additions and 842 deletions

View File

@ -8,4 +8,5 @@ RUN apt-get update && apt-get install -y \
tmux \
vim
WORKDIR /wgmesh
RUN go mod tidy
RUN go build -o /usr/local/bin ./...

18
cmd/dns/main.go Normal file
View File

@ -0,0 +1,18 @@
package main
import (
"log"
smegdns "github.com/tim-beatham/wgmesh/pkg/dns"
)
func main() {
server, err := smegdns.NewDns(53)
if err != nil {
log.Fatal(err.Error())
}
defer server.Close()
server.Listen()
}

View File

@ -8,7 +8,9 @@ import (
"time"
"github.com/akamensky/argparse"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
)
@ -93,9 +95,13 @@ func getMesh(client *ipcRpc.Client, meshId string) {
fmt.Println("Control Endpoint: " + node.HostEndpoint)
fmt.Println("WireGuard Endpoint: " + node.WgEndpoint)
fmt.Println("Wg IP: " + node.WgHost)
fmt.Println(fmt.Sprintf("Timestamp: %s", time.Unix(node.Timestamp, 0).String()))
fmt.Printf("Timestamp: %s", time.Unix(node.Timestamp, 0).String())
advertiseRoutes := strings.Join(node.Routes, ",")
mapFunc := func(r ctrlserver.MeshRoute) string {
return r.Destination
}
advertiseRoutes := strings.Join(lib.Map(node.Routes, mapFunc), ",")
fmt.Printf("Routes: %s\n", advertiseRoutes)
fmt.Println("---")

View File

@ -13,7 +13,7 @@ import (
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/robin"
"github.com/tim-beatham/wgmesh/pkg/sync"
"github.com/tim-beatham/wgmesh/pkg/timestamp"
timer "github.com/tim-beatham/wgmesh/pkg/timers"
"golang.zx2c4.com/wireguard/wgctrl"
)
@ -45,20 +45,25 @@ func main() {
var robinRpc robin.WgRpc
var robinIpc robin.IpcHandler
var syncProvider sync.SyncServiceImpl
var syncRequester sync.SyncRequester
var syncer sync.Syncer
ctrlServerParams := ctrlserver.NewCtrlServerParams{
Conf: conf,
CtrlProvider: &robinRpc,
SyncProvider: &syncProvider,
Client: client,
OnDelete: func(mp mesh.MeshProvider) {
syncer.SyncMeshes()
},
}
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
syncProvider.Server = ctrlServer
syncRequester := sync.NewSyncRequester(ctrlServer)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester)
timestampScheduler := timestamp.NewTimestampScheduler(ctrlServer)
pruneScheduler := mesh.NewPruner(ctrlServer.MeshManager, *conf)
syncRequester = sync.NewSyncRequester(ctrlServer)
syncer = sync.NewSyncer(ctrlServer.MeshManager, conf, syncRequester)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, syncer)
keepAlive := timer.NewTimestampScheduler(ctrlServer)
robinIpcParams := robin.RobinIpcParams{
CtrlServer: ctrlServer,
@ -76,13 +81,12 @@ func main() {
go ipc.RunIpcHandler(&robinIpc)
go syncScheduler.Run()
go timestampScheduler.Run()
go pruneScheduler.Run()
go keepAlive.Run()
closeResources := func() {
logging.Log.WriteInfof("Closing resources")
syncScheduler.Stop()
timestampScheduler.Stop()
keepAlive.Stop()
ctrlServer.Close()
client.Close()
}

View File

@ -30,15 +30,14 @@ func (s *SmegServer) routeToApiRoute(meshNode ctrlserver.MeshNode) []Route {
routes := make([]Route, len(meshNode.Routes))
for index, route := range meshNode.Routes {
word, err := s.words.Convert(route)
if err != nil {
fmt.Println(err.Error())
if route.Path == nil {
route.Path = make([]string, 0)
}
routes[index] = Route{
Prefix: route,
RouteId: word,
Prefix: route.Destination,
Path: route.Path,
}
}
@ -47,7 +46,7 @@ func (s *SmegServer) routeToApiRoute(meshNode ctrlserver.MeshNode) []Route {
func (s *SmegServer) meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode {
if meshNode.Routes == nil {
meshNode.Routes = make([]string, 0)
meshNode.Routes = make([]ctrlserver.MeshRoute, 0)
}
alias := meshNode.Alias

View File

@ -1,8 +1,8 @@
package api
type Route struct {
RouteId string `json:"routeId"`
Prefix string `json:"prefix"`
Prefix string `json:"prefix"`
Path []string `json:"path"`
}
type SmegNode struct {

View File

@ -1,9 +1,10 @@
package crdt
package automerge
import (
"errors"
"fmt"
"net"
"slices"
"strings"
"time"
@ -35,11 +36,11 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
panic("node must be of type *MeshNodeCrdt")
}
crdt.Routes = make(map[string]interface{})
crdt.Routes = make(map[string]Route)
crdt.Services = make(map[string]string)
crdt.Timestamp = time.Now().Unix()
c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt)
c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt)
}
func (c *CrdtMeshManager) isPeer(nodeId string) bool {
@ -58,11 +59,30 @@ func (c *CrdtMeshManager) isPeer(nodeId string) bool {
return nodeType.Str() == string(conf.PEER_ROLE)
}
// isAlive: checks that the node's configuration has been updated
// since the rquired keep alive time
func (c *CrdtMeshManager) isAlive(nodeId string) bool {
node, err := c.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return false
}
timestamp, err := node.Map().Get("timestamp")
if err != nil || timestamp.Kind() != automerge.KindInt64 {
return false
}
keepAliveTime := timestamp.Int64()
return (time.Now().Unix() - keepAliveTime) < int64(c.conf.DeadTime)
}
func (c *CrdtMeshManager) GetPeers() []string {
keys, _ := c.doc.Path("nodes").Map().Keys()
keys = lib.Filter(keys, func(s string) bool {
return c.isPeer(s)
keys = lib.Filter(keys, func(publicKey string) bool {
return c.isPeer(publicKey) && c.isAlive(publicKey)
})
return keys
@ -270,7 +290,8 @@ func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
return fmt.Errorf("AddService: services property does not exist in node")
}
return service.Map().Set(key, value)
err = service.Map().Set(key, value)
return err
}
func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
@ -300,7 +321,7 @@ func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
}
// AddRoutes: adds routes to the specific nodeId
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
logging.Log.WriteInfof("Adding route to %s", nodeId)
@ -319,7 +340,32 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
}
for _, route := range routes {
err = routeMap.Map().Set(route, struct{}{})
prevRoute, err := routeMap.Map().Get(route.GetDestination().String())
if prevRoute.Kind() == automerge.KindVoid && err != nil {
path, err := prevRoute.Map().Get("path")
if err != nil {
return err
}
if path.Kind() != automerge.KindList {
return fmt.Errorf("path is not a list")
}
pathStr, err := automerge.As[[]string](path)
if err != nil {
return err
}
slices.Equal(route.GetPath(), pathStr)
}
err = routeMap.Map().Set(route.GetDestination().String(), Route{
Destination: route.GetDestination().String(),
Path: route.GetPath(),
})
if err != nil {
return err
@ -328,6 +374,80 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
return nil
}
func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
return nil, err
}
if nodeVal.Kind() != automerge.KindMap {
return nil, fmt.Errorf("node does not exist")
}
routeMap, err := nodeVal.Map().Get("routes")
if err != nil {
return nil, err
}
if routeMap.Kind() != automerge.KindMap {
return nil, fmt.Errorf("node %s is not a map", nodeId)
}
routes, err := automerge.As[map[string]Route](routeMap)
return lib.MapValues(routes), err
}
func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) {
node, err := m.GetNode(targetNode)
if err != nil {
return nil, err
}
routes := make(map[string]mesh.Route)
// Add routes that the node directly has
for _, route := range node.GetRoutes() {
routes[route.GetDestination().String()] = route
}
// Work out the other routes in the mesh
for _, node := range m.GetPeers() {
nodeRoutes, err := m.getRoutes(node)
if err != nil {
return nil, err
}
for _, route := range nodeRoutes {
otherRoute, ok := routes[route.GetDestination().String()]
hopCount := route.GetHopCount()
if node != targetNode {
hopCount += 1
}
if !ok || route.GetHopCount()+1 < otherRoute.GetHopCount() {
routes[route.GetDestination().String()] = &Route{
Destination: route.GetDestination().String(),
Path: append(route.Path, m.GetMeshId()),
}
}
}
}
return routes, nil
}
func (m *CrdtMeshManager) RemoveNode(nodeId string) error {
err := m.doc.Path("nodes").Map().Delete(nodeId)
return err
}
// DeleteRoutes deletes the specified routes
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -357,54 +477,54 @@ func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
return NewAutomergeSync(m)
}
func (m *CrdtMeshManager) Prune(pruneTime int) error {
nodes, err := m.doc.Path("nodes").Get()
func (m *CrdtMeshManager) Prune() error {
// nodes, err := m.doc.Path("nodes").Get()
if err != nil {
return err
}
// if err != nil {
// return err
// }
if nodes.Kind() != automerge.KindMap {
return errors.New("node must be a map")
}
// if nodes.Kind() != automerge.KindMap {
// return errors.New("node must be a map")
// }
values, err := nodes.Map().Values()
// values, err := nodes.Map().Values()
if err != nil {
return err
}
// if err != nil {
// return err
// }
deletionNodes := make([]string, 0)
// deletionNodes := make([]string, 0)
for nodeId, node := range values {
if node.Kind() != automerge.KindMap {
return errors.New("node must be a map")
}
// for nodeId, node := range values {
// if node.Kind() != automerge.KindMap {
// return errors.New("node must be a map")
// }
nodeMap := node.Map()
// nodeMap := node.Map()
timeStamp, err := nodeMap.Get("timestamp")
// timeStamp, err := nodeMap.Get("timestamp")
if err != nil {
return err
}
// if err != nil {
// return err
// }
if timeStamp.Kind() != automerge.KindInt64 {
return errors.New("timestamp is not int64")
}
// if timeStamp.Kind() != automerge.KindInt64 {
// return errors.New("timestamp is not int64")
// }
timeValue := timeStamp.Int64()
nowValue := time.Now().Unix()
// timeValue := timeStamp.Int64()
// nowValue := time.Now().Unix()
if nowValue-timeValue >= int64(pruneTime) {
deletionNodes = append(deletionNodes, nodeId)
}
}
// if nowValue-timeValue >= int64(pruneTime) {
// deletionNodes = append(deletionNodes, nodeId)
// }
// }
for _, node := range deletionNodes {
logging.Log.WriteInfof("Pruning %s", node)
nodes.Map().Delete(node)
}
// for _, node := range deletionNodes {
// logging.Log.WriteInfof("Pruning %s", node)
// nodes.Map().Delete(node)
// }
return nil
}
@ -429,7 +549,6 @@ func (m *MeshNodeCrdt) GetWgHost() *net.IPNet {
_, ipnet, err := net.ParseCIDR(m.WgHost)
if err != nil {
logging.Log.WriteErrorf("Cannot parse WgHost %s", err.Error())
return nil
}
@ -440,8 +559,13 @@ func (m *MeshNodeCrdt) GetTimeStamp() int64 {
return m.Timestamp
}
func (m *MeshNodeCrdt) GetRoutes() []string {
return lib.MapKeys(m.Routes)
func (m *MeshNodeCrdt) GetRoutes() []mesh.Route {
return lib.Map(lib.MapValues(m.Routes), func(r Route) mesh.Route {
return &Route{
Destination: r.Destination,
Path: r.Path,
}
})
}
func (m *MeshNodeCrdt) GetDescription() string {
@ -452,7 +576,6 @@ func (m *MeshNodeCrdt) GetIdentifier() string {
ipv6 := m.WgHost[:len(m.WgHost)-4]
constituents := strings.Split(ipv6, ":")
logging.Log.WriteInfof(ipv6)
constituents = constituents[4:]
return strings.Join(constituents, ":")
}
@ -497,3 +620,16 @@ func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
return nodes
}
func (r *Route) GetDestination() *net.IPNet {
_, ipnet, _ := net.ParseCIDR(r.Destination)
return ipnet
}
func (r *Route) GetHopCount() int {
return len(r.Path)
}
func (r *Route) GetPath() []string {
return r.Path
}

View File

@ -1,4 +1,4 @@
package crdt
package automerge
import (
"github.com/automerge/automerge-go"

View File

@ -1,4 +1,4 @@
package crdt
package automerge
import (
"slices"

View File

@ -1,4 +1,4 @@
package crdt
package automerge
import (
"fmt"
@ -28,17 +28,23 @@ type MeshNodeFactory struct {
func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode {
hostName := f.getAddress(params)
grpcEndpoint := fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort)
if f.Config.Role == conf.CLIENT_ROLE {
grpcEndpoint = "-"
}
return &MeshNodeCrdt{
HostEndpoint: fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort),
HostEndpoint: grpcEndpoint,
PublicKey: params.PublicKey.String(),
WgEndpoint: fmt.Sprintf("%s:%d", hostName, params.WgPort),
WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()),
// Always set the routes as empty.
// Routes handled by external component
Routes: map[string]interface{}{},
Routes: make(map[string]Route),
Description: "",
Alias: "",
Type: string(params.Role),
Type: string(f.Config.Role),
}
}
@ -51,7 +57,13 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string
} else if len(f.Config.Endpoint) != 0 {
hostName = f.Config.Endpoint
} else {
ip, err := lib.GetPublicIP()
ipFunc := lib.GetPublicIP
if f.Config.IPDiscovery == conf.DNS_IP_DISCOVERY {
ipFunc = lib.GetOutboundIP
}
ip, err := ipFunc()
if err != nil {
return ""

View File

@ -1,17 +1,23 @@
package crdt
package automerge
// Route: Represents a CRDT of the given route
type Route struct {
Destination string `automerge:"destination"`
Path []string `automerge:"path"`
}
// MeshNodeCrdt: Represents a CRDT for a mesh nodes
type MeshNodeCrdt struct {
HostEndpoint string `automerge:"hostEndpoint"`
WgEndpoint string `automerge:"wgEndpoint"`
PublicKey string `automerge:"publicKey"`
WgHost string `automerge:"wgHost"`
Timestamp int64 `automerge:"timestamp"`
Routes map[string]interface{} `automerge:"routes"`
Alias string `automerge:"alias"`
Description string `automerge:"description"`
Services map[string]string `automerge:"services"`
Type string `automerge:"type"`
HostEndpoint string `automerge:"hostEndpoint"`
WgEndpoint string `automerge:"wgEndpoint"`
PublicKey string `automerge:"publicKey"`
WgHost string `automerge:"wgHost"`
Timestamp int64 `automerge:"timestamp"`
Routes map[string]Route `automerge:"routes"`
Alias string `automerge:"alias"`
Description string `automerge:"description"`
Services map[string]string `automerge:"services"`
Type string `automerge:"type"`
}
// MeshCrdt: Represents the mesh network as a whole

View File

@ -23,6 +23,13 @@ const (
CLIENT_ROLE NodeType = "client"
)
type IPDiscovery string
const (
PUBLIC_IP_DISCOVERY = "public"
DNS_IP_DISCOVERY = "dns"
)
type WgMeshConfiguration struct {
// CertificatePath is the path to the certificate to use in mTLS
CertificatePath string `yaml:"certificatePath"`
@ -35,8 +42,13 @@ type WgMeshConfiguration struct {
SkipCertVerification bool `yaml:"skipCertVerification"`
// Port to run the GrpcServer on
GrpcPort string `yaml:"gRPCPort"`
// IPDIscovery: how to discover your IP if not specified. Use DNS server 8.8.8.8 or
// use public IP discovery library
IPDiscovery IPDiscovery `yaml:"ipDiscovery"`
// AdvertiseRoutes advertises other meshes if the node is in multiple meshes
AdvertiseRoutes bool `yaml:"advertiseRoutes"`
// AdvertiseDefaultRoute advertises a default route out of the mesh.
AdvertiseDefaultRoute bool `yaml:"advertiseDefaults"`
// Endpoint is the IP in which this computer is publicly reachable.
// usecase is when the node has multiple IP addresses
Endpoint string `yaml:"publicEndpoint"`
@ -54,8 +66,11 @@ type WgMeshConfiguration struct {
KeepAliveTime int `yaml:"keepAliveTime"`
// Timeout number of seconds before we consider the node as dead
Timeout int `yaml:"timeout"`
// PruneTime number of seconds before we consider the 'node' as dead
// PruneTime number of seconds before we remove nodes that are likely to be dead
PruneTime int `yaml:"pruneTime"`
// DeadTime: number of seconds before we consider the node as dead and stop considering it
// when picking a random peer
DeadTime int `yaml:"deadTime"`
// Profile whether or not to include a http server that profiles the code
Profile bool `yaml:"profile"`
// StubWg whether or not to stub the WireGuard types
@ -135,9 +150,15 @@ func ValidateConfiguration(c *WgMeshConfiguration) error {
}
}
if c.PruneTime <= 1 {
if c.PruneTime < 1 {
return &WgMeshConfigurationError{
msg: "Prune time cannot be <= 1",
msg: "Prune time cannot be < 1",
}
}
if c.DeadTime < 1 {
return &WgMeshConfigurationError{
msg: "Dead time cannot be < 1",
}
}
@ -151,6 +172,10 @@ func ValidateConfiguration(c *WgMeshConfiguration) error {
c.Role = PEER_ROLE
}
if c.IPDiscovery == "" {
c.IPDiscovery = PUBLIC_IP_DISCOVERY
}
return nil
}

501
pkg/crdt/datastore.go Normal file
View File

@ -0,0 +1,501 @@
package crdt
import (
"bytes"
"encoding/gob"
"fmt"
"net"
"slices"
"strings"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type Route struct {
Destination string
Path []string
}
// GetDestination implements mesh.Route.
func (r *Route) GetDestination() *net.IPNet {
_, ipnet, _ := net.ParseCIDR(r.Destination)
return ipnet
}
// GetHopCount implements mesh.Route.
func (r *Route) GetHopCount() int {
return len(r.Path)
}
// GetPath implements mesh.Route.
func (r *Route) GetPath() []string {
return r.Path
}
type MeshNode struct {
HostEndpoint string
WgEndpoint string
PublicKey string
WgHost string
Timestamp int64
Routes map[string]Route
Alias string
Description string
Services map[string]string
Type string
Tombstone bool
}
// Mark: marks the node is unreachable. This is not broadcast on
// syncrhonisation
func (m *TwoPhaseStoreMeshManager) Mark(nodeId string) {
m.store.Mark(nodeId)
}
// GetHostEndpoint: gets the gRPC endpoint of the node
func (n *MeshNode) GetHostEndpoint() string {
return n.HostEndpoint
}
// GetPublicKey: gets the public key of the node
func (n *MeshNode) GetPublicKey() (wgtypes.Key, error) {
return wgtypes.ParseKey(n.PublicKey)
}
// GetWgEndpoint(): get IP and port of the wireguard endpoint
func (n *MeshNode) GetWgEndpoint() string {
return n.WgEndpoint
}
// GetWgHost: get the IP address of the WireGuard node
func (n *MeshNode) GetWgHost() *net.IPNet {
_, ipnet, _ := net.ParseCIDR(n.WgHost)
return ipnet
}
// GetTimestamp: get the UNIX time stamp of the ndoe
func (n *MeshNode) GetTimeStamp() int64 {
return n.Timestamp
}
// GetRoutes: returns the routes that the nodes provides
func (n *MeshNode) GetRoutes() []mesh.Route {
routes := make([]mesh.Route, len(n.Routes))
for index, route := range lib.MapValues(n.Routes) {
routes[index] = &Route{
Destination: route.Destination,
Path: route.Path,
}
}
return routes
}
// GetIdentifier: returns the identifier of the node
func (m *MeshNode) GetIdentifier() string {
ipv6 := m.WgHost[:len(m.WgHost)-4]
constituents := strings.Split(ipv6, ":")
constituents = constituents[4:]
return strings.Join(constituents, ":")
}
// GetDescription: returns the description for this node
func (n *MeshNode) GetDescription() string {
return n.Description
}
// GetAlias: associates the node with an alias. Potentially used
// for DNS and so forth.
func (n *MeshNode) GetAlias() string {
return n.Alias
}
// GetServices: returns a list of services offered by the node
func (n *MeshNode) GetServices() map[string]string {
return n.Services
}
func (n *MeshNode) GetType() conf.NodeType {
return conf.NodeType(n.Type)
}
type MeshSnapshot struct {
Nodes map[string]MeshNode
}
// GetNodes() returns the nodes in the mesh
func (m *MeshSnapshot) GetNodes() map[string]mesh.MeshNode {
newMap := make(map[string]mesh.MeshNode)
for key, value := range m.Nodes {
newMap[key] = &MeshNode{
HostEndpoint: value.HostEndpoint,
PublicKey: value.PublicKey,
WgHost: value.WgHost,
WgEndpoint: value.WgEndpoint,
Timestamp: value.Timestamp,
Routes: value.Routes,
Alias: value.Alias,
Description: value.Description,
Services: value.Services,
Type: value.Type,
}
}
return newMap
}
type TwoPhaseStoreMeshManager struct {
MeshId string
IfName string
Client *wgctrl.Client
LastClock uint64
conf *conf.WgMeshConfiguration
store *TwoPhaseMap[string, MeshNode]
}
// AddNode() adds a node to the mesh
func (m *TwoPhaseStoreMeshManager) AddNode(node mesh.MeshNode) {
crdt, ok := node.(*MeshNode)
if !ok {
panic("node must be of type mesh node")
}
crdt.Routes = make(map[string]Route)
crdt.Services = make(map[string]string)
crdt.Timestamp = time.Now().Unix()
m.store.Put(crdt.PublicKey, *crdt)
}
// GetMesh() returns a snapshot of the mesh provided by the mesh provider.
func (m *TwoPhaseStoreMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
nodes := m.store.AsList()
snapshot := make(map[string]MeshNode)
for _, node := range nodes {
snapshot[node.PublicKey] = node
}
return &MeshSnapshot{
Nodes: snapshot,
}, nil
}
// GetMeshId() returns the ID of the mesh network
func (m *TwoPhaseStoreMeshManager) GetMeshId() string {
return m.MeshId
}
// Save() saves the mesh network
func (m *TwoPhaseStoreMeshManager) Save() []byte {
snapshot := m.store.Snapshot()
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(*snapshot)
if err != nil {
logging.Log.WriteInfof(err.Error())
}
return buf.Bytes()
}
// Load() loads a mesh network
func (m *TwoPhaseStoreMeshManager) Load(bs []byte) error {
buf := bytes.NewBuffer(bs)
dec := gob.NewDecoder(buf)
var snapshot TwoPhaseMapSnapshot[string, MeshNode]
err := dec.Decode(&snapshot)
m.store.Merge(snapshot)
return err
}
// GetDevice() get the device corresponding with the mesh
func (m *TwoPhaseStoreMeshManager) GetDevice() (*wgtypes.Device, error) {
dev, err := m.Client.Device(m.IfName)
if err != nil {
return nil, err
}
return dev, nil
}
// HasChanges returns true if we have changes since last time we synced
func (m *TwoPhaseStoreMeshManager) HasChanges() bool {
clockValue := m.store.GetHash()
return clockValue != m.LastClock
}
// Record that we have changes and save the corresponding changes
func (m *TwoPhaseStoreMeshManager) SaveChanges() {
clockValue := m.store.GetHash()
m.LastClock = clockValue
}
// UpdateTimeStamp: update the timestamp of the given node
func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
// Sort nodes by their public key
peers := m.GetPeers()
slices.Sort(peers)
if len(peers) == 0 {
return nil
}
peerToUpdate := peers[0]
if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.conf.KeepAliveTime) {
m.store.Mark(peerToUpdate)
if len(peers) < 2 {
return nil
}
peerToUpdate = peers[1]
}
if peerToUpdate != nodeId {
return nil
}
// Refresh causing node to update it's time stamp
node := m.store.Get(nodeId)
node.Timestamp = time.Now().Unix()
m.store.Put(nodeId, node)
return nil
}
// AddRoutes: adds routes to the given node
func (m *TwoPhaseStoreMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
if len(routes) == 0 {
return nil
}
node := m.store.Get(nodeId)
changes := false
for _, route := range routes {
prevRoute, ok := node.Routes[route.GetDestination().String()]
if !ok || route.GetHopCount() < prevRoute.GetHopCount() {
changes = true
node.Routes[route.GetDestination().String()] = Route{
Destination: route.GetDestination().String(),
Path: route.GetPath(),
}
}
}
if changes {
m.store.Put(nodeId, node)
}
return nil
}
// DeleteRoutes: deletes the routes from the node
func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
if len(routes) == 0 {
return nil
}
node := m.store.Get(nodeId)
for _, route := range routes {
delete(node.Routes, route)
}
return nil
}
// GetSyncer: returns the automerge syncer for sync
func (m *TwoPhaseStoreMeshManager) GetSyncer() mesh.MeshSyncer {
return NewTwoPhaseSyncer(m)
}
// GetNode get a particular not within the mesh
func (m *TwoPhaseStoreMeshManager) GetNode(nodeId string) (mesh.MeshNode, error) {
if !m.store.Contains(nodeId) {
return nil, fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
node := m.store.Get(nodeId)
return &node, nil
}
// NodeExists: returns true if a particular node exists false otherwise
func (m *TwoPhaseStoreMeshManager) NodeExists(nodeId string) bool {
return m.store.Contains(nodeId)
}
// SetDescription: sets the description of this automerge data type
func (m *TwoPhaseStoreMeshManager) SetDescription(nodeId string, description string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
node := m.store.Get(nodeId)
node.Description = description
m.store.Put(nodeId, node)
return nil
}
// SetAlias: set the alias of the nodeId
func (m *TwoPhaseStoreMeshManager) SetAlias(nodeId string, alias string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
node := m.store.Get(nodeId)
node.Description = alias
m.store.Put(nodeId, node)
return nil
}
// AddService: adds the service to the given node
func (m *TwoPhaseStoreMeshManager) AddService(nodeId string, key string, value string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
node := m.store.Get(nodeId)
node.Services[key] = value
m.store.Put(nodeId, node)
return nil
}
// RemoveService: removes the service form the node. throws an error if the service does not exist
func (m *TwoPhaseStoreMeshManager) RemoveService(nodeId string, key string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
node := m.store.Get(nodeId)
delete(node.Services, key)
m.store.Put(nodeId, node)
return nil
}
// Prune: prunes all nodes that have not updated their timestamp in
func (m *TwoPhaseStoreMeshManager) Prune() error {
m.store.Prune()
return nil
}
// GetPeers: get a list of contactable peers
func (m *TwoPhaseStoreMeshManager) GetPeers() []string {
nodes := m.store.AsList()
nodes = lib.Filter(nodes, func(mn MeshNode) bool {
if mn.Type != string(conf.PEER_ROLE) {
return false
}
// If the node is marked as unreachable don't consider it a peer.
// this help to optimize convergence time for unreachable nodes.
// However advertising it to other nodes could result in flapping.
if m.store.IsMarked(mn.PublicKey) {
return false
}
return true
})
return lib.Map(nodes, func(mn MeshNode) string {
return mn.PublicKey
})
}
func (m *TwoPhaseStoreMeshManager) getRoutes(targetNode string) (map[string]Route, error) {
if !m.store.Contains(targetNode) {
return nil, fmt.Errorf("getRoute: cannot get route %s does not exist", targetNode)
}
node := m.store.Get(targetNode)
return node.Routes, nil
}
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen
func (m *TwoPhaseStoreMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) {
node, err := m.GetNode(targetNode)
if err != nil {
return nil, err
}
routes := make(map[string]mesh.Route)
// Add routes that the node directly has
for _, route := range node.GetRoutes() {
routes[route.GetDestination().String()] = route
}
// Work out the other routes in the mesh
for _, node := range m.GetPeers() {
nodeRoutes, err := m.getRoutes(node)
if err != nil {
return nil, err
}
for _, route := range nodeRoutes {
otherRoute, ok := routes[route.GetDestination().String()]
hopCount := route.GetHopCount()
if node != targetNode {
hopCount += 1
}
if !ok || route.GetHopCount()+1 < otherRoute.GetHopCount() {
routes[route.GetDestination().String()] = &Route{
Destination: route.GetDestination().String(),
Path: append(route.GetPath(), m.GetMeshId()),
}
}
}
}
return routes, nil
}
// RemoveNode(): remove the node from the mesh
func (m *TwoPhaseStoreMeshManager) RemoveNode(nodeId string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
m.store.Remove(nodeId)
return nil
}

78
pkg/crdt/factory.go Normal file
View File

@ -0,0 +1,78 @@
package crdt
import (
"fmt"
"hash/fnv"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
)
type TwoPhaseMapFactory struct{}
func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) {
return &TwoPhaseStoreMeshManager{
MeshId: params.MeshId,
IfName: params.DevName,
Client: params.Client,
conf: params.Conf,
store: NewTwoPhaseMap[string, MeshNode](params.NodeID, func(s string) uint64 {
h := fnv.New64a()
h.Write([]byte(s))
return h.Sum64()
}, uint64(3*params.Conf.KeepAliveTime)),
}, nil
}
type MeshNodeFactory struct {
Config conf.WgMeshConfiguration
}
func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode {
hostName := f.getAddress(params)
grpcEndpoint := fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort)
if f.Config.Role == conf.CLIENT_ROLE {
grpcEndpoint = "-"
}
return &MeshNode{
HostEndpoint: grpcEndpoint,
PublicKey: params.PublicKey.String(),
WgEndpoint: fmt.Sprintf("%s:%d", hostName, params.WgPort),
WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()),
Routes: make(map[string]Route),
Description: "",
Alias: "",
Type: string(f.Config.Role),
}
}
// getAddress returns the routable address of the machine.
func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string {
var hostName string = ""
if params.Endpoint != "" {
hostName = params.Endpoint
} else if len(f.Config.Endpoint) != 0 {
hostName = f.Config.Endpoint
} else {
ipFunc := lib.GetPublicIP
if f.Config.IPDiscovery == conf.DNS_IP_DISCOVERY {
ipFunc = lib.GetOutboundIP
}
ip, err := ipFunc()
if err != nil {
return ""
}
hostName = ip.String()
}
return hostName
}

176
pkg/crdt/g_map.go Normal file
View File

@ -0,0 +1,176 @@
// crdt is a golang implementation of a crdt
package crdt
import (
"cmp"
"sync"
)
type Bucket[D any] struct {
Vector uint64
Contents D
Gravestone bool
}
// GMap is a set that can only grow in size
type GMap[K cmp.Ordered, D any] struct {
lock sync.RWMutex
contents map[uint64]Bucket[D]
clock *VectorClock[K]
}
func (g *GMap[K, D]) Put(key K, value D) {
g.lock.Lock()
clock := g.clock.IncrementClock()
g.contents[g.clock.hashFunc(key)] = Bucket[D]{
Vector: clock,
Contents: value,
}
g.lock.Unlock()
}
func (g *GMap[K, D]) Contains(key K) bool {
return g.contains(g.clock.hashFunc(key))
}
func (g *GMap[K, D]) contains(key uint64) bool {
g.lock.RLock()
_, ok := g.contents[key]
g.lock.RUnlock()
return ok
}
func (g *GMap[K, D]) put(key uint64, b Bucket[D]) {
g.lock.Lock()
if g.contents[key].Vector < b.Vector {
g.contents[key] = b
}
g.lock.Unlock()
}
func (g *GMap[K, D]) get(key uint64) Bucket[D] {
g.lock.RLock()
bucket := g.contents[key]
g.lock.RUnlock()
return bucket
}
func (g *GMap[K, D]) Get(key K) D {
return g.get(g.clock.hashFunc(key)).Contents
}
func (g *GMap[K, D]) Mark(key K) {
g.lock.Lock()
bucket := g.contents[g.clock.hashFunc(key)]
bucket.Gravestone = true
g.contents[g.clock.hashFunc(key)] = bucket
g.lock.Unlock()
}
// IsMarked: returns true if the node is marked
func (g *GMap[K, D]) IsMarked(key K) bool {
marked := false
g.lock.RLock()
bucket, ok := g.contents[g.clock.hashFunc(key)]
if ok {
marked = bucket.Gravestone
}
g.lock.RUnlock()
return marked
}
func (g *GMap[K, D]) Keys() []uint64 {
g.lock.RLock()
contents := make([]uint64, len(g.contents))
index := 0
for key := range g.contents {
contents[index] = key
index++
}
g.lock.RUnlock()
return contents
}
func (g *GMap[K, D]) Save() map[uint64]Bucket[D] {
buckets := make(map[uint64]Bucket[D])
g.lock.RLock()
for key, value := range g.contents {
buckets[key] = value
}
g.lock.RUnlock()
return buckets
}
func (g *GMap[K, D]) SaveWithKeys(keys []uint64) map[uint64]Bucket[D] {
buckets := make(map[uint64]Bucket[D])
g.lock.RLock()
for _, key := range keys {
buckets[key] = g.contents[key]
}
g.lock.RUnlock()
return buckets
}
func (g *GMap[K, D]) GetClock() map[uint64]uint64 {
clock := make(map[uint64]uint64)
g.lock.RLock()
for key, bucket := range g.contents {
clock[key] = bucket.Vector
}
g.lock.RUnlock()
return clock
}
func (g *GMap[K, D]) GetHash() uint64 {
hash := uint64(0)
g.lock.RLock()
for _, value := range g.contents {
hash += value.Vector
}
g.lock.RUnlock()
return hash
}
func (g *GMap[K, D]) Prune() {
stale := g.clock.getStale()
g.lock.Lock()
for _, outlier := range stale {
delete(g.contents, outlier)
}
g.lock.Unlock()
}
func NewGMap[K cmp.Ordered, D any](clock *VectorClock[K]) *GMap[K, D] {
return &GMap[K, D]{
contents: make(map[uint64]Bucket[D]),
clock: clock,
}
}

211
pkg/crdt/two_phase_map.go Normal file
View File

@ -0,0 +1,211 @@
package crdt
import (
"cmp"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
type TwoPhaseMap[K cmp.Ordered, D any] struct {
addMap *GMap[K, D]
removeMap *GMap[K, bool]
Clock *VectorClock[K]
processId K
}
type TwoPhaseMapSnapshot[K cmp.Ordered, D any] struct {
Add map[uint64]Bucket[D]
Remove map[uint64]Bucket[bool]
}
// Contains checks whether the value exists in the map
func (m *TwoPhaseMap[K, D]) Contains(key K) bool {
return m.contains(m.Clock.hashFunc(key))
}
// Contains checks whether the value exists in the map
func (m *TwoPhaseMap[K, D]) contains(key uint64) bool {
if !m.addMap.contains(key) {
return false
}
addValue := m.addMap.get(key)
if !m.removeMap.contains(key) {
return true
}
removeValue := m.removeMap.get(key)
return addValue.Vector >= removeValue.Vector
}
func (m *TwoPhaseMap[K, D]) Get(key K) D {
var result D
if !m.Contains(key) {
return result
}
return m.addMap.Get(key)
}
func (m *TwoPhaseMap[K, D]) get(key uint64) D {
var result D
if !m.contains(key) {
return result
}
return m.addMap.get(key).Contents
}
// Put places the key K in the map
func (m *TwoPhaseMap[K, D]) Put(key K, data D) {
msgSequence := m.Clock.IncrementClock()
m.Clock.Put(key, msgSequence)
m.addMap.Put(key, data)
}
func (m *TwoPhaseMap[K, D]) Mark(key K) {
m.addMap.Mark(key)
}
// Remove removes the value from the map
func (m *TwoPhaseMap[K, D]) Remove(key K) {
m.removeMap.Put(key, true)
}
func (m *TwoPhaseMap[K, D]) keys() []uint64 {
keys := make([]uint64, 0)
addKeys := m.addMap.Keys()
for _, key := range addKeys {
if !m.contains(key) {
continue
}
keys = append(keys, key)
}
return keys
}
func (m *TwoPhaseMap[K, D]) AsList() []D {
theList := make([]D, 0)
keys := m.keys()
for _, key := range keys {
theList = append(theList, m.get(key))
}
return theList
}
func (m *TwoPhaseMap[K, D]) Snapshot() *TwoPhaseMapSnapshot[K, D] {
return &TwoPhaseMapSnapshot[K, D]{
Add: m.addMap.Save(),
Remove: m.removeMap.Save(),
}
}
func (m *TwoPhaseMap[K, D]) SnapShotFromState(state *TwoPhaseMapState[K]) *TwoPhaseMapSnapshot[K, D] {
addKeys := lib.MapKeys(state.AddContents)
removeKeys := lib.MapKeys(state.RemoveContents)
return &TwoPhaseMapSnapshot[K, D]{
Add: m.addMap.SaveWithKeys(addKeys),
Remove: m.removeMap.SaveWithKeys(removeKeys),
}
}
type TwoPhaseMapState[K cmp.Ordered] struct {
Vectors map[uint64]uint64
AddContents map[uint64]uint64
RemoveContents map[uint64]uint64
}
func (m *TwoPhaseMap[K, D]) IsMarked(key K) bool {
return m.addMap.IsMarked(key)
}
// GetHash: Get the hash of the current state of the map
// Sums the current values of the vectors. Provides good approximation
// of increasing numbers
func (m *TwoPhaseMap[K, D]) GetHash() uint64 {
return (m.addMap.GetHash() + 1) * (m.removeMap.GetHash() + 1)
}
// GetState: get the current vector clock of the add and remove
// map
func (m *TwoPhaseMap[K, D]) GenerateMessage() *TwoPhaseMapState[K] {
addContents := m.addMap.GetClock()
removeContents := m.removeMap.GetClock()
return &TwoPhaseMapState[K]{
Vectors: m.Clock.GetClock(),
AddContents: addContents,
RemoveContents: removeContents,
}
}
func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMapState[K] {
mapState := &TwoPhaseMapState[K]{
AddContents: make(map[uint64]uint64),
RemoveContents: make(map[uint64]uint64),
}
for key, value := range state.AddContents {
otherValue, ok := m.AddContents[key]
if !ok || otherValue < value {
mapState.AddContents[key] = value
}
}
for key, value := range state.RemoveContents {
otherValue, ok := m.RemoveContents[key]
if !ok || otherValue < value {
mapState.RemoveContents[key] = value
}
}
return mapState
}
func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) {
for key, value := range snapshot.Add {
// Gravestone is local only to that node.
// Discover ourselves if the node is alive
m.addMap.put(key, value)
m.Clock.put(key, value.Vector)
}
for key, value := range snapshot.Remove {
m.removeMap.put(key, value)
m.Clock.put(key, value.Vector)
}
}
func (m *TwoPhaseMap[K, D]) Prune() {
m.addMap.Prune()
m.removeMap.Prune()
m.Clock.Prune()
}
// NewTwoPhaseMap: create a new two phase map. Consists of two maps
// a grow map and a remove map. If both timestamps equal then favour keeping
// it in the map
func NewTwoPhaseMap[K cmp.Ordered, D any](processId K, hashKey func(K) uint64, staleTime uint64) *TwoPhaseMap[K, D] {
m := TwoPhaseMap[K, D]{
processId: processId,
Clock: NewVectorClock(processId, hashKey, staleTime),
}
m.addMap = NewGMap[K, D](m.Clock)
m.removeMap = NewGMap[K, bool](m.Clock)
return &m
}

View File

@ -0,0 +1,187 @@
package crdt
import (
"bytes"
"encoding/gob"
logging "github.com/tim-beatham/wgmesh/pkg/log"
)
type SyncState int
const (
HASH SyncState = iota
PREPARE
PRESENT
EXCHANGE
MERGE
FINISHED
)
// TwoPhaseSyncer is a type to sync a TwoPhase data store
type TwoPhaseSyncer struct {
manager *TwoPhaseStoreMeshManager
generateMessageFSM SyncFSM
state SyncState
mapState *TwoPhaseMapState[string]
peerMsg []byte
}
type TwoPhaseHash struct {
Hash uint64
}
type SyncFSM map[SyncState]func(*TwoPhaseSyncer) ([]byte, bool)
func hash(syncer *TwoPhaseSyncer) ([]byte, bool) {
hash := TwoPhaseHash{
Hash: syncer.manager.store.Clock.GetHash(),
}
var buffer bytes.Buffer
enc := gob.NewEncoder(&buffer)
err := enc.Encode(hash)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
syncer.IncrementState()
return buffer.Bytes(), true
}
func prepare(syncer *TwoPhaseSyncer) ([]byte, bool) {
var recvBuffer = bytes.NewBuffer(syncer.peerMsg)
dec := gob.NewDecoder(recvBuffer)
var hash TwoPhaseHash
err := dec.Decode(&hash)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
// If vector clocks are equal then no need to merge state
// Helps to reduce bandwidth by detecting early
if hash.Hash == syncer.manager.store.Clock.GetHash() {
return nil, false
}
var buffer bytes.Buffer
enc := gob.NewEncoder(&buffer)
err = enc.Encode(*syncer.mapState)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
syncer.IncrementState()
return buffer.Bytes(), true
}
func present(syncer *TwoPhaseSyncer) ([]byte, bool) {
if syncer.peerMsg == nil {
panic("peer msg is nil")
}
var recvBuffer = bytes.NewBuffer(syncer.peerMsg)
dec := gob.NewDecoder(recvBuffer)
var mapState TwoPhaseMapState[string]
err := dec.Decode(&mapState)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
difference := syncer.mapState.Difference(&mapState)
syncer.manager.store.Clock.Merge(mapState.Vectors)
var sendBuffer bytes.Buffer
enc := gob.NewEncoder(&sendBuffer)
enc.Encode(*difference)
syncer.IncrementState()
return sendBuffer.Bytes(), true
}
func exchange(syncer *TwoPhaseSyncer) ([]byte, bool) {
if syncer.peerMsg == nil {
panic("peer msg is nil")
}
var recvBuffer = bytes.NewBuffer(syncer.peerMsg)
dec := gob.NewDecoder(recvBuffer)
var mapState TwoPhaseMapState[string]
dec.Decode(&mapState)
snapshot := syncer.manager.store.SnapShotFromState(&mapState)
var sendBuffer bytes.Buffer
enc := gob.NewEncoder(&sendBuffer)
enc.Encode(*snapshot)
syncer.IncrementState()
return sendBuffer.Bytes(), true
}
func merge(syncer *TwoPhaseSyncer) ([]byte, bool) {
if syncer.peerMsg == nil {
panic("peer msg is nil")
}
var recvBuffer = bytes.NewBuffer(syncer.peerMsg)
dec := gob.NewDecoder(recvBuffer)
var snapshot TwoPhaseMapSnapshot[string, MeshNode]
dec.Decode(&snapshot)
syncer.manager.store.Merge(snapshot)
return nil, false
}
func (t *TwoPhaseSyncer) IncrementState() {
t.state = min(t.state+1, FINISHED)
}
func (t *TwoPhaseSyncer) GenerateMessage() ([]byte, bool) {
fsmFunc, ok := t.generateMessageFSM[t.state]
if !ok {
panic("state not handled")
}
return fsmFunc(t)
}
func (t *TwoPhaseSyncer) RecvMessage(msg []byte) error {
t.peerMsg = msg
return nil
}
func (t *TwoPhaseSyncer) Complete() {
logging.Log.WriteInfof("SYNC COMPLETED")
if t.state >= MERGE {
t.manager.store.Clock.IncrementClock()
}
}
func NewTwoPhaseSyncer(manager *TwoPhaseStoreMeshManager) *TwoPhaseSyncer {
var generateMessageFsm SyncFSM = SyncFSM{
HASH: hash,
PREPARE: prepare,
PRESENT: present,
EXCHANGE: exchange,
MERGE: merge,
}
return &TwoPhaseSyncer{
manager: manager,
state: HASH,
mapState: manager.store.GenerateMessage(),
generateMessageFSM: generateMessageFsm,
}
}

149
pkg/crdt/vector_clock.go Normal file
View File

@ -0,0 +1,149 @@
package crdt
import (
"cmp"
"sync"
"time"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
type VectorBucket struct {
// clock current value of the node's clock
clock uint64
// lastUpdate we've seen
lastUpdate uint64
}
// Vector clock defines an abstract data type
// for a vector clock implementation
type VectorClock[K cmp.Ordered] struct {
vectors map[uint64]*VectorBucket
lock sync.RWMutex
processID K
staleTime uint64
hashFunc func(K) uint64
}
// IncrementClock: increments the node's value in the vector clock
func (m *VectorClock[K]) IncrementClock() uint64 {
maxClock := uint64(0)
m.lock.Lock()
for _, value := range m.vectors {
maxClock = max(maxClock, value.clock)
}
newBucket := VectorBucket{
clock: maxClock + 1,
lastUpdate: uint64(time.Now().Unix()),
}
m.vectors[m.hashFunc(m.processID)] = &newBucket
m.lock.Unlock()
return maxClock
}
// GetHash: gets the hash of the vector clock used to determine if there
// are any changes
func (m *VectorClock[K]) GetHash() uint64 {
m.lock.RLock()
hash := uint64(0)
for key, bucket := range m.vectors {
hash += key * (bucket.clock + 1)
}
m.lock.RUnlock()
return hash
}
func (m *VectorClock[K]) Merge(vectors map[uint64]uint64) {
for key, value := range vectors {
m.put(key, value)
}
}
// getStale: get all entries that are stale within the mesh
func (m *VectorClock[K]) getStale() []uint64 {
m.lock.RLock()
maxTimeStamp := lib.Reduce(0, lib.MapValues(m.vectors), func(i uint64, vb *VectorBucket) uint64 {
return max(i, vb.lastUpdate)
})
toRemove := make([]uint64, 0)
for key, bucket := range m.vectors {
if maxTimeStamp-bucket.lastUpdate > m.staleTime {
toRemove = append(toRemove, key)
}
}
m.lock.RUnlock()
return toRemove
}
func (m *VectorClock[K]) Prune() {
stale := m.getStale()
m.lock.Lock()
for _, key := range stale {
delete(m.vectors, key)
}
m.lock.Unlock()
}
func (m *VectorClock[K]) GetTimestamp(processId K) uint64 {
return m.vectors[m.hashFunc(m.processID)].lastUpdate
}
func (m *VectorClock[K]) Put(key K, value uint64) {
m.put(m.hashFunc(key), value)
}
func (m *VectorClock[K]) put(key uint64, value uint64) {
clockValue := uint64(0)
m.lock.Lock()
bucket, ok := m.vectors[key]
if ok {
clockValue = bucket.clock
}
if value > clockValue {
newBucket := VectorBucket{
clock: value,
lastUpdate: uint64(time.Now().Unix()),
}
m.vectors[key] = &newBucket
}
m.lock.Unlock()
}
func (m *VectorClock[K]) GetClock() map[uint64]uint64 {
clock := make(map[uint64]uint64)
m.lock.RLock()
for key, value := range m.vectors {
clock[key] = value.clock
}
m.lock.RUnlock()
return clock
}
func NewVectorClock[K cmp.Ordered](processID K, hashFunc func(K) uint64, staleTime uint64) *VectorClock[K] {
return &VectorClock[K]{
vectors: make(map[uint64]*VectorBucket),
processID: processID,
staleTime: staleTime,
hashFunc: hashFunc,
}
}

View File

@ -1,9 +1,9 @@
package ctrlserver
import (
crdt "github.com/tim-beatham/wgmesh/pkg/automerge"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/crdt"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
@ -21,17 +21,18 @@ type NewCtrlServerParams struct {
CtrlProvider rpc.MeshCtrlServerServer
SyncProvider rpc.SyncServiceServer
Querier query.Querier
OnDelete func(mesh.MeshProvider)
}
// Create a new instance of the MeshCtrlServer or error if the
// operation failed
func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
ctrlServer := new(MeshCtrlServer)
meshFactory := crdt.CrdtProviderFactory{}
nodeFactory := crdt.MeshNodeFactory{
meshFactory := &crdt.TwoPhaseMapFactory{}
nodeFactory := &crdt.MeshNodeFactory{
Config: *params.Conf,
}
idGenerator := &lib.UUIDGenerator{}
idGenerator := &lib.IDNameGenerator{}
ipAllocator := &ip.ULABuilder{}
interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client)
@ -40,12 +41,13 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
meshManagerParams := &mesh.NewMeshManagerParams{
Conf: *params.Conf,
Client: params.Client,
MeshProvider: &meshFactory,
NodeFactory: &nodeFactory,
MeshProvider: meshFactory,
NodeFactory: nodeFactory,
IdGenerator: idGenerator,
IPAllocator: ipAllocator,
InterfaceManipulator: interfaceManipulator,
ConfigApplyer: configApplyer,
OnDelete: params.OnDelete,
}
ctrlServer.MeshManager = mesh.NewMeshManager(meshManagerParams)

View File

@ -9,6 +9,11 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type MeshRoute struct {
Destination string
Path []string
}
// Represents a WireGuard MeshNode
type MeshNode struct {
HostEndpoint string
@ -16,7 +21,7 @@ type MeshNode struct {
PublicKey string
WgHost string
Timestamp int64
Routes []string
Routes []MeshRoute
Description string
Alias string
Services map[string]string

114
pkg/dns/dns.go Normal file
View File

@ -0,0 +1,114 @@
package smegdns
import (
"encoding/json"
"fmt"
"net"
"net/rpc"
"github.com/miekg/dns"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/query"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
const MeshRegularExpression = `(?P<meshId>.+)\.(?P<alias>.+)\.smeg\.`
type DNSHandler struct {
client *rpc.Client
server *dns.Server
}
// queryMesh: queries the mesh network for the given meshId and node
// with alias
func (d *DNSHandler) queryMesh(meshId, alias string) net.IP {
var reply string
err := d.client.Call("IpcHandler.Query", &ipc.QueryMesh{
MeshId: meshId,
Query: fmt.Sprintf("[?alias == '%s'] | [0]", alias),
}, &reply)
if err != nil {
return nil
}
var node *query.QueryNode
err = json.Unmarshal([]byte(reply), &node)
if err != nil || node == nil {
return nil
}
ip, _, _ := net.ParseCIDR(node.WgHost)
return ip
}
func (d *DNSHandler) handleQuery(m *dns.Msg) {
for _, q := range m.Question {
switch q.Qtype {
case dns.TypeAAAA:
logging.Log.WriteInfof("Query for %s", q.Name)
groups := lib.MatchCaptureGroup(MeshRegularExpression, q.Name)
if len(groups) == 0 {
continue
}
ip := d.queryMesh(groups["meshId"], groups["alias"])
rr, err := dns.NewRR(fmt.Sprintf("%s AAAA %s", q.Name, ip))
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
}
}
func (h *DNSHandler) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
msg := new(dns.Msg)
msg.SetReply(r)
msg.Authoritative = true
switch r.Opcode {
case dns.OpcodeQuery:
h.handleQuery(msg)
}
w.WriteMsg(msg)
}
func (h *DNSHandler) Listen() error {
return h.server.ListenAndServe()
}
func (h *DNSHandler) Close() error {
return h.server.Shutdown()
}
func NewDns(udpPort int) (*DNSHandler, error) {
client, err := rpc.DialHTTP("unix", SockAddr)
if err != nil {
return nil, err
}
dnsHander := DNSHandler{
client: client,
}
dns.HandleFunc("smeg.", dnsHander.handleDnsRequest)
dnsHander.server = &dns.Server{Addr: fmt.Sprintf(":%d", udpPort), Net: "udp"}
return &dnsHander, nil
}

View File

@ -4,13 +4,13 @@ package rpctypes;
option go_package = "pkg/rpc";
service MeshCtrlServer {
rpc JoinMesh(JoinMeshRequest) returns (JoinMeshReply) {}
rpc GetMesh(GetMeshRequest) returns (GetMeshReply) {}
}
message JoinMeshRequest {
string meshId = 2;
message GetMeshRequest {
string meshId = 1;
}
message JoinMeshReply {
bool success = 1;
message GetMeshReply {
bytes mesh = 1;
}

View File

@ -62,7 +62,6 @@ type MeshIpc interface {
JoinMesh(args JoinMeshArgs, reply *string) error
LeaveMesh(meshId string, reply *string) error
GetMesh(meshId string, reply *GetMeshReply) error
EnableInterface(meshId string, reply *string) error
GetDOT(meshId string, reply *string) error
Query(query QueryMesh, reply *string) error
PutDescription(description string, reply *string) error

View File

@ -1,11 +1,13 @@
package lib
import "cmp"
// MapToSlice converts a map to a slice in go
func MapValues[K comparable, V any](m map[K]V) []V {
func MapValues[K cmp.Ordered, V any](m map[K]V) []V {
return MapValuesWithExclude(m, map[K]struct{}{})
}
func MapValuesWithExclude[K comparable, V any](m map[K]V, exclude map[K]struct{}) []V {
func MapValuesWithExclude[K cmp.Ordered, V any](m map[K]V, exclude map[K]struct{}) []V {
values := make([]V, len(m)-len(exclude))
i := 0
@ -26,7 +28,7 @@ func MapValuesWithExclude[K comparable, V any](m map[K]V, exclude map[K]struct{}
return values
}
func MapKeys[K comparable, V any](m map[K]V) []K {
func MapKeys[K cmp.Ordered, V any](m map[K]V) []K {
values := make([]K, len(m))
i := 0
@ -66,3 +68,23 @@ func Filter[V any](list []V, f filterFunc[V]) []V {
return newList
}
func Contains[V any](list []V, proposition func(V) bool) bool {
for _, elem := range list {
if proposition(elem) {
return true
}
}
return false
}
func Reduce[A any, V any](start A, values []V, reduce func(A, V) A) A {
accum := start
for _, elem := range values {
accum = reduce(accum, elem)
}
return accum
}

View File

@ -18,7 +18,7 @@ func HashString(value string) int {
// ConsistentHash implementation. Traverse the values until we find a key
// less than ours.
func ConsistentHash[V any](values []V, client V, keyFunc func(V) int) V {
func ConsistentHash[V any, K any](values []V, client K, bucketFunc func(V) int, keyFunc func(K) int) V {
if len(values) == 0 {
panic("values is empty")
}
@ -26,7 +26,7 @@ func ConsistentHash[V any](values []V, client V, keyFunc func(V) int) V {
vs := Map(values, func(v V) consistentHashRecord[V] {
return consistentHashRecord[V]{
v,
keyFunc(v),
bucketFunc(v),
}
})

View File

@ -1,6 +1,9 @@
package lib
import "github.com/google/uuid"
import (
"github.com/anandvarma/namegen"
"github.com/google/uuid"
)
// IdGenerator generates unique ids
type IdGenerator interface {
@ -15,3 +18,11 @@ func (g *UUIDGenerator) GetId() (string, error) {
id := uuid.New()
return id.String(), nil
}
type IDNameGenerator struct {
}
func (i *IDNameGenerator) GetId() (string, error) {
name_schema := namegen.New()
return name_schema.Get(), nil
}

View File

@ -9,14 +9,14 @@ import (
)
// GetOutboundIP: gets the oubound IP of this packet
func GetOutboundIP() net.IP {
func GetOutboundIP() (net.IP, error) {
conn, err := net.Dial("udp", "8.8.8.8:80")
if err != nil {
log.Fatal(err)
}
defer conn.Close()
localAddr := conn.LocalAddr().(*net.UDPAddr)
return localAddr.IP
return localAddr.IP, nil
}
const IP_SERVICE = "https://api.ipify.org?format=json"

19
pkg/lib/regex.go Normal file
View File

@ -0,0 +1,19 @@
package lib
import "regexp"
func MatchCaptureGroup(pattern, payload string) map[string]string {
patterns := make(map[string]string)
expr := regexp.MustCompile(pattern)
match := expr.FindStringSubmatch(payload)
for i, name := range expr.SubexpNames() {
if i != 0 && name != "" {
patterns[name] = match[i]
}
}
return patterns
}

View File

@ -201,7 +201,7 @@ func (c *RtNetlinkConfig) DeleteRoute(ifName string, route Route) error {
})
if err != nil {
return fmt.Errorf("failed to delete route %w", err)
return fmt.Errorf("failed to delete route %s", dst.IP.String())
}
return nil
@ -219,22 +219,15 @@ func (r1 Route) equal(r2 Route) bool {
// DeleteRoutes deletes all routes not in exclude
func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...Route) error {
routes := make([]rtnetlink.RouteMessage, 0)
routes, err := c.listRoutes(ifName, family)
if len(exclude) != 0 {
lRoutes, err := c.listRoutes(ifName, family, exclude[0].Gateway)
if err != nil {
return err
}
routes = lRoutes
if err != nil {
return err
}
ifRoutes := make([]Route, 0)
for _, rtRoute := range routes {
logging.Log.WriteInfof("Routes: %s", rtRoute.Attributes.Dst.String())
maskSize := 128
if family == unix.AF_INET {
@ -255,6 +248,14 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
if route.equal(r) {
return false
}
if family == unix.AF_INET && route.Destination.IP.To4() == nil {
return false
}
if family == unix.AF_INET6 && route.Destination.IP.To16() == nil {
return false
}
}
return true
}
@ -262,7 +263,7 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
toDelete := Filter(ifRoutes, shouldExclude)
for _, route := range toDelete {
logging.Log.WriteInfof("Deleting route %s", route.Destination.String())
logging.Log.WriteInfof("Deleting route: %s", route.Destination.String())
err := c.DeleteRoute(ifName, route)
if err != nil {
@ -274,7 +275,7 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
}
// listRoutes lists all routes on the interface
func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8, gateway net.IP) ([]rtnetlink.RouteMessage, error) {
func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8) ([]rtnetlink.RouteMessage, error) {
iface, err := net.InterfaceByName(ifName)
if err != nil {
@ -288,7 +289,7 @@ func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8, gateway net.IP
}
filterFunc := func(r rtnetlink.RouteMessage) bool {
return r.Attributes.Gateway.Equal(gateway) && r.Attributes.OutIface == uint32(iface.Index)
return r.Attributes.Gateway != nil && r.Attributes.OutIface == uint32(iface.Index)
}
routes = Filter(routes, filterFunc)

40
pkg/lib/stats.go Normal file
View File

@ -0,0 +1,40 @@
// lib contains helper functions for the implementation
package lib
import (
"cmp"
"math"
"gonum.org/v1/gonum/stat"
"gonum.org/v1/gonum/stat/distuv"
)
// Modelling the distribution using a normal distribution get the count
// of the outliers
func GetOutliers[K cmp.Ordered](counts map[K]uint64, alpha float64) []K {
n := float64(len(counts))
keys := MapKeys(counts)
values := make([]float64, len(keys))
for index, key := range keys {
values[index] = float64(counts[key])
}
mean := stat.Mean(values, nil)
stdDev := stat.StdDev(values, nil)
moe := distuv.Normal{Mu: 0, Sigma: 1}.Quantile(1-alpha/2) * (stdDev / math.Sqrt(n))
lowerBound := mean - moe
var outliers []K
for i, count := range values {
if count < lowerBound {
outliers = append(outliers, keys[i])
}
}
return outliers
}

View File

@ -3,10 +3,14 @@ package mesh
import (
"fmt"
"net"
"slices"
"strings"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/route"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@ -23,14 +27,17 @@ type WgMeshConfigApplyer struct {
meshManager MeshManager
config *conf.WgMeshConfiguration
routeInstaller route.RouteInstaller
hashFunc func(MeshNode) int
}
func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, peerToClients map[string][]net.IPNet) (*wgtypes.PeerConfig, error) {
endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint())
type routeNode struct {
gateway string
route Route
}
if err != nil {
return nil, err
}
func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Device,
peerToClients map[string][]net.IPNet,
routes map[string][]routeNode) (*wgtypes.PeerConfig, error) {
pubKey, err := node.GetPublicKey()
@ -41,29 +48,285 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, peerToClients map[s
allowedips := make([]net.IPNet, 1)
allowedips[0] = *node.GetWgHost()
for _, route := range node.GetRoutes() {
_, ipnet, _ := net.ParseCIDR(route)
allowedips = append(allowedips, *ipnet)
}
clients, ok := peerToClients[node.GetWgHost().String()]
clients, ok := peerToClients[pubKey.String()]
if ok {
allowedips = append(allowedips, clients...)
}
for _, route := range node.GetRoutes() {
bestRoutes := routes[route.GetDestination().String()]
var pickedRoute routeNode
if len(bestRoutes) == 1 {
pickedRoute = bestRoutes[0]
} else if len(bestRoutes) > 1 {
bucketFunc := func(rn routeNode) int {
return lib.HashString(rn.gateway)
}
// Else there is more than one candidate so consistently hash
pickedRoute = lib.ConsistentHash(bestRoutes, node, bucketFunc, m.hashFunc)
}
if pickedRoute.gateway == pubKey.String() {
allowedips = append(allowedips, *pickedRoute.route.GetDestination())
}
}
keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second
existing := slices.IndexFunc(device.Peers, func(p wgtypes.Peer) bool {
pubKey, _ := node.GetPublicKey()
return p.PublicKey.String() == pubKey.String()
})
endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint())
if err != nil {
return nil, err
}
// Don't override the existing IP in case it already exists
if existing != -1 {
endpoint = device.Peers[existing].Endpoint
}
peerConfig := wgtypes.PeerConfig{
PublicKey: pubKey,
Endpoint: endpoint,
AllowedIPs: allowedips,
PersistentKeepaliveInterval: &keepAlive,
ReplaceAllowedIPs: true,
}
return &peerConfig, nil
}
// getRoutes: finds the routes with the least hop distance. If more than one route exists
// consistently hash to evenly spread the distribution of traffic
func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]routeNode {
mesh, _ := meshProvider.GetMesh()
routes := make(map[string][]routeNode)
peers := lib.Filter(lib.MapValues(mesh.GetNodes()), func(p MeshNode) bool {
return p.GetType() == conf.PEER_ROLE
})
meshPrefixes := lib.Map(lib.MapValues(m.meshManager.GetMeshes()), func(mesh MeshProvider) *net.IPNet {
ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
return ipNet
})
for _, node := range mesh.GetNodes() {
pubKey, _ := node.GetPublicKey()
for _, route := range node.GetRoutes() {
if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool {
v6Default, _, _ := net.ParseCIDR("::/0")
v4Default, _, _ := net.ParseCIDR("0.0.0.0/0")
if (prefix.IP.Equal(v6Default) || prefix.IP.Equal(v4Default)) && m.config.AdvertiseDefaultRoute {
return true
}
return prefix.Contains(route.GetDestination().IP)
}) {
continue
}
destination := route.GetDestination().String()
otherRoute, ok := routes[destination]
rn := routeNode{
gateway: pubKey.String(),
route: route,
}
// Client's only acessible by another peer
if node.GetType() == conf.CLIENT_ROLE {
peer := m.getCorrespondingPeer(peers, node)
self, _ := m.meshManager.GetSelf(meshProvider.GetMeshId())
// If the node isn't the self use that peer as the gateway
if !NodeEquals(peer, self) {
peerPub, _ := peer.GetPublicKey()
rn.gateway = peerPub.String()
rn.route = &RouteStub{
Destination: rn.route.GetDestination(),
HopCount: rn.route.GetHopCount() + 1,
// Append the path to this peer
Path: append(rn.route.GetPath(), peer.GetWgHost().IP.String()),
}
}
}
if !ok {
otherRoute = make([]routeNode, 1)
otherRoute[0] = rn
routes[destination] = otherRoute
} else if route.GetHopCount() < otherRoute[0].route.GetHopCount() {
otherRoute[0] = rn
} else if otherRoute[0].route.GetHopCount() == route.GetHopCount() {
logging.Log.WriteInfof("Other Route Hop: %d", otherRoute[0].route.GetHopCount())
logging.Log.WriteInfof("Route gateway %s, route hop %d", rn.gateway, route.GetHopCount())
routes[destination] = append(otherRoute, rn)
}
}
}
return routes
}
// getCorrespondignPeer: gets the peer corresponding to the client
func (m *WgMeshConfigApplyer) getCorrespondingPeer(peers []MeshNode, client MeshNode) MeshNode {
peer := lib.ConsistentHash(peers, client, m.hashFunc, m.hashFunc)
return peer
}
type GetConfigParams struct {
mesh MeshProvider
peers []MeshNode
clients []MeshNode
dev *wgtypes.Device
routes map[string][]routeNode
}
func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes.Config, error) {
self, err := m.meshManager.GetSelf(params.mesh.GetMeshId())
ula := &ip.ULABuilder{}
meshNet, _ := ula.GetIPNet(params.mesh.GetMeshId())
routes := lib.Map(lib.MapKeys(params.routes), func(destination string) net.IPNet {
_, ipNet, _ := net.ParseCIDR(destination)
return *ipNet
})
routes = append(routes, *meshNet)
if err != nil {
return nil, err
}
peer := m.getCorrespondingPeer(params.peers, self)
pubKey, _ := peer.GetPublicKey()
keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second
endpoint, err := net.ResolveUDPAddr("udp", peer.GetWgEndpoint())
if err != nil {
return nil, err
}
peerCfgs := make([]wgtypes.PeerConfig, 1)
peerCfgs[0] = wgtypes.PeerConfig{
PublicKey: pubKey,
Endpoint: endpoint,
PersistentKeepaliveInterval: &keepAlive,
AllowedIPs: routes,
ReplaceAllowedIPs: true,
}
installedRoutes := make([]lib.Route, 0)
for _, route := range peerCfgs[0].AllowedIPs {
installedRoutes = append(installedRoutes, lib.Route{
Gateway: peer.GetWgHost().IP,
Destination: route,
})
}
cfg := wgtypes.Config{
Peers: peerCfgs,
}
m.routeInstaller.InstallRoutes(params.dev.Name, installedRoutes...)
return &cfg, err
}
func (m *WgMeshConfigApplyer) getRoutesToInstall(wgNode *wgtypes.PeerConfig, mesh MeshProvider, node MeshNode) []lib.Route {
routes := make([]lib.Route, 0)
for _, route := range wgNode.AllowedIPs {
ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
_, defaultRoute, _ := net.ParseCIDR("::/0")
if !ipNet.Contains(route.IP) && !ipNet.IP.Equal(defaultRoute.IP) {
routes = append(routes, lib.Route{
Gateway: node.GetWgHost().IP,
Destination: route,
})
}
}
return routes
}
func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.Config, error) {
peerToClients := make(map[string][]net.IPNet)
installedRoutes := make([]lib.Route, 0)
peerConfigs := make([]wgtypes.PeerConfig, 0)
self, err := m.meshManager.GetSelf(params.mesh.GetMeshId())
if err != nil {
return nil, err
}
for _, n := range params.clients {
if len(params.peers) > 0 {
peer := m.getCorrespondingPeer(params.peers, n)
pubKey, _ := peer.GetPublicKey()
clients, ok := peerToClients[pubKey.String()]
if !ok {
clients = make([]net.IPNet, 0)
peerToClients[pubKey.String()] = clients
}
peerToClients[pubKey.String()] = append(clients, *n.GetWgHost())
if NodeEquals(self, peer) {
cfg, err := m.convertMeshNode(n, params.dev, peerToClients, params.routes)
if err != nil {
return nil, err
}
installedRoutes = append(installedRoutes, m.getRoutesToInstall(cfg, params.mesh, n)...)
peerConfigs = append(peerConfigs, *cfg)
}
}
}
for _, n := range params.peers {
if NodeEquals(n, self) {
continue
}
peer, err := m.convertMeshNode(n, params.dev, peerToClients, params.routes)
if err != nil {
return nil, err
}
installedRoutes = append(installedRoutes, m.getRoutesToInstall(peer, params.mesh, n)...)
peerConfigs = append(peerConfigs, *peer)
}
cfg := wgtypes.Config{
Peers: peerConfigs,
ReplacePeers: true,
}
err = m.routeInstaller.InstallRoutes(params.dev.Name, installedRoutes...)
return &cfg, err
}
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
snap, err := mesh.GetMesh()
@ -72,13 +335,19 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
}
nodes := lib.MapValues(snap.GetNodes())
peerConfigs := make([]wgtypes.PeerConfig, len(nodes))
dev, _ := mesh.GetDevice()
slices.SortFunc(nodes, func(a, b MeshNode) int {
return strings.Compare(string(a.GetType()), string(b.GetType()))
})
peers := lib.Filter(nodes, func(mn MeshNode) bool {
return mn.GetType() == conf.PEER_ROLE
})
var count int = 0
clients := lib.Filter(nodes, func(mn MeshNode) bool {
return mn.GetType() == conf.CLIENT_ROLE
})
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
@ -86,71 +355,35 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
return err
}
rtnl, err := lib.NewRtNetlinkConfig()
var cfg *wgtypes.Config = nil
routes := m.getRoutes(mesh)
configParams := &GetConfigParams{
mesh: mesh,
peers: peers,
clients: clients,
dev: dev,
routes: routes,
}
switch self.GetType() {
case conf.PEER_ROLE:
cfg, err = m.getPeerConfig(configParams)
case conf.CLIENT_ROLE:
cfg, err = m.getClientConfig(configParams)
}
if err != nil {
return err
}
peerToClients := make(map[string][]net.IPNet)
for _, n := range nodes {
if NodeEquals(n, self) {
continue
}
if n.GetType() == conf.CLIENT_ROLE && len(peers) > 0 && self.GetType() == conf.CLIENT_ROLE {
peer := lib.ConsistentHash(peers, n, func(mn MeshNode) int {
return lib.HashString(mn.GetWgHost().String())
})
dev, err := mesh.GetDevice()
if err != nil {
return err
}
rtnl.AddRoute(dev.Name, lib.Route{
Gateway: peer.GetWgHost().IP,
Destination: *n.GetWgHost(),
})
if err != nil {
return err
}
clients, ok := peerToClients[peer.GetWgHost().String()]
if !ok {
clients = make([]net.IPNet, 0)
peerToClients[peer.GetWgHost().String()] = clients
}
peerToClients[peer.GetWgHost().String()] = append(clients, *n.GetWgHost())
continue
}
peer, err := m.convertMeshNode(n, peerToClients)
if err != nil {
return err
}
peerConfigs[count] = *peer
count++
}
cfg := wgtypes.Config{
Peers: peerConfigs,
}
dev, err := mesh.GetDevice()
err = m.meshManager.GetClient().ConfigureDevice(dev.Name, *cfg)
if err != nil {
return err
}
return m.meshManager.GetClient().ConfigureDevice(dev.Name, cfg)
return nil
}
func (m *WgMeshConfigApplyer) ApplyConfig() error {
@ -179,8 +412,8 @@ func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error {
}
m.meshManager.GetClient().ConfigureDevice(dev.Name, wgtypes.Config{
Peers: make([]wgtypes.PeerConfig, 0),
ReplacePeers: true,
Peers: make([]wgtypes.PeerConfig, 1),
})
return nil
@ -194,5 +427,9 @@ func NewWgMeshConfigApplyer(config *conf.WgMeshConfiguration) MeshConfigApplyer
return &WgMeshConfigApplyer{
config: config,
routeInstaller: route.NewRouteInstaller(),
hashFunc: func(mn MeshNode) int {
pubKey, _ := mn.GetPublicKey()
return lib.HashString(pubKey.String())
},
}
}

View File

@ -61,7 +61,7 @@ func (c *MeshDOTConverter) graphNode(g *graph.Graph, node MeshNode, meshId strin
self, _ := c.manager.GetSelf(meshId)
if node.GetHostEndpoint() == self.GetHostEndpoint() {
if NodeEquals(self, node) {
return
}

View File

@ -3,11 +3,11 @@ package mesh
import (
"errors"
"fmt"
"sync"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/wg"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@ -18,8 +18,7 @@ type MeshManager interface {
AddMesh(params *AddMeshParams) error
HasChanges(meshid string) bool
GetMesh(meshId string) MeshProvider
EnableInterface(meshId string) error
GetPublicKey(meshId string) (*wgtypes.Key, error)
GetPublicKey() *wgtypes.Key
AddSelf(params *AddSelfParams) error
LeaveMesh(meshId string) error
GetSelf(meshId string) (MeshNode, error)
@ -35,9 +34,11 @@ type MeshManager interface {
Close() error
GetMonitor() MeshMonitor
GetNode(string, string) MeshNode
GetRouteManager() RouteManager
}
type MeshManagerImpl struct {
lock sync.RWMutex
Meshes map[string]MeshProvider
RouteManager RouteManager
Client *wgctrl.Client
@ -52,12 +53,18 @@ type MeshManagerImpl struct {
ipAllocator ip.IPAllocator
interfaceManipulator wg.WgInterfaceManipulator
Monitor MeshMonitor
OnDelete func(MeshProvider)
}
// GetRouteManager implements MeshManager.
func (m *MeshManagerImpl) GetRouteManager() RouteManager {
return m.RouteManager
}
// RemoveService implements MeshManager.
func (m *MeshManagerImpl) RemoveService(service string) error {
for _, mesh := range m.Meshes {
err := mesh.RemoveService(m.HostParameters.HostEndpoint, service)
err := mesh.RemoveService(m.HostParameters.GetPublicKey(), service)
if err != nil {
return err
@ -70,7 +77,7 @@ func (m *MeshManagerImpl) RemoveService(service string) error {
// SetService implements MeshManager.
func (m *MeshManagerImpl) SetService(service string, value string) error {
for _, mesh := range m.Meshes {
err := mesh.AddService(m.HostParameters.HostEndpoint, service, value)
err := mesh.AddService(m.HostParameters.GetPublicKey(), service, value)
if err != nil {
return err
@ -104,7 +111,7 @@ func (m *MeshManagerImpl) GetMonitor() MeshMonitor {
// Prune implements MeshManager.
func (m *MeshManagerImpl) Prune() error {
for _, mesh := range m.Meshes {
err := mesh.Prune(m.conf.PruneTime)
err := mesh.Prune()
if err != nil {
return err
@ -125,7 +132,7 @@ func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
}
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(port)
ifName, err = m.interfaceManipulator.CreateInterface(port, m.HostParameters.PrivateKey)
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
@ -138,13 +145,16 @@ func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
Conf: m.conf,
Client: m.Client,
MeshId: meshId,
NodeID: m.HostParameters.GetPublicKey(),
})
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
}
m.lock.Lock()
m.Meshes[meshId] = nodeManager
m.lock.Unlock()
return meshId, nil
}
@ -160,7 +170,7 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
var err error
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(params.WgPort)
ifName, err = m.interfaceManipulator.CreateInterface(params.WgPort, m.HostParameters.PrivateKey)
if err != nil {
return err
@ -173,6 +183,7 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
Conf: m.conf,
Client: m.Client,
MeshId: params.MeshId,
NodeID: m.HostParameters.GetPublicKey(),
})
if err != nil {
@ -185,7 +196,9 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
return err
}
m.lock.Lock()
m.Meshes[params.MeshId] = meshProvider
m.lock.Unlock()
return nil
}
@ -200,43 +213,10 @@ func (m *MeshManagerImpl) GetMesh(meshId string) MeshProvider {
return theMesh
}
// EnableInterface: Enables the given WireGuard interface.
func (s *MeshManagerImpl) EnableInterface(meshId string) error {
err := s.configApplyer.ApplyConfig()
if err != nil {
return err
}
err = s.RouteManager.InstallRoutes()
if err != nil {
return err
}
return nil
}
// GetPublicKey: Gets the public key of the WireGuard mesh
func (s *MeshManagerImpl) GetPublicKey(meshId string) (*wgtypes.Key, error) {
if s.conf.StubWg {
zeroedKey := make([]byte, wgtypes.KeyLen)
return (*wgtypes.Key)(zeroedKey), nil
}
mesh, ok := s.Meshes[meshId]
if !ok {
return nil, errors.New("mesh does not exist")
}
dev, err := mesh.GetDevice()
if err != nil {
return nil, err
}
return &dev.PublicKey, nil
func (s *MeshManagerImpl) GetPublicKey() *wgtypes.Key {
key := s.HostParameters.PrivateKey.PublicKey()
return &key
}
type AddSelfParams struct {
@ -266,24 +246,19 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
params.WgPort = device.ListenPort
}
pubKey, err := s.GetPublicKey(params.MeshId)
pubKey := s.HostParameters.PrivateKey.PublicKey()
if err != nil {
return err
}
nodeIP, err := s.ipAllocator.GetIP(*pubKey, params.MeshId)
nodeIP, err := s.ipAllocator.GetIP(pubKey, params.MeshId)
if err != nil {
return err
}
node := s.nodeFactory.Build(&MeshNodeFactoryParams{
PublicKey: pubKey,
PublicKey: &pubKey,
NodeIP: nodeIP,
WgPort: params.WgPort,
Endpoint: params.Endpoint,
Role: s.conf.Role,
})
if !s.conf.StubWg {
@ -306,29 +281,43 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
// LeaveMesh leaves the mesh network
func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
mesh, exists := s.Meshes[meshId]
mesh := s.GetMesh(meshId)
if !exists {
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
err := s.RouteManager.RemoveRoutes(meshId)
var err error
s.RouteManager.RemoveRoutes(meshId)
err = mesh.RemoveNode(s.HostParameters.GetPublicKey())
if err != nil {
return err
}
if !s.conf.StubWg {
device, e := mesh.GetDevice()
if s.OnDelete != nil {
s.OnDelete(mesh)
}
if e != nil {
s.lock.Lock()
delete(s.Meshes, meshId)
s.lock.Unlock()
if !s.conf.StubWg {
device, err := mesh.GetDevice()
if err != nil {
return err
}
err = s.interfaceManipulator.RemoveInterface(device.Name)
if err != nil {
return err
}
}
delete(s.Meshes, meshId)
return err
}
@ -339,8 +328,7 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
return nil, fmt.Errorf("mesh %s does not exist", meshId)
}
logging.Log.WriteInfof(s.HostParameters.HostEndpoint)
node, err := meshInstance.GetNode(s.HostParameters.HostEndpoint)
node, err := meshInstance.GetNode(s.HostParameters.GetPublicKey())
if err != nil {
return nil, errors.New("the node doesn't exist in the mesh")
@ -364,9 +352,10 @@ func (s *MeshManagerImpl) ApplyConfig() error {
}
func (s *MeshManagerImpl) SetDescription(description string) error {
for _, mesh := range s.Meshes {
if mesh.NodeExists(s.HostParameters.HostEndpoint) {
err := mesh.SetDescription(s.HostParameters.HostEndpoint, description)
meshes := s.GetMeshes()
for _, mesh := range meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetDescription(s.HostParameters.GetPublicKey(), description)
if err != nil {
return err
@ -379,9 +368,10 @@ func (s *MeshManagerImpl) SetDescription(description string) error {
// SetAlias implements MeshManager.
func (s *MeshManagerImpl) SetAlias(alias string) error {
for _, mesh := range s.Meshes {
if mesh.NodeExists(s.HostParameters.HostEndpoint) {
err := mesh.SetAlias(s.HostParameters.HostEndpoint, alias)
meshes := s.GetMeshes()
for _, mesh := range meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetAlias(s.HostParameters.GetPublicKey(), alias)
if err != nil {
return err
@ -393,9 +383,10 @@ func (s *MeshManagerImpl) SetAlias(alias string) error {
// UpdateTimeStamp updates the timestamp of this node in all meshes
func (s *MeshManagerImpl) UpdateTimeStamp() error {
for _, mesh := range s.Meshes {
if mesh.NodeExists(s.HostParameters.HostEndpoint) {
err := mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint)
meshes := s.GetMeshes()
for _, mesh := range meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.UpdateTimeStamp(s.HostParameters.GetPublicKey())
if err != nil {
return err
@ -411,7 +402,16 @@ func (s *MeshManagerImpl) GetClient() *wgctrl.Client {
}
func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider {
return s.Meshes
meshes := make(map[string]MeshProvider)
s.lock.RLock()
for id, mesh := range s.Meshes {
meshes[id] = mesh
}
s.lock.RUnlock()
return meshes
}
// Close the mesh manager
@ -448,22 +448,16 @@ type NewMeshManagerParams struct {
InterfaceManipulator wg.WgInterfaceManipulator
ConfigApplyer MeshConfigApplyer
RouteManager RouteManager
OnDelete func(MeshProvider)
}
// Creates a new instance of a mesh manager with the given parameters
func NewMeshManager(params *NewMeshManagerParams) MeshManager {
hostParams := HostParameters{}
switch params.Conf.Endpoint {
case "":
ip, _ := lib.GetPublicIP()
hostParams.HostEndpoint = fmt.Sprintf("%s:%s", ip.String(), params.Conf.GrpcPort)
default:
hostParams.HostEndpoint = fmt.Sprintf("%s:%s", params.Conf.Endpoint, params.Conf.GrpcPort)
privateKey, _ := wgtypes.GeneratePrivateKey()
hostParams := HostParameters{
PrivateKey: &privateKey,
}
logging.Log.WriteInfof("Endpoint %s", hostParams.HostEndpoint)
m := &MeshManagerImpl{
Meshes: make(map[string]MeshProvider),
HostParameters: &hostParams,
@ -477,7 +471,7 @@ func NewMeshManager(params *NewMeshManagerParams) MeshManager {
m.RouteManager = params.RouteManager
if m.RouteManager == nil {
m.RouteManager = NewRouteManager(m)
m.RouteManager = NewRouteManager(m, &params.Conf)
}
m.idGenerator = params.IdGenerator
@ -489,5 +483,6 @@ func NewMeshManager(params *NewMeshManagerParams) MeshManager {
aliasManager := NewAliasManager()
m.Monitor.AddUpdateCallback(aliasManager.AddAliases)
m.Monitor.AddRemoveCallback(aliasManager.RemoveAliases)
m.OnDelete = params.OnDelete
return m
}

View File

@ -64,7 +64,6 @@ func TestAddMeshAddsAMesh(t *testing.T) {
manager.AddMesh(&AddMeshParams{
MeshId: meshId,
DevName: "wg0",
WgPort: 6000,
MeshBytes: make([]byte, 0),
})
@ -83,7 +82,6 @@ func TestAddMeshMeshAlreadyExistsReplacesIt(t *testing.T) {
for i := 0; i < 2; i++ {
err := manager.AddMesh(&AddMeshParams{
MeshId: meshId,
DevName: "wg0",
WgPort: 6000,
MeshBytes: make([]byte, 0),
})
@ -106,7 +104,6 @@ func TestAddSelfAddsSelfToTheMesh(t *testing.T) {
err := manager.AddMesh(&AddMeshParams{
MeshId: meshId,
DevName: "wg0",
WgPort: 6000,
MeshBytes: make([]byte, 0),
})
@ -175,7 +172,6 @@ func TestLeaveMeshDeletesMesh(t *testing.T) {
err := manager.AddMesh(&AddMeshParams{
MeshId: meshId,
DevName: "wg0",
WgPort: 6000,
MeshBytes: make([]byte, 0),
})

View File

@ -1,25 +1,22 @@
package mesh
import (
"fmt"
"net"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/route"
"golang.org/x/sys/unix"
)
type RouteManager interface {
UpdateRoutes() error
InstallRoutes() error
RemoveRoutes(meshId string) error
}
type RouteManagerImpl struct {
meshManager MeshManager
routeInstaller route.RouteInstaller
meshManager MeshManager
conf *conf.WgMeshConfiguration
}
func (r *RouteManagerImpl) UpdateRoutes() error {
@ -27,6 +24,35 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
ulaBuilder := new(ip.ULABuilder)
for _, mesh1 := range meshes {
self, err := r.meshManager.GetSelf(mesh1.GetMeshId())
if err != nil {
return err
}
pubKey, err := self.GetPublicKey()
if err != nil {
return err
}
routeMap, err := mesh1.GetRoutes(pubKey.String())
if err != nil {
return err
}
if r.conf.AdvertiseDefaultRoute {
_, ipv6Default, _ := net.ParseCIDR("::/0")
mesh1.AddRoutes(NodeID(self),
&RouteStub{
Destination: ipv6Default,
HopCount: 0,
Path: make([]string, 0),
})
}
for _, mesh2 := range meshes {
if mesh1 == mesh2 {
continue
@ -39,13 +65,13 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
return err
}
self, err := r.meshManager.GetSelf(mesh1.GetMeshId())
routes := lib.MapValues(routeMap)
if err != nil {
return err
}
err = mesh1.AddRoutes(self.GetHostEndpoint(), ipNet.String())
err = mesh2.AddRoutes(NodeID(self), append(routes, &RouteStub{
Destination: ipNet,
HopCount: 0,
Path: make([]string, 0),
})...)
if err != nil {
return err
@ -74,111 +100,11 @@ func (r *RouteManagerImpl) RemoveRoutes(meshId string) error {
return err
}
mesh1.RemoveRoutes(self.GetHostEndpoint(), ipNet.String())
mesh1.RemoveRoutes(NodeID(self), ipNet.String())
}
return nil
}
// AddRoute adds a route to the given interface
func (m *RouteManagerImpl) addRoute(ifName string, meshPrefix string, routes ...lib.Route) error {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return fmt.Errorf("failed to create config: %w", err)
}
defer rtnl.Close()
// Delete any routes that may be vacant
err = rtnl.DeleteRoutes(ifName, unix.AF_INET6, routes...)
if err != nil {
return err
}
for _, route := range routes {
if route.Destination.String() == meshPrefix {
continue
}
err = rtnl.AddRoute(ifName, route)
if err != nil {
return err
}
}
return nil
}
func (m *RouteManagerImpl) installRoute(ifName string, meshid string, node MeshNode) error {
routeMapFunc := func(route string) lib.Route {
_, cidr, _ := net.ParseCIDR(route)
r := lib.Route{
Destination: *cidr,
Gateway: node.GetWgHost().IP,
}
return r
}
ipBuilder := &ip.ULABuilder{}
ipNet, err := ipBuilder.GetIPNet(meshid)
if err != nil {
return err
}
routes := lib.Map(append(node.GetRoutes(), ipNet.String()), routeMapFunc)
return m.addRoute(ifName, ipNet.String(), routes...)
}
func (m *RouteManagerImpl) installRoutes(meshProvider MeshProvider) error {
mesh, err := meshProvider.GetMesh()
if err != nil {
return err
}
dev, err := meshProvider.GetDevice()
if err != nil {
return err
}
self, err := m.meshManager.GetSelf(meshProvider.GetMeshId())
if err != nil {
return err
}
for _, node := range mesh.GetNodes() {
if self.GetHostEndpoint() == node.GetHostEndpoint() {
continue
}
err = m.installRoute(dev.Name, meshProvider.GetMeshId(), node)
if err != nil {
return err
}
}
return nil
}
// InstallRoutes installs all routes to the RIB
func (r *RouteManagerImpl) InstallRoutes() error {
for _, mesh := range r.meshManager.GetMeshes() {
err := r.installRoutes(mesh)
if err != nil {
return err
}
}
return nil
}
func NewRouteManager(m MeshManager) RouteManager {
return &RouteManagerImpl{meshManager: m, routeInstaller: route.NewRouteInstaller()}
func NewRouteManager(m MeshManager, conf *conf.WgMeshConfiguration) RouteManager {
return &RouteManagerImpl{meshManager: m, conf: conf}
}

View File

@ -16,14 +16,14 @@ type MeshNodeStub struct {
wgEndpoint string
wgHost *net.IPNet
timeStamp int64
routes []string
routes []Route
identifier string
description string
}
// GetType implements MeshNode.
func (*MeshNodeStub) GetType() conf.NodeType {
return PEER
return conf.PEER_ROLE
}
// GetServices implements MeshNode.
@ -56,7 +56,7 @@ func (m *MeshNodeStub) GetTimeStamp() int64 {
return m.timeStamp
}
func (m *MeshNodeStub) GetRoutes() []string {
func (m *MeshNodeStub) GetRoutes() []Route {
return m.routes
}
@ -81,6 +81,20 @@ type MeshProviderStub struct {
snapshot *MeshSnapshotStub
}
// Mark implements MeshProvider.
func (*MeshProviderStub) Mark(nodeId string) {
panic("unimplemented")
}
// RemoveNode implements MeshProvider.
func (*MeshProviderStub) RemoveNode(nodeId string) error {
panic("unimplemented")
}
func (*MeshProviderStub) GetRoutes(targetId string) (map[string]Route, error) {
return nil, nil
}
// GetNodeIds implements MeshProvider.
func (*MeshProviderStub) GetPeers() []string {
return make([]string, 0)
@ -108,7 +122,7 @@ func (*MeshProviderStub) RemoveService(nodeId string, key string) error {
// SetAlias implements MeshProvider.
func (*MeshProviderStub) SetAlias(nodeId string, alias string) error {
panic("unimplemented")
return nil
}
// RemoveRoutes implements MeshProvider.
@ -117,7 +131,7 @@ func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error {
}
// Prune implements MeshProvider.
func (*MeshProviderStub) Prune(pruneAmount int) error {
func (*MeshProviderStub) Prune() error {
return nil
}
@ -159,7 +173,7 @@ func (s *MeshProviderStub) HasChanges() bool {
return false
}
func (s *MeshProviderStub) AddRoutes(nodeId string, route ...string) error {
func (s *MeshProviderStub) AddRoutes(nodeId string, route ...Route) error {
return nil
}
@ -193,7 +207,7 @@ func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode {
wgEndpoint: fmt.Sprintf("%s:%s", params.Endpoint, s.Config.GrpcPort),
wgHost: wgHost,
timeStamp: time.Now().Unix(),
routes: make([]string, 0),
routes: make([]Route, 0),
identifier: "abc",
description: "A Mesh Node Stub",
}
@ -216,6 +230,11 @@ type MeshManagerStub struct {
meshes map[string]MeshProvider
}
// GetRouteManager implements MeshManager.
func (*MeshManagerStub) GetRouteManager() RouteManager {
panic("unimplemented")
}
// GetNode implements MeshManager.
func (*MeshManagerStub) GetNode(string, string) MeshNode {
panic("unimplemented")
@ -278,13 +297,9 @@ func (m *MeshManagerStub) GetMesh(meshId string) MeshProvider {
snapshot: &MeshSnapshotStub{nodes: make(map[string]MeshNode)}}
}
func (m *MeshManagerStub) EnableInterface(meshId string) error {
return nil
}
func (m *MeshManagerStub) GetPublicKey(meshId string) (*wgtypes.Key, error) {
func (m *MeshManagerStub) GetPublicKey() *wgtypes.Key {
key, _ := wgtypes.GenerateKey()
return &key, nil
return &key
}
func (m *MeshManagerStub) AddSelf(params *AddSelfParams) error {

View File

@ -10,12 +10,32 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
const (
// Data Exchanged Between Peers
PEER conf.NodeType = "peer"
// Data Exchanged Between Clients
CLIENT conf.NodeType = "client"
)
type Route interface {
// GetDestination: returns the destination of the route
GetDestination() *net.IPNet
// GetHopCount: get the total hopcount of the prefix
GetHopCount() int
// GetPath: get a list of AS paths to get to the destination
GetPath() []string
}
type RouteStub struct {
Destination *net.IPNet
HopCount int
Path []string
}
func (r *RouteStub) GetDestination() *net.IPNet {
return r.Destination
}
func (r *RouteStub) GetHopCount() int {
return r.HopCount
}
func (r *RouteStub) GetPath() []string {
return r.Path
}
// MeshNode represents an implementation of a node in a mesh
type MeshNode interface {
@ -30,7 +50,7 @@ type MeshNode interface {
// GetTimestamp: get the UNIX time stamp of the ndoe
GetTimeStamp() int64
// GetRoutes: returns the routes that the nodes provides
GetRoutes() []string
GetRoutes() []Route
// GetIdentifier: returns the identifier of the node
GetIdentifier() string
// GetDescription: returns the description for this node
@ -45,7 +65,20 @@ type MeshNode interface {
// NodeEquals: determines if two mesh nodes are equivalent to one another
func NodeEquals(node1, node2 MeshNode) bool {
return node1.GetHostEndpoint() == node2.GetHostEndpoint()
key1, _ := node1.GetPublicKey()
key2, _ := node2.GetPublicKey()
return key1.String() == key2.String()
}
func RouteEquals(route1, route2 Route) bool {
return route1.GetDestination().String() == route2.GetDestination().String() &&
route1.GetHopCount() == route2.GetHopCount()
}
func NodeID(node MeshNode) string {
key, _ := node.GetPublicKey()
return key.String()
}
type MeshSnapshot interface {
@ -81,7 +114,7 @@ type MeshProvider interface {
// UpdateTimeStamp: update the timestamp of the given node
UpdateTimeStamp(nodeId string) error
// AddRoutes: adds routes to the given node
AddRoutes(nodeId string, route ...string) error
AddRoutes(nodeId string, route ...Route) error
// DeleteRoutes: deletes the routes from the node
RemoveRoutes(nodeId string, route ...string) error
// GetSyncer: returns the automerge syncer for sync
@ -98,15 +131,28 @@ type MeshProvider interface {
AddService(nodeId, key, value string) error
// RemoveService: removes the service form the node. throws an error if the service does not exist
RemoveService(nodeId, key string) error
// Prune: prunes all nodes that have not updated their timestamp in
// pruneAmount seconds
Prune(pruneAmount int) error
// Prune: prunes all nodes that have not updated their
// vector clock
Prune() error
// GetPeers: get a list of contactable peers
GetPeers() []string
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen
GetRoutes(targetNode string) (map[string]Route, error)
// RemoveNode(): remove the node from the mesh
RemoveNode(nodeId string) error
// Mark: marks the node as unreachable. This is not broadcast to the entire
// this is not considered when syncing node state
Mark(nodeId string)
}
// HostParameters contains the IDs of a node
type HostParameters struct {
HostEndpoint string
PrivateKey *wgtypes.Key
}
// GetPublicKey: gets the public key of the node
func (h *HostParameters) GetPublicKey() string {
return h.PrivateKey.PublicKey().String()
}
// MeshProviderFactoryParams parameters required to build a mesh provider
@ -116,6 +162,7 @@ type MeshProviderFactoryParams struct {
Port int
Conf *conf.WgMeshConfiguration
Client *wgctrl.Client
NodeID string
}
// MeshProviderFactory creates an instance of a mesh provider
@ -130,7 +177,6 @@ type MeshNodeFactoryParams struct {
NodeIP net.IP
WgPort int
Endpoint string
Role conf.NodeType
}
// MeshBuilder build the hosts mesh node for it to be added to the mesh

View File

@ -3,6 +3,7 @@ package query
import (
"encoding/json"
"fmt"
"strings"
"github.com/jmespath/go-jmespath"
"github.com/tim-beatham/wgmesh/pkg/conf"
@ -24,6 +25,12 @@ type QueryError struct {
msg string
}
type QueryRoute struct {
Destination string `json:"destination"`
HopCount int `json:"hopCount"`
Path string `json:"path"`
}
type QueryNode struct {
HostEndpoint string `json:"hostEndpoint"`
PublicKey string `json:"publicKey"`
@ -31,7 +38,7 @@ type QueryNode struct {
WgHost string `json:"wgHost"`
Timestamp int64 `json:"timestamp"`
Description string `json:"description"`
Routes []string `json:"routes"`
Routes []QueryRoute `json:"routes"`
Alias string `json:"alias"`
Services map[string]string `json:"services"`
Type conf.NodeType `json:"type"`
@ -78,7 +85,13 @@ func MeshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode.WgHost = node.GetWgHost().String()
queryNode.Timestamp = node.GetTimeStamp()
queryNode.Routes = node.GetRoutes()
queryNode.Routes = lib.Map(node.GetRoutes(), func(r mesh.Route) QueryRoute {
return QueryRoute{
Destination: r.GetDestination().String(),
HopCount: r.GetHopCount(),
Path: strings.Join(r.GetPath(), ","),
}
})
queryNode.Description = node.GetDescription()
queryNode.Alias = node.GetAlias()
queryNode.Services = node.GetServices()

View File

@ -10,6 +10,7 @@ import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query"
"github.com/tim-beatham/wgmesh/pkg/rpc"
@ -72,7 +73,9 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
configuration := n.Server.GetConfiguration()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(configuration.Timeout))
defer cancel()
meshReply, err := c.GetMesh(ctx, &rpc.GetMeshRequest{MeshId: args.MeshId})
@ -117,19 +120,19 @@ func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error {
}
func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
mesh := n.Server.GetMeshManager().GetMesh(meshId)
theMesh := n.Server.GetMeshManager().GetMesh(meshId)
if mesh == nil {
if theMesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
meshSnapshot, err := mesh.GetMesh()
meshSnapshot, err := theMesh.GetMesh()
if err != nil {
return err
}
if mesh == nil {
if theMesh == nil {
return errors.New("mesh does not exist")
}
@ -149,10 +152,15 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
PublicKey: pubKey.String(),
WgHost: node.GetWgHost().String(),
Timestamp: node.GetTimeStamp(),
Routes: node.GetRoutes(),
Description: node.GetDescription(),
Alias: node.GetAlias(),
Services: node.GetServices(),
Routes: lib.Map(node.GetRoutes(), func(r mesh.Route) ctrlserver.MeshRoute {
return ctrlserver.MeshRoute{
Destination: r.GetDestination().String(),
Path: r.GetPath(),
}
}),
Description: node.GetDescription(),
Alias: node.GetAlias(),
Services: node.GetServices(),
}
nodes[i] = node
@ -163,18 +171,6 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
return nil
}
func (n *IpcHandler) EnableInterface(meshId string, reply *string) error {
err := n.Server.GetMeshManager().EnableInterface(meshId)
if err != nil {
*reply = err.Error()
return err
}
*reply = "up"
return nil
}
func (n *IpcHandler) GetDOT(meshId string, reply *string) error {
g := mesh.NewMeshDotConverter(n.Server.GetMeshManager())

View File

@ -28,7 +28,3 @@ func (m *WgRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.
return &reply, nil
}
func (m *WgRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (*rpc.JoinMeshReply, error) {
return &rpc.JoinMeshReply{Success: true}, nil
}

View File

@ -1,22 +1,36 @@
package route
import (
"net"
"os/exec"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/lib"
"golang.org/x/sys/unix"
)
type RouteInstaller interface {
InstallRoutes(devName string, routes ...*net.IPNet) error
InstallRoutes(devName string, routes ...lib.Route) error
}
type RouteInstallerImpl struct{}
// InstallRoutes: installs a route into the routing table
func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...*net.IPNet) error {
func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...lib.Route) error {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return err
}
ip6Routes := lib.Filter(routes, func(r lib.Route) bool {
return r.Destination.IP.To4() == nil
})
err = rtnl.DeleteRoutes(devName, unix.AF_INET6, ip6Routes...)
if err != nil {
return err
}
for _, route := range routes {
err := r.installRoute(devName, route)
err := rtnl.AddRoute(devName, route)
if err != nil {
return err
@ -26,22 +40,6 @@ func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...*net.IPNet)
return nil
}
// installRoute: installs a route into the linux table
func (r *RouteInstallerImpl) installRoute(devName string, route *net.IPNet) error {
// TODO: Find a library that automates this
cmd := exec.Command("/usr/bin/ip", "-6", "route", "add", route.String(), "dev", devName)
logging.Log.WriteInfof("%s %s", route.String(), devName)
if msg, err := cmd.CombinedOutput(); err != nil {
logging.Log.WriteErrorf(err.Error())
logging.Log.WriteErrorf(string(msg))
return err
}
return nil
}
func NewRouteInstaller() RouteInstaller {
return &RouteInstallerImpl{}
}

View File

@ -20,77 +20,6 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type MeshNode struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
PublicKey string `protobuf:"bytes,1,opt,name=publicKey,proto3" json:"publicKey,omitempty"`
WgEndpoint string `protobuf:"bytes,2,opt,name=wgEndpoint,proto3" json:"wgEndpoint,omitempty"`
Endpoint string `protobuf:"bytes,3,opt,name=endpoint,proto3" json:"endpoint,omitempty"`
WgHost string `protobuf:"bytes,4,opt,name=wgHost,proto3" json:"wgHost,omitempty"`
}
func (x *MeshNode) Reset() {
*x = MeshNode{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *MeshNode) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*MeshNode) ProtoMessage() {}
func (x *MeshNode) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use MeshNode.ProtoReflect.Descriptor instead.
func (*MeshNode) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{0}
}
func (x *MeshNode) GetPublicKey() string {
if x != nil {
return x.PublicKey
}
return ""
}
func (x *MeshNode) GetWgEndpoint() string {
if x != nil {
return x.WgEndpoint
}
return ""
}
func (x *MeshNode) GetEndpoint() string {
if x != nil {
return x.Endpoint
}
return ""
}
func (x *MeshNode) GetWgHost() string {
if x != nil {
return x.WgHost
}
return ""
}
type GetMeshRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
@ -102,7 +31,7 @@ type GetMeshRequest struct {
func (x *GetMeshRequest) Reset() {
*x = GetMeshRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1]
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@ -115,7 +44,7 @@ func (x *GetMeshRequest) String() string {
func (*GetMeshRequest) ProtoMessage() {}
func (x *GetMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1]
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@ -128,7 +57,7 @@ func (x *GetMeshRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetMeshRequest.ProtoReflect.Descriptor instead.
func (*GetMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{1}
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{0}
}
func (x *GetMeshRequest) GetMeshId() string {
@ -149,7 +78,7 @@ type GetMeshReply struct {
func (x *GetMeshReply) Reset() {
*x = GetMeshReply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2]
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@ -162,7 +91,7 @@ func (x *GetMeshReply) String() string {
func (*GetMeshReply) ProtoMessage() {}
func (x *GetMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2]
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@ -175,7 +104,7 @@ func (x *GetMeshReply) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetMeshReply.ProtoReflect.Descriptor instead.
func (*GetMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{2}
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{1}
}
func (x *GetMeshReply) GetMesh() []byte {
@ -185,145 +114,24 @@ func (x *GetMeshReply) GetMesh() []byte {
return nil
}
type JoinMeshRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Changes []byte `protobuf:"bytes,1,opt,name=changes,proto3" json:"changes,omitempty"`
MeshId string `protobuf:"bytes,2,opt,name=meshId,proto3" json:"meshId,omitempty"`
}
func (x *JoinMeshRequest) Reset() {
*x = JoinMeshRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *JoinMeshRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JoinMeshRequest) ProtoMessage() {}
func (x *JoinMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JoinMeshRequest.ProtoReflect.Descriptor instead.
func (*JoinMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{3}
}
func (x *JoinMeshRequest) GetChanges() []byte {
if x != nil {
return x.Changes
}
return nil
}
func (x *JoinMeshRequest) GetMeshId() string {
if x != nil {
return x.MeshId
}
return ""
}
type JoinMeshReply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"`
}
func (x *JoinMeshReply) Reset() {
*x = JoinMeshReply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *JoinMeshReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JoinMeshReply) ProtoMessage() {}
func (x *JoinMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JoinMeshReply.ProtoReflect.Descriptor instead.
func (*JoinMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{4}
}
func (x *JoinMeshReply) GetSuccess() bool {
if x != nil {
return x.Success
}
return false
}
var File_pkg_grpc_ctrlserver_ctrlserver_proto protoreflect.FileDescriptor
var file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc = []byte{
0x0a, 0x24, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73,
0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72,
0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73,
0x22, 0x7c, 0x0a, 0x08, 0x4d, 0x65, 0x73, 0x68, 0x4e, 0x6f, 0x64, 0x65, 0x12, 0x1c, 0x0a, 0x09,
0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52,
0x09, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x77, 0x67,
0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a,
0x77, 0x67, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e,
0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x6e,
0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x77, 0x67, 0x48, 0x6f, 0x73, 0x74,
0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x77, 0x67, 0x48, 0x6f, 0x73, 0x74, 0x22, 0x28,
0x0a, 0x0e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09,
0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x22, 0x22, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x4d,
0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x65, 0x73, 0x68,
0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x22, 0x43, 0x0a, 0x0f,
0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12,
0x18, 0x0a, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c,
0x52, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73,
0x68, 0x49, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49,
0x64, 0x22, 0x29, 0x0a, 0x0d, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70,
0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20,
0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x32, 0x91, 0x01, 0x0a,
0x0e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x74, 0x72, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12,
0x3d, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x18, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e,
0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x40,
0x0a, 0x08, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x19, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73,
0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00,
0x42, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x33,
0x22, 0x28, 0x0a, 0x0e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65,
0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01,
0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x22, 0x22, 0x0a, 0x0c, 0x47, 0x65,
0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x65,
0x73, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x32, 0x4f,
0x0a, 0x0e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x74, 0x72, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72,
0x12, 0x3d, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x18, 0x2e, 0x72, 0x70,
0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73,
0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x42,
0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x33,
}
var (
@ -338,21 +146,16 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP() []byte {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescData
}
var file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
var file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_pkg_grpc_ctrlserver_ctrlserver_proto_goTypes = []interface{}{
(*MeshNode)(nil), // 0: rpctypes.MeshNode
(*GetMeshRequest)(nil), // 1: rpctypes.GetMeshRequest
(*GetMeshReply)(nil), // 2: rpctypes.GetMeshReply
(*JoinMeshRequest)(nil), // 3: rpctypes.JoinMeshRequest
(*JoinMeshReply)(nil), // 4: rpctypes.JoinMeshReply
(*GetMeshRequest)(nil), // 0: rpctypes.GetMeshRequest
(*GetMeshReply)(nil), // 1: rpctypes.GetMeshReply
}
var file_pkg_grpc_ctrlserver_ctrlserver_proto_depIdxs = []int32{
1, // 0: rpctypes.MeshCtrlServer.GetMesh:input_type -> rpctypes.GetMeshRequest
3, // 1: rpctypes.MeshCtrlServer.JoinMesh:input_type -> rpctypes.JoinMeshRequest
2, // 2: rpctypes.MeshCtrlServer.GetMesh:output_type -> rpctypes.GetMeshReply
4, // 3: rpctypes.MeshCtrlServer.JoinMesh:output_type -> rpctypes.JoinMeshReply
2, // [2:4] is the sub-list for method output_type
0, // [0:2] is the sub-list for method input_type
0, // 0: rpctypes.MeshCtrlServer.GetMesh:input_type -> rpctypes.GetMeshRequest
1, // 1: rpctypes.MeshCtrlServer.GetMesh:output_type -> rpctypes.GetMeshReply
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
@ -365,18 +168,6 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
}
if !protoimpl.UnsafeEnabled {
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*MeshNode); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetMeshRequest); i {
case 0:
return &v.state
@ -388,7 +179,7 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
return nil
}
}
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetMeshReply); i {
case 0:
return &v.state
@ -400,30 +191,6 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
return nil
}
}
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*JoinMeshRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*JoinMeshReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
@ -431,7 +198,7 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc,
NumEnums: 0,
NumMessages: 5,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},

View File

@ -23,7 +23,6 @@ const _ = grpc.SupportPackageIsVersion7
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type MeshCtrlServerClient interface {
GetMesh(ctx context.Context, in *GetMeshRequest, opts ...grpc.CallOption) (*GetMeshReply, error)
JoinMesh(ctx context.Context, in *JoinMeshRequest, opts ...grpc.CallOption) (*JoinMeshReply, error)
}
type meshCtrlServerClient struct {
@ -43,21 +42,11 @@ func (c *meshCtrlServerClient) GetMesh(ctx context.Context, in *GetMeshRequest,
return out, nil
}
func (c *meshCtrlServerClient) JoinMesh(ctx context.Context, in *JoinMeshRequest, opts ...grpc.CallOption) (*JoinMeshReply, error) {
out := new(JoinMeshReply)
err := c.cc.Invoke(ctx, "/rpctypes.MeshCtrlServer/JoinMesh", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// MeshCtrlServerServer is the server API for MeshCtrlServer service.
// All implementations must embed UnimplementedMeshCtrlServerServer
// for forward compatibility
type MeshCtrlServerServer interface {
GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error)
JoinMesh(context.Context, *JoinMeshRequest) (*JoinMeshReply, error)
mustEmbedUnimplementedMeshCtrlServerServer()
}
@ -68,9 +57,6 @@ type UnimplementedMeshCtrlServerServer struct {
func (UnimplementedMeshCtrlServerServer) GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetMesh not implemented")
}
func (UnimplementedMeshCtrlServerServer) JoinMesh(context.Context, *JoinMeshRequest) (*JoinMeshReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method JoinMesh not implemented")
}
func (UnimplementedMeshCtrlServerServer) mustEmbedUnimplementedMeshCtrlServerServer() {}
// UnsafeMeshCtrlServerServer may be embedded to opt out of forward compatibility for this service.
@ -102,24 +88,6 @@ func _MeshCtrlServer_GetMesh_Handler(srv interface{}, ctx context.Context, dec f
return interceptor(ctx, in, info, handler)
}
func _MeshCtrlServer_JoinMesh_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(JoinMeshRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(MeshCtrlServerServer).JoinMesh(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/rpctypes.MeshCtrlServer/JoinMesh",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(MeshCtrlServerServer).JoinMesh(ctx, req.(*JoinMeshRequest))
}
return interceptor(ctx, in, info, handler)
}
// MeshCtrlServer_ServiceDesc is the grpc.ServiceDesc for MeshCtrlServer service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@ -131,10 +99,6 @@ var MeshCtrlServer_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetMesh",
Handler: _MeshCtrlServer_GetMesh_Handler,
},
{
MethodName: "JoinMesh",
Handler: _MeshCtrlServer_JoinMesh_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "pkg/grpc/ctrlserver/ctrlserver.proto",

View File

@ -1,8 +1,8 @@
package sync
import (
"io"
"math/rand"
"sync"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
@ -12,7 +12,7 @@ import (
"github.com/tim-beatham/wgmesh/pkg/mesh"
)
// Syncer: picks random nodes from the mesh
// Syncer: picks random nodes from the meshs
type Syncer interface {
Sync(meshId string) error
SyncMeshes() error
@ -25,71 +25,95 @@ type SyncerImpl struct {
syncCount int
cluster conn.ConnCluster
conf *conf.WgMeshConfiguration
lastSync uint64
}
// Sync: Sync random nodes
func (s *SyncerImpl) Sync(meshId string) error {
if !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
logging.Log.WriteInfof("No changes for %s", meshId)
return nil
}
logging.Log.WriteInfof("UPDATING WG CONF")
if s.manager.HasChanges(meshId) {
err := s.manager.ApplyConfig()
if err != nil {
logging.Log.WriteInfof("Failed to update config %w", err)
}
}
nodeNames := s.manager.GetMesh(meshId).GetPeers()
self, err := s.manager.GetSelf(meshId)
if err != nil {
return err
}
neighbours := s.cluster.GetNeighbours(nodeNames, self.GetHostEndpoint())
randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
s.manager.GetMesh(meshId).Prune()
for _, node := range randomSubset {
logging.Log.WriteInfof("Random node: %s", node)
if self.GetType() == conf.PEER_ROLE && !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
logging.Log.WriteInfof("No changes for %s", meshId)
return nil
}
before := time.Now()
s.manager.GetRouteManager().UpdateRoutes()
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
logging.Log.WriteInfof("Sending to random cluster")
interCluster := s.cluster.GetInterCluster(nodeNames, self.GetHostEndpoint())
randomSubset = append(randomSubset, interCluster)
publicKey := s.manager.GetPublicKey()
logging.Log.WriteInfof(publicKey.String())
nodeNames := s.manager.GetMesh(meshId).GetPeers()
var gossipNodes []string
// Clients always pings its peer for configuration
if self.GetType() == conf.CLIENT_ROLE {
keyFunc := lib.HashString
bucketFunc := lib.HashString
neighbour := lib.ConsistentHash(nodeNames, publicKey.String(), keyFunc, bucketFunc)
gossipNodes = make([]string, 1)
gossipNodes[0] = neighbour
} else {
neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
gossipNodes = lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
gossipNodes[len(gossipNodes)-1] = s.cluster.GetInterCluster(nodeNames, publicKey.String())
}
}
var waitGroup sync.WaitGroup
var succeeded bool = false
for index := range randomSubset {
waitGroup.Add(1)
// Do this synchronously to conserve bandwidth
for _, node := range gossipNodes {
correspondingPeer := s.manager.GetNode(meshId, node)
go func(i int) error {
defer waitGroup.Done()
err := s.requester.SyncMesh(meshId, randomSubset[i])
return err
}(index)
if correspondingPeer == nil {
logging.Log.WriteErrorf("node %s does not exist", node)
}
err := s.requester.SyncMesh(meshId, correspondingPeer)
if err == nil || err == io.EOF {
succeeded = true
} else {
// If the synchronisation operation has failed them mark a gravestone
// preventing the peer from being re-contacted until it has updated
// itself
s.manager.GetMesh(meshId).Mark(node)
}
}
waitGroup.Wait()
s.syncCount++
logging.Log.WriteInfof("SYNC TIME: %v", time.Since(before))
logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount)
s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount)
// Check if any changes have occurred and trigger callbacks
// if changes have occurred.
// return s.manager.GetMonitor().Trigger()
if !succeeded {
// If could not gossip with anyone then repeat.
s.infectionCount++
}
s.manager.GetMesh(meshId).SaveChanges()
s.lastSync = uint64(time.Now().Unix())
logging.Log.WriteInfof("UPDATING WG CONF")
err = s.manager.ApplyConfig()
if err != nil {
logging.Log.WriteInfof("Failed to update config %w", err)
}
return nil
}
@ -99,7 +123,7 @@ func (s *SyncerImpl) SyncMeshes() error {
err := s.Sync(meshId)
if err != nil {
return err
logging.Log.WriteErrorf(err.Error())
}
}

View File

@ -17,31 +17,20 @@ type SyncErrorHandlerImpl struct {
meshManager mesh.MeshManager
}
func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint string) bool {
func (s *SyncErrorHandlerImpl) handleFailed(meshId string, nodeId string) bool {
mesh := s.meshManager.GetMesh(meshId)
if mesh == nil {
return false
}
// self, err := s.meshManager.GetSelf(meshId)
// if err != nil {
// return false
// }
// mesh.DecrementHealth(endpoint, self.GetHostEndpoint())
mesh.Mark(nodeId)
return true
}
func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error) bool {
func (s *SyncErrorHandlerImpl) Handle(meshId string, nodeId string, err error) bool {
errStatus, _ := status.FromError(err)
logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message())
switch errStatus.Code() {
case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound:
return s.incrementFailedCount(meshId, endpoint)
return s.handleFailed(meshId, nodeId)
}
return false

View File

@ -15,7 +15,7 @@ import (
// SyncRequester: coordinates the syncing of meshes
type SyncRequester interface {
GetMesh(meshId string, ifName string, port int, endPoint string) error
SyncMesh(meshid string, endPoint string) error
SyncMesh(meshid string, meshNode mesh.MeshNode) error
}
type SyncRequesterImpl struct {
@ -56,8 +56,8 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endP
return err
}
func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error {
ok := s.errorHdlr.Handle(meshId, endpoint, err)
func (s *SyncRequesterImpl) handleErr(meshId, pubKey string, err error) error {
ok := s.errorHdlr.Handle(meshId, pubKey, err)
if ok {
return nil
@ -67,7 +67,10 @@ func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error
}
// SyncMesh: Proactively send a sync request to the other mesh
func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) error {
endpoint := meshNode.GetHostEndpoint()
pubKey, _ := meshNode.GetPublicKey()
peerConnection, err := s.server.ConnectionManager.GetConnection(endpoint)
if err != nil {
@ -96,7 +99,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
err = s.syncMesh(mesh, ctx, c)
if err != nil {
return s.handleErr(meshId, endpoint, err)
return s.handleErr(meshId, pubKey.String(), err)
}
logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId)

View File

@ -8,11 +8,11 @@ import (
// Run implements SyncScheduler.
func syncFunction(syncer Syncer) lib.TimerFunc {
return func() error {
return syncer.SyncMeshes()
syncer.SyncMeshes()
return nil
}
}
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester) *lib.Timer {
syncer := NewSyncer(s.MeshManager, s.Conf, syncRequester)
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester, syncer Syncer) *lib.Timer {
return lib.NewTimer(syncFunction(syncer), int(s.Conf.SyncRate))
}

View File

@ -1,12 +1,14 @@
package timestamp
package timer
import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
)
func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
timerFunc := func() error {
logging.Log.WriteInfof("Updated Timestamp")
return ctrlServer.MeshManager.UpdateTimeStamp()
}

View File

@ -1,5 +1,7 @@
package wg
import "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
type WgError struct {
msg string
}
@ -10,7 +12,7 @@ func (m *WgError) Error() string {
type WgInterfaceManipulator interface {
// CreateInterface creates a WireGuard interface
CreateInterface(port int) (string, error)
CreateInterface(port int, privateKey *wgtypes.Key) (string, error)
// AddAddress adds an address to the given interface name
AddAddress(ifName string, addr string) error
// RemoveInterface removes the specified interface

View File

@ -3,7 +3,6 @@ package wg
import (
"crypto"
"crypto/rand"
"encoding/base64"
"fmt"
"github.com/tim-beatham/wgmesh/pkg/lib"
@ -19,7 +18,7 @@ type WgInterfaceManipulatorImpl struct {
const hashLength = 6
// CreateInterface creates a WireGuard interface
func (m *WgInterfaceManipulatorImpl) CreateInterface(port int) (string, error) {
func (m *WgInterfaceManipulatorImpl) CreateInterface(port int, privKey *wgtypes.Key) (string, error) {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
@ -35,8 +34,7 @@ func (m *WgInterfaceManipulatorImpl) CreateInterface(port int) (string, error) {
}
md5 := crypto.MD5.New().Sum(randomBuf)
md5Str := fmt.Sprintf("wg%s", base64.StdEncoding.EncodeToString(md5)[:hashLength])
md5Str := fmt.Sprintf("wg%x", md5)[:hashLength]
err = rtnl.CreateLink(md5Str)
@ -44,14 +42,8 @@ func (m *WgInterfaceManipulatorImpl) CreateInterface(port int) (string, error) {
return "", fmt.Errorf("failed to create link: %w", err)
}
privateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return "", fmt.Errorf("failed to create private key: %w", err)
}
var cfg wgtypes.Config = wgtypes.Config{
PrivateKey: &privateKey,
PrivateKey: privKey,
ListenPort: &port,
}