1
0
forked from extern/smegmesh

Compare commits

...

69 Commits

Author SHA1 Message Date
1f0914e2df bugfix-node-not-leaving
- Add lock when perform synchronisation on concurrent access
2024-01-04 00:23:20 +00:00
27e00196cd main
- Not waiting in the waitgroup
2024-01-02 20:31:24 +00:00
dea6f1a22d main
- error in code invalid check for nil
2024-01-02 20:19:34 +00:00
913de57568 main
- Fixed bug
2024-01-02 20:11:11 +00:00
ce829114b1 bugfix
- on synchornisation node is not leaving mesh
2024-01-02 19:41:20 +00:00
cd844ff46e - Fixing DNS error 2024-01-02 00:15:23 +00:00
d0b1913796 74-perform-dad
- Fixing nil pointer dereference
2024-01-02 00:13:04 +00:00
90cfe820d2 - Fixing errors with stale paths 2024-01-02 00:09:31 +00:00
8a49809855 74-perform-dad
- Adding go.sum to fix errors
2024-01-01 23:59:04 +00:00
dbc18bddc6 74-perform-dad
- Performing DAD to check if IPv6 address present before adding
  outselves to mesh
- Changing name from wgmesh to smegmesh
2024-01-01 23:55:50 +00:00
36e82dba47 72-pull-rate-in-configuration
- Refactored pull rate into the configuration
- code freeze so no more code changes
2023-12-31 14:25:06 +00:00
3cc87bc252 72-pull-rate-in-configuration
- Updated examples
2023-12-31 12:47:45 +00:00
a9ed7c0a20 72-pull-rate-in-configuration
- Removing libp2p reference
2023-12-31 12:47:45 +00:00
fd29af73e3 72-pull-rate-in-configuration
- Added pull rate to configuration (finally) so this can
be modified by an administrator.
2023-12-31 12:47:45 +00:00
9e1058e0f2 72-pull-rate-in-configuration
- Added the pull rate to the configuration file
2023-12-31 12:47:45 +00:00
c29eb197f3 Merge pull request #71 from tim-beatham/66-ipv6-address-not-conforming-to-spec
66 ipv6 address not conforming to spec
2023-12-30 22:26:53 +00:00
1a9d9d61ad 66-ipv6-address-not-conforming-to-spec
- Missing commit
2023-12-30 22:26:08 +00:00
6954608c32 66-ipv6-address-not-confirming-to-spec
- UUID is not random just a name generator needs changing to shortuuid
- When in multiple meshes there is no wait group
2023-12-30 22:24:43 +00:00
2e6aed6f93 main
- Fixing issue with nil pointer de-reference due to bad design of mesh
  manager.
- Going forward all references to GetSelf should be depracated. It
  introduces a race condition when leaving a mesh network
2023-12-30 00:44:57 +00:00
b0893a0b8e Merge pull request #69 from tim-beatham/60-unit-test-crdt-data-store
60-unit-test-crdt-data-store
2023-12-29 22:06:20 +00:00
e7d6055fa3 60-unit-test-crdt-data-store
Provided unit tests for datastore.go
And fixed unit tets failing by different way of providing CA
2023-12-29 22:05:05 +00:00
e0f3f116b9 main
- Stale serverConfig entry causing certificate authorities
to not become authorised
2023-12-29 19:54:08 +00:00
352648b7cb main
- Fixed problem where connection not removed on error
2023-12-29 11:12:40 +00:00
2d5df25b1d main
- If deadline exceeded error remove connection from
connection manager
2023-12-29 01:29:11 +00:00
cabe173831 main
Adding retry parameter
2023-12-29 01:10:26 +00:00
d2c8a52ec6 main
- Adding retry policy for mobility
2023-12-29 00:58:43 +00:00
bf53108384 main
- Bugfix, fix consistent hash problem where
if failure happens then causes panic
2023-12-28 23:24:38 +00:00
77aac5534b main
- Bugfix in client where "-" was attempted to be parsed as a UDP addr
2023-12-28 17:46:04 +00:00
58439fcd56 main
- Bugfix when keepalivewg is not set causes segmentation fault
- give keepalive a default value of 0 if not set
2023-12-28 17:32:54 +00:00
311a15363a Merge pull request #67 from tim-beatham/66-improve-graph-dot-tool
66 improve graph dot tool
2023-12-25 01:26:15 +00:00
255d3c8b39 66-improve-graph-dot-tool
- Showing services a node provides
- Showing all meshes not just one
- Showing the default route
2023-12-25 01:25:20 +00:00
41899c5831 66-improve-graph-dot-tool
Improving the graph dot tool so that it shows all
meshes
2023-12-25 01:10:11 +00:00
fe4ca66ff6 Merge pull request #65 from tim-beatham/64-2p-set-unit-test
64 2p set unit test
2023-12-22 23:58:59 +00:00
0b91ba744a 61-improve-unit-test-coverage
- Provided unit tests for g_map and 2p_map
2023-12-22 23:57:10 +00:00
67483c2a90 64-unit-test-two-phase-set
Provide unit tests for two phase set to make it more
transparent what exactly they are doing.
2023-12-22 23:57:10 +00:00
af26e81bd3 Merge pull request #63 from tim-beatham/61-improve-unit-testing-coverage
61-improve-unit-testing-coverage
2023-12-22 21:52:46 +00:00
0cc3141b58 61-improve-unit-testing-coverage
- Added missing files to commit
2023-12-22 21:49:47 +00:00
186acbe915 Merge pull request #62 from tim-beatham/61-improve-unit-testing-coverage
61-improve-unit-testing-coverage
2023-12-22 21:49:06 +00:00
ceb43a1db1 61-improve-unit-testing-coverage
- Got unit tests passing
- Improved manager unit tests
2023-12-22 21:47:56 +00:00
bed59f120f Merge pull request #60 from tim-beatham/59-error-when-peer-not-selected
59-error-when-peer-not-selected
2023-12-22 19:12:30 +00:00
8aab4e99d8 59-error-when-peer-not-selected
In the CLI when the peer is not selected
as the type throwing an error stating
either client or peer must be selected
2023-12-22 19:08:20 +00:00
cf4be1ccab Merge pull request #58 from tim-beatham/bugfix-pull-only
Bugfix pull only
2023-12-22 18:49:09 +00:00
6ed32f3a79 bugfix-push-pull
Organised groups as a tree so that there
isn't a limit to dissemination
2023-12-19 00:50:17 +00:00
b6199892f0 bugfix-pull-only
Bugfix with inter-cluster communication pull not working
2023-12-18 22:17:46 +00:00
ad22f04b0d bugfix-pull-only
After certain period of time if no changes have
occurred then pull
2023-12-18 20:45:56 +00:00
092d9a4af5 checking-latency-for-pull-only 2023-12-17 09:44:32 +00:00
19abf712a6 Fixing bug with nodes being removed 2023-12-12 12:45:41 +00:00
b296e1f45a Merge pull request #57 from tim-beatham/55-cli-option-for-peer-type
55-cli-optionifor-peer-type
2023-12-12 12:00:42 +00:00
2dc89d171b 55-cli-optionifor-peer-type
- Ability to specify WireGuard keepalive in the CLI formatter
- Ability to specify publicly routeable endpoint
- Ability to specify whether to advetise routes into the mesh,
and whether to advertise default routes.
2023-12-12 11:58:47 +00:00
13bea10638 main - bugfix
- Nodes not being removed when deleted because when node gossips again
  it is readded.
- Keep track of highest vector clock we have removed and used this as a
  mark for determining if something is stale.
2023-12-11 11:09:02 +00:00
3222d7e388 main - adding WireGuard stats to JSON objects
- Adding WireGuard stats through to IPC calls so that they can be used
by the API
2023-12-11 09:55:25 +00:00
1789d203f6 main - fix default routing being deleted
Default route keeps fluctuating on configuration
update.
2023-12-10 23:35:00 +00:00
a5074a536e main - BUGFIX
- segfault BUGFIX
2023-12-10 22:31:24 +00:00
acb90a5679 main - go.sum should be tracked into the git
- go.sum should be contained in the git history
2023-12-10 22:11:09 +00:00
27ec23f133 Merge pull request #54 from tim-beatham/53-run-commands-pre-up-and-post-down
53-run-commands-pre-up-and-post-down
2023-12-10 19:22:59 +00:00
fe14f63217 53-run-commands-pre-up-and-post-down
- Ability to run a command pre up and post down
- Ability to be a client in one mesh and a peer in the other
- Added dev card to specify different sync rate, keepalive rate per
  mesh.
2023-12-10 19:21:54 +00:00
4a8a39601f Merge pull request #52 from tim-beatham/51-bufix-not-removing-when-withdrawn
51-bugfix-routes-not-removing-when-withdrawn
2023-12-10 15:13:57 +00:00
1e263cc6a8 51-bugfix-routes-not-removing-when-withdrawn
- Routes are not being removed despite being withdrawn from the
configuration.
- Best path routes are not shared across interfaces
- Bug in consistent hashing wrong parameter passed caused by
refactorings.
2023-12-10 15:10:36 +00:00
dae9cd31a1 Merge pull request #50 from tim-beatham/50-give-client-ability-to-bridge-meshes
50-give-client-ability-to-bridge-meshes
2023-12-08 23:58:32 +00:00
f855f53fbf 50-give-client-ability-to-bridge-meshes
Client can act as a route bridging meshes. Cient send keepalives
to all of it's peers in the different meshes act as a bridge between
the meshes
2023-12-08 23:56:07 +00:00
52feb5767b Merge pull request #48 from tim-beatham/47-default-routing
47 default routing
2023-12-08 20:03:45 +00:00
815c4484ee 47-default-routing
Implemented default routing and improved size of gossip. Using 64 bit
hash funciton to identify vector.
2023-12-08 20:02:57 +00:00
0058c9f4c9 47-default-routing
Implementing default routing so that all traffic goes out of an
exit point.
2023-12-08 11:49:24 +00:00
92c0805275 Merge pull request #46 from tim-beatham/45-use-statistical-testing
45 use statistical testing
2023-12-07 18:20:25 +00:00
661fb0d54c 45-use-statistical-testing
Keepalive is based on per mesh and not per node.
Using total ordering mechanism similar to paxos to elect a leader
if leader doesn't update it's timestamp within 3 * keepAlive then
give the leader a gravestone and elect the next leader.
Leader is bassed on lexicographically ordered public key.
2023-12-07 18:18:13 +00:00
64885f1055 45-use-statistical-testing
Using statistical testing to test whether the node has failed.
2023-12-07 01:44:54 +00:00
2169f7796f Merge pull request #44 from tim-beatham/43-gravestones
43-use-gravestones
2023-12-06 22:46:05 +00:00
a3ceff019d 43-use-gravestones
Change of approach from keepalive to a noiseless protocol
2023-12-06 22:45:04 +00:00
b78d96986c Merge pull request #42 from tim-beatham/41-bugfix-fluctuating-ips
41 bugfix fluctuating ips
2023-12-06 14:37:14 +00:00
76 changed files with 4177 additions and 2141 deletions

BIN
api Executable file

Binary file not shown.

View File

@ -3,7 +3,7 @@ package main
import (
"log"
"github.com/tim-beatham/wgmesh/pkg/api"
"github.com/tim-beatham/smegmesh/pkg/api"
)
func main() {

View File

@ -3,7 +3,7 @@ package main
import (
"log"
smegdns "github.com/tim-beatham/wgmesh/pkg/dns"
smegdns "github.com/tim-beatham/smegmesh/pkg/dns"
)
func main() {

409
cmd/smegctl/main.go Normal file
View File

@ -0,0 +1,409 @@
package main
import (
"fmt"
ipcRpc "net/rpc"
"os"
"github.com/akamensky/argparse"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
graph "github.com/tim-beatham/smegmesh/pkg/dot"
"github.com/tim-beatham/smegmesh/pkg/ipc"
logging "github.com/tim-beatham/smegmesh/pkg/log"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
type CreateMeshParams struct {
Client *ipcRpc.Client
Endpoint string
WgArgs ipc.WireGuardArgs
AdvertiseRoutes bool
AdvertiseDefault bool
}
func createMesh(client *ipc.SmegmeshIpc, args *ipc.NewMeshArgs) {
var reply string
err := client.CreateMesh(args, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func listMeshes(client *ipc.SmegmeshIpc) {
reply := new(ipc.ListMeshReply)
err := client.ListMeshes(reply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
return
}
for _, meshId := range reply.Meshes {
fmt.Println(meshId)
}
}
func joinMesh(client *ipc.SmegmeshIpc, args ipc.JoinMeshArgs) {
var reply string
err := client.JoinMesh(args, &reply)
if err != nil {
fmt.Println(err.Error())
}
fmt.Println(reply)
}
func leaveMesh(client *ipc.SmegmeshIpc, meshId string) {
var reply string
err := client.LeaveMesh(meshId, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func getGraph(client *ipc.SmegmeshIpc) {
listMeshesReply := new(ipc.ListMeshReply)
err := client.ListMeshes(listMeshesReply)
if err != nil {
fmt.Println(err.Error())
return
}
meshes := make(map[string][]ctrlserver.MeshNode)
for _, meshId := range listMeshesReply.Meshes {
var meshReply ipc.GetMeshReply
err := client.GetMesh(meshId, &meshReply)
if err != nil {
fmt.Println(err.Error())
return
}
meshes[meshId] = meshReply.Nodes
}
dotGenerator := graph.NewMeshGraphConverter(meshes)
dot, err := dotGenerator.Generate()
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(dot)
}
func queryMesh(client *ipc.SmegmeshIpc, meshId, query string) {
var reply string
args := ipc.QueryMesh{
MeshId: meshId,
Query: query,
}
err := client.Query(args, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func putDescription(client *ipc.SmegmeshIpc, meshId, description string) {
var reply string
err := client.PutDescription(ipc.PutDescriptionArgs{
MeshId: meshId,
Description: description,
}, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
// putAlias: puts an alias for the node
func putAlias(client *ipc.SmegmeshIpc, meshid, alias string) {
var reply string
err := client.PutAlias(ipc.PutAliasArgs{
MeshId: meshid,
Alias: alias,
}, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func setService(client *ipc.SmegmeshIpc, meshId, service, value string) {
var reply string
err := client.PutService(ipc.PutServiceArgs{
MeshId: meshId,
Service: service,
Value: value,
}, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func deleteService(client *ipc.SmegmeshIpc, meshId, service string) {
var reply string
err := client.DeleteService(ipc.DeleteServiceArgs{
MeshId: meshId,
Service: service,
}, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func main() {
parser := argparse.NewParser("smgctl",
"smegctl Manipulate WireGuard mesh networks")
newMeshCmd := parser.NewCommand("new-mesh", "Create a new mesh")
listMeshCmd := parser.NewCommand("list-meshes", "List meshes the node is connected to")
joinMeshCmd := parser.NewCommand("join-mesh", "Join a mesh network")
getGraphCmd := parser.NewCommand("get-graph", "Convert a mesh into DOT format")
leaveMeshCmd := parser.NewCommand("leave-mesh", "Leave a mesh network")
queryMeshCmd := parser.NewCommand("query-mesh", "Query a mesh network using JMESPath")
putDescriptionCmd := parser.NewCommand("put-description", "Place a description for the node")
putAliasCmd := parser.NewCommand("put-alias", "Place an alias for the node")
setServiceCmd := parser.NewCommand("set-service", "Place a service into your advertisements")
deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements")
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{
Default: 0,
Help: "WireGuard port to use to the interface. A default of 0 uses an unused ephmeral port.",
})
var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{
Help: "Publicly routeable endpoint to advertise within the mesh",
})
var newMeshRole *string = newMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{
Help: "Role in the mesh network. A value of peer means that the node is publicly routeable and thus considered" +
" in the gossip protocol. Client means that the node is not publicly routeable and is not a candidate in the gossip" +
" protocol",
})
var newMeshKeepAliveWg *int = newMeshCmd.Int("k", "KeepAliveWg", &argparse.Options{
Default: 0,
Help: "WireGuard KeepAlive value for NAT traversal and firewall holepunching",
})
var newMeshAdvertiseRoutes *bool = newMeshCmd.Flag("a", "advertise", &argparse.Options{
Help: "Advertise routes to other mesh network into the mesh",
})
var newMeshAdvertiseDefaults *bool = newMeshCmd.Flag("d", "defaults", &argparse.Options{
Help: "Advertise ::/0 into the mesh network",
})
var joinMeshId *string = joinMeshCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{
Required: true,
Help: "IP address of the bootstrapping node to join through",
})
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{
Help: "Publicly routeable endpoint to advertise within the mesh",
})
var joinMeshRole *string = joinMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{
Help: "Role in the mesh network. A value of peer means that the node is publicly routeable and thus considered" +
" in the gossip protocol. Client means that the node is not publicly routeable and is not a candidate in the gossip" +
" protocol",
})
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{
Default: 0,
Help: "WireGuard port to use to the interface. A default of 0 uses an unused ephmeral port.",
})
var joinMeshKeepAliveWg *int = joinMeshCmd.Int("k", "KeepAliveWg", &argparse.Options{
Default: 0,
Help: "WireGuard KeepAlive value for NAT traversal and firewall ho;lepunching",
})
var joinMeshAdvertiseRoutes *bool = joinMeshCmd.Flag("a", "advertise", &argparse.Options{
Help: "Advertise routes to other mesh network into the mesh",
})
var joinMeshAdvertiseDefaults *bool = joinMeshCmd.Flag("d", "defaults", &argparse.Options{
Help: "Advertise ::/0 into the mesh network",
})
var leaveMeshMeshId *string = leaveMeshCmd.String("m", "mesh", &argparse.Options{
Required: true,
Help: "MeshID of the mesh to leave",
})
var queryMeshMeshId *string = queryMeshCmd.String("m", "mesh", &argparse.Options{
Required: true,
Help: "MeshID of the mesh to query",
})
var queryMeshQuery *string = queryMeshCmd.String("q", "query", &argparse.Options{
Required: true,
Help: "JMESPath Query Of The Mesh Network To Query",
})
var description *string = putDescriptionCmd.String("d", "description", &argparse.Options{
Required: true,
Help: "Description of the node in the mesh",
})
var descriptionMeshId *string = putDescriptionCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
var aliasMeshId *string = putAliasCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
var alias *string = putAliasCmd.String("a", "alias", &argparse.Options{
Required: true,
Help: "Alias of the node to set can be used in DNS to lookup an IP address",
})
var serviceKey *string = setServiceCmd.String("s", "service", &argparse.Options{
Required: true,
Help: "Key of the service to advertise in the mesh network",
})
var serviceValue *string = setServiceCmd.String("v", "value", &argparse.Options{
Required: true,
Help: "Value of the service to advertise in the mesh network",
})
var serviceMeshId *string = setServiceCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{
Required: true,
Help: "Key of the service to remove",
})
var deleteServiceMeshid *string = deleteServiceCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
err := parser.Parse(os.Args)
if err != nil {
fmt.Print(parser.Usage(err))
return
}
client, err := ipc.NewClientIpc()
if err != nil {
panic(err)
}
if newMeshCmd.Happened() {
args := &ipc.NewMeshArgs{
WgArgs: ipc.WireGuardArgs{
Endpoint: *newMeshEndpoint,
Role: *newMeshRole,
WgPort: *newMeshPort,
KeepAliveWg: *newMeshKeepAliveWg,
AdvertiseDefaultRoute: *newMeshAdvertiseDefaults,
AdvertiseRoutes: *newMeshAdvertiseRoutes,
},
}
createMesh(client, args)
}
if listMeshCmd.Happened() {
listMeshes(client)
}
if joinMeshCmd.Happened() {
args := ipc.JoinMeshArgs{
IpAddress: *joinMeshIpAddress,
MeshId: *joinMeshId,
WgArgs: ipc.WireGuardArgs{
Endpoint: *joinMeshEndpoint,
Role: *joinMeshRole,
WgPort: *joinMeshPort,
KeepAliveWg: *joinMeshKeepAliveWg,
AdvertiseDefaultRoute: *joinMeshAdvertiseDefaults,
AdvertiseRoutes: *joinMeshAdvertiseRoutes,
},
}
joinMesh(client, args)
}
if getGraphCmd.Happened() {
getGraph(client)
}
if leaveMeshCmd.Happened() {
leaveMesh(client, *leaveMeshMeshId)
}
if queryMeshCmd.Happened() {
queryMesh(client, *queryMeshMeshId, *queryMeshQuery)
}
if putDescriptionCmd.Happened() {
putDescription(client, *descriptionMeshId, *description)
}
if putAliasCmd.Happened() {
putAlias(client, *aliasMeshId, *alias)
}
if setServiceCmd.Happened() {
setService(client, *serviceMeshId, *serviceKey, *serviceValue)
}
if deleteServiceCmd.Happened() {
deleteService(client, *deleteServiceMeshid, *deleteServiceKey)
}
}

View File

@ -6,26 +6,26 @@ import (
"os"
"os/signal"
"github.com/tim-beatham/wgmesh/pkg/conf"
ctrlserver "github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/robin"
"github.com/tim-beatham/wgmesh/pkg/sync"
timer "github.com/tim-beatham/wgmesh/pkg/timers"
"github.com/tim-beatham/smegmesh/pkg/conf"
ctrlserver "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/smegmesh/pkg/ipc"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/robin"
"github.com/tim-beatham/smegmesh/pkg/sync"
timer "github.com/tim-beatham/smegmesh/pkg/timers"
"golang.zx2c4.com/wireguard/wgctrl"
)
func main() {
if len(os.Args) != 2 {
logging.Log.WriteErrorf("Need to provide configuration.yaml")
logging.Log.WriteErrorf("Did not provide configuration")
return
}
conf, err := conf.ParseConfiguration(os.Args[1])
conf, err := conf.ParseDaemonConfiguration(os.Args[1])
if err != nil {
logging.Log.WriteInfof("Could not parse configuration")
logging.Log.WriteErrorf("Could not parse configuration: %s", err.Error())
return
}
@ -59,12 +59,16 @@ func main() {
}
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
if err != nil {
panic(err)
}
syncProvider.Server = ctrlServer
syncRequester = sync.NewSyncRequester(ctrlServer)
syncer = sync.NewSyncer(ctrlServer.MeshManager, conf, syncRequester)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, syncer)
timestampScheduler := timer.NewTimestampScheduler(ctrlServer)
pruneScheduler := mesh.NewPruner(ctrlServer.MeshManager, *conf)
keepAlive := timer.NewTimestampScheduler(ctrlServer)
robinIpcParams := robin.RobinIpcParams{
CtrlServer: ctrlServer,
@ -82,13 +86,12 @@ func main() {
go ipc.RunIpcHandler(&robinIpc)
go syncScheduler.Run()
go timestampScheduler.Run()
go pruneScheduler.Run()
go keepAlive.Run()
closeResources := func() {
logging.Log.WriteInfof("Closing resources")
syncScheduler.Stop()
timestampScheduler.Stop()
keepAlive.Stop()
ctrlServer.Close()
client.Close()
}

View File

@ -1,355 +0,0 @@
package main
import (
"fmt"
ipcRpc "net/rpc"
"os"
"strings"
"time"
"github.com/akamensky/argparse"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
type CreateMeshParams struct {
Client *ipcRpc.Client
WgPort int
Endpoint string
}
func createMesh(args *CreateMeshParams) string {
var reply string
newMeshParams := ipc.NewMeshArgs{
WgPort: args.WgPort,
Endpoint: args.Endpoint,
}
err := args.Client.Call("IpcHandler.CreateMesh", &newMeshParams, &reply)
if err != nil {
return err.Error()
}
return reply
}
func listMeshes(client *ipcRpc.Client) {
reply := new(ipc.ListMeshReply)
err := client.Call("IpcHandler.ListMeshes", "", &reply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
return
}
for _, meshId := range reply.Meshes {
fmt.Println(meshId)
}
}
type JoinMeshParams struct {
Client *ipcRpc.Client
MeshId string
IpAddress string
IfName string
WgPort int
Endpoint string
}
func joinMesh(params *JoinMeshParams) string {
var reply string
args := ipc.JoinMeshArgs{
MeshId: params.MeshId,
IpAdress: params.IpAddress,
Port: params.WgPort,
}
err := params.Client.Call("IpcHandler.JoinMesh", &args, &reply)
if err != nil {
return err.Error()
}
return reply
}
func getMesh(client *ipcRpc.Client, meshId string) {
reply := new(ipc.GetMeshReply)
err := client.Call("IpcHandler.GetMesh", &meshId, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
for _, node := range reply.Nodes {
fmt.Println("Public Key: " + node.PublicKey)
fmt.Println("Control Endpoint: " + node.HostEndpoint)
fmt.Println("WireGuard Endpoint: " + node.WgEndpoint)
fmt.Println("Wg IP: " + node.WgHost)
fmt.Printf("Timestamp: %s", time.Unix(node.Timestamp, 0).String())
mapFunc := func(r ctrlserver.MeshRoute) string {
return r.Destination
}
advertiseRoutes := strings.Join(lib.Map(node.Routes, mapFunc), ",")
fmt.Printf("Routes: %s\n", advertiseRoutes)
fmt.Println("---")
}
}
func leaveMesh(client *ipcRpc.Client, meshId string) {
var reply string
err := client.Call("IpcHandler.LeaveMesh", &meshId, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func enableInterface(client *ipcRpc.Client, meshId string) {
var reply string
err := client.Call("IpcHandler.EnableInterface", &meshId, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func getGraph(client *ipcRpc.Client, meshId string) {
var reply string
err := client.Call("IpcHandler.GetDOT", &meshId, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func queryMesh(client *ipcRpc.Client, meshId, query string) {
var reply string
err := client.Call("IpcHandler.Query", &ipc.QueryMesh{MeshId: meshId, Query: query}, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
// putDescription: puts updates the description about the node to the meshes
func putDescription(client *ipcRpc.Client, description string) {
var reply string
err := client.Call("IpcHandler.PutDescription", &description, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
// putAlias: puts an alias for the node
func putAlias(client *ipcRpc.Client, alias string) {
var reply string
err := client.Call("IpcHandler.PutAlias", &alias, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func setService(client *ipcRpc.Client, service, value string) {
var reply string
serviceArgs := &ipc.PutServiceArgs{
Service: service,
Value: value,
}
err := client.Call("IpcHandler.PutService", serviceArgs, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func deleteService(client *ipcRpc.Client, service string) {
var reply string
err := client.Call("IpcHandler.PutService", &service, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func getNode(client *ipcRpc.Client, nodeId, meshId string) {
var reply string
args := &ipc.GetNodeArgs{
NodeId: nodeId,
MeshId: meshId,
}
err := client.Call("IpcHandler.GetNode", &args, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func main() {
parser := argparse.NewParser("wg-mesh",
"wg-mesh Manipulate WireGuard meshes")
newMeshCmd := parser.NewCommand("new-mesh", "Create a new mesh")
listMeshCmd := parser.NewCommand("list-meshes", "List meshes the node is connected to")
joinMeshCmd := parser.NewCommand("join-mesh", "Join a mesh network")
enableInterfaceCmd := parser.NewCommand("enable-interface", "Enable A Specific Mesh Interface")
getGraphCmd := parser.NewCommand("get-graph", "Convert a mesh into DOT format")
leaveMeshCmd := parser.NewCommand("leave-mesh", "Leave a mesh network")
queryMeshCmd := parser.NewCommand("query-mesh", "Query a mesh network using JMESPath")
putDescriptionCmd := parser.NewCommand("put-description", "Place a description for the node")
putAliasCmd := parser.NewCommand("put-alias", "Place an alias for the node")
setServiceCmd := parser.NewCommand("set-service", "Place a service into your advertisements")
deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements")
getNodeCmd := parser.NewCommand("get-node", "Get a specific node from the mesh")
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{})
var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{})
var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true})
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{})
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{})
var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true})
var getGraphMeshId *string = getGraphCmd.String("m", "mesh", &argparse.Options{Required: true})
var leaveMeshMeshId *string = leaveMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var queryMeshMeshId *string = queryMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var queryMeshQuery *string = queryMeshCmd.String("q", "query", &argparse.Options{Required: true})
var description *string = putDescriptionCmd.String("d", "description", &argparse.Options{Required: true})
var alias *string = putAliasCmd.String("a", "alias", &argparse.Options{Required: true})
var serviceKey *string = setServiceCmd.String("s", "service", &argparse.Options{Required: true})
var serviceValue *string = setServiceCmd.String("v", "value", &argparse.Options{Required: true})
var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{Required: true})
var getNodeNodeId *string = getNodeCmd.String("n", "nodeid", &argparse.Options{Required: true})
var getNodeMeshId *string = getNodeCmd.String("m", "meshid", &argparse.Options{Required: true})
err := parser.Parse(os.Args)
if err != nil {
fmt.Print(parser.Usage(err))
return
}
client, err := ipcRpc.DialHTTP("unix", SockAddr)
if err != nil {
fmt.Println(err.Error())
return
}
if newMeshCmd.Happened() {
fmt.Println(createMesh(&CreateMeshParams{
Client: client,
WgPort: *newMeshPort,
Endpoint: *newMeshEndpoint,
}))
}
if listMeshCmd.Happened() {
listMeshes(client)
}
if joinMeshCmd.Happened() {
fmt.Println(joinMesh(&JoinMeshParams{
Client: client,
WgPort: *joinMeshPort,
IpAddress: *joinMeshIpAddress,
MeshId: *joinMeshId,
Endpoint: *joinMeshEndpoint,
}))
}
if getGraphCmd.Happened() {
getGraph(client, *getGraphMeshId)
}
if enableInterfaceCmd.Happened() {
enableInterface(client, *enableInterfaceMeshId)
}
if leaveMeshCmd.Happened() {
leaveMesh(client, *leaveMeshMeshId)
}
if queryMeshCmd.Happened() {
queryMesh(client, *queryMeshMeshId, *queryMeshQuery)
}
if putDescriptionCmd.Happened() {
putDescription(client, *description)
}
if putAliasCmd.Happened() {
putAlias(client, *alias)
}
if setServiceCmd.Happened() {
setService(client, *serviceKey, *serviceValue)
}
if deleteServiceCmd.Happened() {
deleteService(client, *deleteServiceKey)
}
if getNodeCmd.Happened() {
getNode(client, *getNodeNodeId, *getNodeMeshId)
}
}

View File

@ -10,5 +10,5 @@ syncRate: 1
interClusterChance: 0.15
branchRate: 3
infectionCount: 3
keepAliveTime: 10
heartBeatTime: 10
pruneTime: 20

View File

@ -10,5 +10,5 @@ syncRate: 1
interClusterChance: 0.15
branchRate: 3
infectionCount: 3
keepAliveTime: 10
heartBeatTime: 10
pruneTime: 20

View File

@ -10,5 +10,5 @@ syncRate: 1
interClusterChance: 0.15
branchRate: 3
infectionCount: 3
keepAliveTime: 10
heartBeatTime: 10
pruneTime: 20

16
go.mod
View File

@ -1,14 +1,18 @@
module github.com/tim-beatham/wgmesh
module github.com/tim-beatham/smegmesh
go 1.21.3
require (
github.com/akamensky/argparse v1.4.0
github.com/anandvarma/namegen v0.0.0-20230727084436-5197c6ea3255
github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9
github.com/gin-gonic/gin v1.9.1
github.com/go-playground/validator/v10 v10.16.0
github.com/google/uuid v1.3.0
github.com/jmespath/go-jmespath v0.4.0
github.com/jsimonetti/rtnetlink v1.3.5
github.com/lithammer/shortuuid v3.0.0+incompatible
github.com/miekg/dns v1.1.57
github.com/sirupsen/logrus v1.9.3
golang.org/x/sys v0.14.0
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
@ -24,7 +28,6 @@ require (
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/go-cmp v0.5.9 // indirect
@ -42,10 +45,13 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.13.0 // indirect
golang.org/x/net v0.15.0 // indirect
golang.org/x/sync v0.3.0 // indirect
golang.org/x/crypto v0.14.0 // indirect
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 // indirect
golang.org/x/mod v0.12.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sync v0.4.0 // indirect
golang.org/x/text v0.13.0 // indirect
golang.org/x/tools v0.13.0 // indirect
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect
)

141
go.sum Normal file
View File

@ -0,0 +1,141 @@
github.com/akamensky/argparse v1.4.0 h1:YGzvsTqCvbEZhL8zZu2AiA5nq805NZh75JNj4ajn1xc=
github.com/akamensky/argparse v1.4.0/go.mod h1:S5kwC7IuDcEr5VeXtGPRVZ5o/FdhcMlQz4IZQuw64xA=
github.com/anandvarma/namegen v0.0.0-20230727084436-5197c6ea3255 h1:aIAyyj4XPrke9Tc/umbBCzP5SKX/CHf3dKrL/PhH2lo=
github.com/anandvarma/namegen v0.0.0-20230727084436-5197c6ea3255/go.mod h1:MFyILur9tG8PxaCXGZVr/2BOnHtRIgxYejYFZdWLxr0=
github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9 h1:+6JSfuxZgmURoIlGdnYnY/FLRGWGagLyiBjt/VLtwi4=
github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9/go.mod h1:6UxoDE+thWsISXK93pxaOuOfkcAfCvDbg0eAnFmxL5E=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/cilium/ebpf v0.11.0 h1:V8gS/bTCCjX9uUnkUFUpPsksM8n1lXBAvHcpiFk1X2Y=
github.com/cilium/ebpf v0.11.0/go.mod h1:WE7CZAnqOL2RouJ4f1uyNhqr2P4CCvXFIqdRDUgWsVs=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE=
github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/jsimonetti/rtnetlink v1.3.5 h1:hVlNQNRlLDGZz31gBPicsG7Q53rnlsz1l1Ix/9XlpVA=
github.com/jsimonetti/rtnetlink v1.3.5/go.mod h1:0LFedyiTkebnd43tE4YAkWGIq9jQphow4CcwxaT2Y00=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/lithammer/shortuuid v3.0.0+incompatible h1:NcD0xWW/MZYXEHa6ITy6kaXN5nwm/V115vj2YXfhS0w=
github.com/lithammer/shortuuid v3.0.0+incompatible/go.mod h1:FR74pbAuElzOUuenUHTK2Tciko1/vKuIKS9dSkDrA4w=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI=
github.com/mdlayher/socket v0.5.0/go.mod h1:WkcBFfvyG8QENs5+hfQPl1X6Jpd2yeLIYgrGFmJiJxI=
github.com/miekg/dns v1.1.57 h1:Jzi7ApEIzwEPLHWRcafCN9LZSBbqQpxjt/wpgvg7wcM=
github.com/miekg/dns v1.1.57/go.mod h1:uqRjCRUuEAA6qsOiJvDd+CFo/vW+y5WR6SNmHE55hZk=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug=
golang.org/x/exp v0.0.0-20230321023759-10a507213a29/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ=
golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 h1:EY138uSo1JYlDq+97u1FtcOUwPpIU6WL1Lkt7WpYjPA=
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 h1:bVf09lpb+OJbByTj913DRJioFFAjf/ZGxEz7MajTp2U=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98/go.mod h1:TUfxEVdsvPg18p6AslUXFoLdpED4oBnGwyqk3dV1XzM=
google.golang.org/grpc v1.58.1 h1:OL+Vz23DTtrrldqHK49FUOPHyY75rvFqJfXC84NYW58=
google.golang.org/grpc v1.58.1/go.mod h1:tgX3ZQDlNJGU96V6yHh1T/JeoBQ2TXdr43YbYSsCJk0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

View File

@ -4,17 +4,13 @@ import (
"fmt"
"net/http"
ipcRpc "net/rpc"
"github.com/gin-gonic/gin"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/what8words"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/smegmesh/pkg/ipc"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/what8words"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
type ApiServer interface {
GetMeshes(c *gin.Context)
Run(addr string) error
@ -22,7 +18,7 @@ type ApiServer interface {
type SmegServer struct {
router *gin.Engine
client *ipcRpc.Client
client *ipc.SmegmeshIpc
words *what8words.What8Words
}
@ -65,6 +61,12 @@ func (s *SmegServer) meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNo
PublicKey: meshNode.PublicKey,
Alias: alias,
Services: meshNode.Services,
Stats: SmegStats{
TotalTransmit: meshNode.Stats.TransmitBytes,
TotalReceived: meshNode.Stats.ReceivedBytes,
KeepAliveInterval: meshNode.Stats.PersistentKeepAliveInterval,
AllowedIps: meshNode.Stats.AllowedIPs,
},
}
}
@ -93,12 +95,14 @@ func (s *SmegServer) CreateMesh(c *gin.Context) {
}
ipcRequest := ipc.NewMeshArgs{
WgPort: createMesh.WgPort,
WgArgs: ipc.WireGuardArgs{
WgPort: createMesh.WgPort,
},
}
var reply string
err := s.client.Call("IpcHandler.CreateMesh", &ipcRequest, &reply)
err := s.client.CreateMesh(&ipcRequest, &reply)
if err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
@ -124,14 +128,16 @@ func (s *SmegServer) JoinMesh(c *gin.Context) {
}
ipcRequest := ipc.JoinMeshArgs{
MeshId: joinMesh.MeshId,
IpAdress: joinMesh.Bootstrap,
Port: joinMesh.WgPort,
MeshId: joinMesh.MeshId,
IpAddress: joinMesh.Bootstrap,
WgArgs: ipc.WireGuardArgs{
WgPort: joinMesh.WgPort,
},
}
var reply string
err := s.client.Call("IpcHandler.JoinMesh", &ipcRequest, &reply)
err := s.client.JoinMesh(ipcRequest, &reply)
if err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
@ -154,7 +160,7 @@ func (s *SmegServer) GetMesh(c *gin.Context) {
getMeshReply := new(ipc.GetMeshReply)
err := s.client.Call("IpcHandler.GetMesh", &meshid, &getMeshReply)
err := s.client.GetMesh(meshid, getMeshReply)
if err != nil {
c.JSON(http.StatusNotFound,
@ -172,7 +178,7 @@ func (s *SmegServer) GetMesh(c *gin.Context) {
func (s *SmegServer) GetMeshes(c *gin.Context) {
listMeshesReply := new(ipc.ListMeshReply)
err := s.client.Call("IpcHandler.ListMeshes", "", &listMeshesReply)
err := s.client.ListMeshes(listMeshesReply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
@ -185,7 +191,7 @@ func (s *SmegServer) GetMeshes(c *gin.Context) {
for _, mesh := range listMeshesReply.Meshes {
getMeshReply := new(ipc.GetMeshReply)
err := s.client.Call("IpcHandler.GetMesh", &mesh, &getMeshReply)
err := s.client.GetMesh(mesh, getMeshReply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
@ -205,7 +211,7 @@ func (s *SmegServer) Run(addr string) error {
}
func NewSmegServer(conf ApiServerConf) (ApiServer, error) {
client, err := ipcRpc.DialHTTP("unix", SockAddr)
client, err := ipc.NewClientIpc()
if err != nil {
return nil, err
@ -229,9 +235,19 @@ func NewSmegServer(conf ApiServerConf) (ApiServer, error) {
words: words,
}
router.GET("/meshes", smegServer.GetMeshes)
router.GET("/mesh/:meshid", smegServer.GetMesh)
router.POST("/mesh/create", smegServer.CreateMesh)
router.POST("/mesh/join", smegServer.JoinMesh)
v1 := router.Group("/api/v1")
{
meshes := v1.Group("/meshes")
{
meshes.GET("/", smegServer.GetMeshes)
}
mesh := v1.Group("/mesh")
{
mesh.GET("/:meshid", smegServer.GetMesh)
mesh.POST("/create", smegServer.CreateMesh)
mesh.POST("/join", smegServer.JoinMesh)
}
}
return smegServer, nil
}

View File

@ -1,10 +1,19 @@
package api
import "time"
type Route struct {
Prefix string `json:"prefix"`
Path []string `json:"path"`
}
type SmegStats struct {
TotalTransmit int64 `json:"totalTransmit"`
TotalReceived int64 `json:"totalReceived"`
KeepAliveInterval time.Duration `json:"keepaliveInterval"`
AllowedIps []string `json:"allowedIps"`
}
type SmegNode struct {
Alias string `json:"alias"`
WgHost string `json:"wgHost"`
@ -15,6 +24,7 @@ type SmegNode struct {
PublicKey string `json:"publicKey"`
Routes []Route `json:"routes"`
Services map[string]string `json:"services"`
Stats SmegStats `json:"stats"`
}
type SmegMesh struct {

View File

@ -9,10 +9,10 @@ import (
"time"
"github.com/automerge/automerge-go"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@ -24,7 +24,7 @@ type CrdtMeshManager struct {
Client *wgctrl.Client
doc *automerge.Doc
LastHash automerge.ChangeHash
conf *conf.WgMeshConfiguration
conf *conf.WgConfiguration
cache *MeshCrdt
lastCacheHash automerge.ChangeHash
}
@ -40,7 +40,11 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
crdt.Services = make(map[string]string)
crdt.Timestamp = time.Now().Unix()
c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt)
err := c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt)
if err != nil {
logging.Log.WriteInfof("error")
}
}
func (c *CrdtMeshManager) isPeer(nodeId string) bool {
@ -74,8 +78,8 @@ func (c *CrdtMeshManager) isAlive(nodeId string) bool {
return false
}
keepAliveTime := timestamp.Int64()
return (time.Now().Unix() - keepAliveTime) < int64(c.conf.DeadTime)
return true
// return (time.Now().Unix() - keepAliveTime) < int64(c.conf.DeadTime)
}
func (c *CrdtMeshManager) GetPeers() []string {
@ -135,7 +139,7 @@ type NewCrdtNodeMangerParams struct {
MeshId string
DevName string
Port int
Conf conf.WgMeshConfiguration
Conf *conf.WgConfiguration
Client *wgctrl.Client
}
@ -146,7 +150,7 @@ func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, erro
manager.doc = automerge.New()
manager.IfName = params.DevName
manager.Client = params.Client
manager.conf = &params.Conf
manager.conf = params.Conf
manager.cache = nil
return &manager, nil
}
@ -161,7 +165,7 @@ func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
node, err := m.doc.Path("nodes").Map().Get(endpoint)
if node.Kind() != automerge.KindMap {
return nil, fmt.Errorf("GetNode: something went wrong %s is not a map type")
return nil, fmt.Errorf("getnode: node is not a map")
}
if err != nil {
@ -449,7 +453,7 @@ func (m *CrdtMeshManager) RemoveNode(nodeId string) error {
}
// DeleteRoutes deletes the specified routes
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
@ -467,65 +471,26 @@ func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
}
for _, route := range routes {
err = routeMap.Map().Delete(route)
err = routeMap.Map().Delete(route.GetDestination().String())
}
return err
}
// GetConfiguration: gets the configuration for this mesh network
func (m *CrdtMeshManager) GetConfiguration() *conf.WgConfiguration {
return m.conf
}
// Mark: mark the node as locally dead
func (m *CrdtMeshManager) Mark(nodeId string) {
}
func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
return NewAutomergeSync(m)
}
func (m *CrdtMeshManager) Prune(pruneTime int) error {
nodes, err := m.doc.Path("nodes").Get()
if err != nil {
return err
}
if nodes.Kind() != automerge.KindMap {
return errors.New("node must be a map")
}
values, err := nodes.Map().Values()
if err != nil {
return err
}
deletionNodes := make([]string, 0)
for nodeId, node := range values {
if node.Kind() != automerge.KindMap {
return errors.New("node must be a map")
}
nodeMap := node.Map()
timeStamp, err := nodeMap.Get("timestamp")
if err != nil {
return err
}
if timeStamp.Kind() != automerge.KindInt64 {
return errors.New("timestamp is not int64")
}
timeValue := timeStamp.Int64()
nowValue := time.Now().Unix()
if nowValue-timeValue >= int64(pruneTime) {
deletionNodes = append(deletionNodes, nodeId)
}
}
for _, node := range deletionNodes {
logging.Log.WriteInfof("Pruning %s", node)
nodes.Map().Delete(node)
}
func (m *CrdtMeshManager) Prune() error {
return nil
}

View File

@ -2,7 +2,7 @@ package automerge
import (
"github.com/automerge/automerge-go"
logging "github.com/tim-beatham/wgmesh/pkg/log"
logging "github.com/tim-beatham/smegmesh/pkg/log"
)
type AutomergeSync struct {

View File

@ -1,14 +1,14 @@
package automerge
import (
"slices"
"net"
"strings"
"testing"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@ -22,7 +22,7 @@ func setUpTests() *TestParams {
DevName: "wg0",
Port: 5000,
Client: nil,
Conf: conf.WgMeshConfiguration{},
Conf: &conf.WgConfiguration{},
})
return &TestParams{
@ -31,22 +31,26 @@ func setUpTests() *TestParams {
}
func getTestNode() mesh.MeshNode {
pubKey, _ := wgtypes.GeneratePrivateKey()
return &MeshNodeCrdt{
HostEndpoint: "public-endpoint:8080",
WgEndpoint: "public-endpoint:21906",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d218/128",
PublicKey: "AAAAAAAAAAAA",
PublicKey: pubKey.String(),
Timestamp: time.Now().Unix(),
Description: "A node that we are adding",
}
}
func getTestNode2() mesh.MeshNode {
pubKey, _ := wgtypes.GeneratePrivateKey()
return &MeshNodeCrdt{
HostEndpoint: "public-endpoint:8081",
WgEndpoint: "public-endpoint:21907",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d219/128",
PublicKey: "BBBBBBBBB",
PublicKey: pubKey.String(),
Timestamp: time.Now().Unix(),
Description: "A node that we are adding",
}
@ -54,9 +58,11 @@ func getTestNode2() mesh.MeshNode {
func TestAddNodeNodeExists(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getTestNode())
node := getTestNode()
testParams.manager.AddNode(node)
node, err := testParams.manager.GetNode("public-endpoint:8080")
pubKey, _ := node.GetPublicKey()
node, err := testParams.manager.GetNode(pubKey.String())
if err != nil {
t.Error(err)
@ -70,25 +76,28 @@ func TestAddNodeNodeExists(t *testing.T) {
func TestAddNodeAddRoute(t *testing.T) {
testParams := setUpTests()
testNode := getTestNode()
testParams.manager.AddNode(testNode)
testParams.manager.AddRoutes(testNode.GetHostEndpoint(), "fd:1c64:1d00::/48")
pubKey, _ := testNode.GetPublicKey()
updatedNode, err := testParams.manager.GetNode(testNode.GetHostEndpoint())
_, destination, _ := net.ParseCIDR("fd:1c64:1d00::/48")
testParams.manager.AddNode(testNode)
testParams.manager.AddRoutes(pubKey.String(), &mesh.RouteStub{
Destination: destination,
HopCount: 0,
Path: make([]string, 0),
})
updatedNode, err := testParams.manager.GetNode(pubKey.String())
if err != nil {
t.Error(err)
}
if updatedNode == nil {
t.Fatalf(`Node does not exist in the mesh`)
t.Fatalf(`node does not exist in the mesh`)
}
routes := updatedNode.GetRoutes()
if !slices.Contains(routes, "fd:1c64:1d00::/48") {
t.Fatal("Route node not added")
}
if len(routes) != 1 {
t.Fatal(`Route length mismatch`)
}
@ -253,7 +262,9 @@ func TestUpdateTimeStampNodeExists(t *testing.T) {
node := getTestNode()
testParams.manager.AddNode(node)
err := testParams.manager.UpdateTimeStamp(node.GetHostEndpoint())
pubKey, _ := node.GetPublicKey()
err := testParams.manager.UpdateTimeStamp(pubKey.String())
if err != nil {
t.Error(err)
@ -282,7 +293,13 @@ func TestSetDescriptionNodeExists(t *testing.T) {
func TestAddRoutesNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.AddRoutes("AAAAA", "fd:1c64:1d00::/48")
_, destination, _ := net.ParseCIDR("fd:1c64:1d00::/48")
err := testParams.manager.AddRoutes("AAAAA", &mesh.RouteStub{
Destination: destination,
HopCount: 0,
Path: make([]string, 0),
})
if err == nil {
t.Error(err)
@ -293,16 +310,11 @@ func TestCompareComparesByPublicKey(t *testing.T) {
node := getTestNode().(*MeshNodeCrdt)
node2 := getTestNode2().(*MeshNodeCrdt)
if node.Compare(node2) != -1 {
t.Fatalf(`node is alphabetically before node2`)
}
pubKey1, _ := node.GetPublicKey()
pubKey2, _ := node2.GetPublicKey()
if node2.Compare(node) != 1 {
t.Fatalf(`node is alphabetical;y before node2`)
}
if node.Compare(node) != 0 {
t.Fatalf(`node is equal to node`)
if node.Compare(node2) != strings.Compare(pubKey1.String(), pubKey2.String()) {
t.Fatalf(`compare failed`)
}
}

View File

@ -3,9 +3,9 @@ package automerge
import (
"fmt"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/mesh"
)
type CrdtProviderFactory struct{}
@ -14,13 +14,13 @@ func (f *CrdtProviderFactory) CreateMesh(params *mesh.MeshProviderFactoryParams)
return NewCrdtNodeManager(&NewCrdtNodeMangerParams{
MeshId: params.MeshId,
DevName: params.DevName,
Conf: *params.Conf,
Conf: params.Conf,
Client: params.Client,
})
}
type MeshNodeFactory struct {
Config conf.WgMeshConfiguration
Config conf.DaemonConfiguration
}
// Build builds the mesh node that represents the host machine to add
@ -28,9 +28,9 @@ type MeshNodeFactory struct {
func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode {
hostName := f.getAddress(params)
grpcEndpoint := fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort)
grpcEndpoint := fmt.Sprintf("%s:%d", hostName, f.Config.GrpcPort)
if f.Config.Role == conf.CLIENT_ROLE {
if *params.MeshConfig.Role == conf.CLIENT_ROLE {
grpcEndpoint = "-"
}
@ -44,7 +44,7 @@ func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNod
Routes: make(map[string]Route),
Description: "",
Alias: "",
Type: string(f.Config.Role),
Type: string(*params.MeshConfig.Role),
}
}
@ -54,12 +54,12 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string
if params.Endpoint != "" {
hostName = params.Endpoint
} else if len(f.Config.Endpoint) != 0 {
hostName = f.Config.Endpoint
} else if len(*params.MeshConfig.Endpoint) != 0 {
hostName = *params.MeshConfig.Endpoint
} else {
ipFunc := lib.GetPublicIP
if f.Config.IPDiscovery == conf.DNS_IP_DISCOVERY {
if *params.MeshConfig.IPDiscovery == conf.DNS_IP_DISCOVERY {
ipFunc = lib.GetOutboundIP
}

33
pkg/cmd/cmd.go Normal file
View File

@ -0,0 +1,33 @@
// cmd is a package for running commands in the different operating systems implementations
package cmd
import (
"os/exec"
"strings"
)
type CmdRunner interface {
RunCommands(commands ...string) error
}
type UnixCmdRunner struct{}
// RunCommand: runs the unix command. It splits the command into fields
// and then runs the command accordingly
func RunCommand(cmd string) error {
args := strings.Fields(cmd)
c := exec.Command(args[0], args[1:]...)
return c.Run()
}
func (l *UnixCmdRunner) RunCommands(commands ...string) error {
for _, cmd := range commands {
err := RunCommand(cmd)
if err != nil {
return err
}
}
return nil
}

View File

@ -4,7 +4,7 @@ package conf
import (
"os"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/go-playground/validator/v10"
"gopkg.in/yaml.v3"
)
@ -26,174 +26,177 @@ const (
type IPDiscovery string
const (
PUBLIC_IP_DISCOVERY = "public"
DNS_IP_DISCOVERY = "dns"
PUBLIC_IP_DISCOVERY IPDiscovery = "public"
DNS_IP_DISCOVERY IPDiscovery = "dns"
)
type WgMeshConfiguration struct {
// WgConfiguration contains per-mesh WireGuard configuration. Contains poitner types only so we can
// tell if the attribute is set
type WgConfiguration struct {
// IPDIscovery: how to discover your IP if not specified. Use your outgoing IP or use a public
// service for IPDiscoverability
IPDiscovery *IPDiscovery `yaml:"ipDiscovery" validate:"required,eq=public|eq=dns"`
// AdvertiseRoutes: specifies whether the node can act as a router routing packets between meshes
AdvertiseRoutes *bool `yaml:"advertiseRoute" validate:"required"`
// AdvertiseDefaultRoute: specifies whether or not this route should advertise a default route
// for all nodes to route their packets to
AdvertiseDefaultRoute *bool `yaml:"advertiseDefaults" validate:"required"`
// Endpoint contains what value should be set as the public endpoint of this node
Endpoint *string `yaml:"publicEndpoint"`
// Role specifies whether or not the user is globally accessible.
// If the user is globaly accessible they specify themselves as a client.
Role *NodeType `yaml:"role" validate:"required,eq=client|eq=peer"`
// KeepAliveWg configures the implementation so that we send keep alive packets to peers.
KeepAliveWg *int `yaml:"keepAliveWg" validate:"omitempty,gte=0"`
// PreUp are WireGuard commands to run before adding the WG interface
PreUp []string `yaml:"preUp"`
// PostUp are WireGuard commands to run after adding the WG interface
PostUp []string `yaml:"postUp"`
// PreDown are WireGuard commands to run prior to removing the WG interface
PreDown []string `yaml:"preDown"`
// PostDown are WireGuard command to run after removing the WG interface
PostDown []string `yaml:"postDown"`
}
type DaemonConfiguration struct {
// CertificatePath is the path to the certificate to use in mTLS
CertificatePath string `yaml:"certificatePath"`
CertificatePath string `yaml:"certificatePath" validate:"required"`
// PrivateKeypath is the path to the clients private key in mTLS
PrivateKeyPath string `yaml:"privateKeyPath"`
PrivateKeyPath string `yaml:"privateKeyPath" validate:"required"`
// CaCeritifcatePath path to the certificate of the trust certificate authority
CaCertificatePath string `yaml:"caCertificatePath"`
CaCertificatePath string `yaml:"caCertificatePath" validate:"required"`
// SkipCertVerification specify to skip certificate verification. Should only be used
// in test environments
SkipCertVerification bool `yaml:"skipCertVerification"`
// Port to run the GrpcServer on
GrpcPort string `yaml:"gRPCPort"`
// IPDIscovery: how to discover your IP if not specified. Use DNS server 8.8.8.8 or
// use public IP discovery library
IPDiscovery IPDiscovery `yaml:"ipDiscovery"`
// AdvertiseRoutes advertises other meshes if the node is in multiple meshes
AdvertiseRoutes bool `yaml:"advertiseRoutes"`
// Endpoint is the IP in which this computer is publicly reachable.
// usecase is when the node has multiple IP addresses
Endpoint string `yaml:"publicEndpoint"`
// ClusterSize size of the cluster to split on
ClusterSize int `yaml:"clusterSize"`
// SyncRate number of times per second to perform a sync
SyncRate float64 `yaml:"syncRate"`
// InterClusterChance proability of inter-cluster communication in a sync round
InterClusterChance float64 `yaml:"interClusterChance"`
// BranchRate number of nodes to randomly communicate with
BranchRate int `yaml:"branchRate"`
// InfectionCount number of times we sync before we can no longer catch the udpate
InfectionCount int `yaml:"infectionCount"`
// KeepAliveTime number of seconds before we update node indicating that we are still alive
KeepAliveTime int `yaml:"keepAliveTime"`
// Timeout number of seconds before we consider the node as dead
Timeout int `yaml:"timeout"`
// PruneTime number of seconds before we remove nodes that are likely to be dead
PruneTime int `yaml:"pruneTime"`
// DeadTime: number of seconds before we consider the node as dead and stop considering it
// when picking a random peer
DeadTime int `yaml:"deadTime"`
GrpcPort int `yaml:"gRPCPort" validate:"required"`
// Timeout number of seconds without response that a node is considered unreachable by gRPC
Timeout int `yaml:"timeout" validate:"required,gte=1"`
// Profile whether or not to include a http server that profiles the code
Profile bool `yaml:"profile"`
// StubWg whether or not to stub the WireGuard types
StubWg bool `yaml:"stubWg"`
// Role specifies whether or not the user is globally accessible.
// If the user is globaly accessible they specify themselves as a client.
Role NodeType `yaml:"role"`
// KeepAliveWg configures the implementation so that we send keep alive packets to peers.
// KeepAlive can only be set if role is type client
KeepAliveWg int `yaml:"keepAliveWg"`
// SyncTime specifies how long the minimum time should be between synchronisation
SyncTime int `yaml:"syncTime" validate:"required,gte=1"`
// PullTime specifies the interval between checking for configuration changes
PullTime int `yaml:"pullTime" validate:"gte=0"`
// HeartBeat: number of seconds before the leader of the mesh sends an update to
// send to every member in the mesh
HeartBeat int `yaml:"heartBeatTime" validate:"required,gte=1"`
// ClusterSize specifies how many neighbours you should synchronise with per round
ClusterSize int `yaml:"clusterSize" validate:"gte=1"`
// InterClusterChance specifies the probabilityof inter-cluster communication in a sync round
InterClusterChance float64 `yaml:"interClusterChance" validate:"gt=0"`
// BranchRate specifies the number of nodes to synchronise with when a node has
// new changes to send to the mesh
BranchRate int `yaml:"branchRate" validate:"required,gte=1"`
// InfectionCount: number of time to sync before an update can no longer be 'caught'
InfectionCount int `yaml:"infectionCount" validate:"required,gte=1"`
// BaseConfiguration base WireGuard configuration to use, this is used when none is provided
BaseConfiguration WgConfiguration `yaml:"baseConfiguration" validate:"required"`
}
func ValidateConfiguration(c *WgMeshConfiguration) error {
if len(c.CertificatePath) == 0 {
return &WgMeshConfigurationError{
msg: "A public certificate must be specified for mTLS",
}
// ValdiateMeshConfiguration: validates the mesh configuration
func ValidateMeshConfiguration(conf *WgConfiguration) error {
validate := validator.New(validator.WithRequiredStructEnabled())
err := validate.Struct(conf)
if conf.PostDown == nil {
conf.PostDown = make([]string, 0)
}
if len(c.PrivateKeyPath) == 0 {
return &WgMeshConfigurationError{
msg: "A private key must be specified for mTLS",
}
if conf.PostUp == nil {
conf.PostUp = make([]string, 0)
}
if len(c.CaCertificatePath) == 0 {
return &WgMeshConfigurationError{
msg: "A ca certificate must be specified for mTLS",
}
if conf.PreDown == nil {
conf.PreDown = make([]string, 0)
}
if len(c.GrpcPort) == 0 {
return &WgMeshConfigurationError{
msg: "A grpc port must be specified",
}
if conf.PreUp == nil {
conf.PreUp = make([]string, 0)
}
if c.ClusterSize <= 0 {
return &WgMeshConfigurationError{
msg: "A cluster size must not be 0",
}
}
if c.SyncRate <= 0 {
return &WgMeshConfigurationError{
msg: "SyncRate cannot be negative",
}
}
if c.BranchRate <= 0 {
return &WgMeshConfigurationError{
msg: "Branch rate cannot be negative",
}
}
if c.InfectionCount <= 0 {
return &WgMeshConfigurationError{
msg: "Infection count cannot be less than 1",
}
}
if c.KeepAliveTime <= 0 {
return &WgMeshConfigurationError{
msg: "KeepAliveRate cannot be less than negative",
}
}
if c.InterClusterChance <= 0 {
return &WgMeshConfigurationError{
msg: "Intercluster chance cannot be less than 0",
}
}
if c.Timeout < 1 {
return &WgMeshConfigurationError{
msg: "Timeout should be greater than or equal to 1",
}
}
if c.PruneTime < 1 {
return &WgMeshConfigurationError{
msg: "Prune time cannot be < 1",
}
}
if c.DeadTime < 1 {
return &WgMeshConfigurationError{
msg: "Dead time cannot be < 1",
}
}
if c.KeepAliveTime <= 1 {
return &WgMeshConfigurationError{
msg: "Prune time cannot be less than keep alive time",
}
}
if c.Role == "" {
c.Role = PEER_ROLE
}
if c.IPDiscovery == "" {
c.IPDiscovery = PUBLIC_IP_DISCOVERY
}
return nil
return err
}
// ParseConfiguration parses the mesh configuration
func ParseConfiguration(filePath string) (*WgMeshConfiguration, error) {
var conf WgMeshConfiguration
// ValidateDaemonConfiguration: validates the dameon configuration that is used.
func ValidateDaemonConfiguration(c *DaemonConfiguration) error {
validate := validator.New(validator.WithRequiredStructEnabled())
err := validate.Struct(c)
return err
}
// ParseDaemonConfiguration parses the mesh configuration and validates the configuration
func ParseDaemonConfiguration(filePath string) (*DaemonConfiguration, error) {
var conf DaemonConfiguration
yamlBytes, err := os.ReadFile(filePath)
if err != nil {
logging.Log.WriteErrorf("Read file error: %s\n", err.Error())
return nil, err
}
err = yaml.Unmarshal(yamlBytes, &conf)
if err != nil {
logging.Log.WriteErrorf("Unmarshal error: %s\n", err.Error())
return nil, err
}
return &conf, ValidateConfiguration(&conf)
if conf.BaseConfiguration.KeepAliveWg == nil {
var keepAlive int = 0
conf.BaseConfiguration.KeepAliveWg = &keepAlive
}
return &conf, ValidateDaemonConfiguration(&conf)
}
// MergemeshConfiguration: merges the configuration in precedence where the last
// element in the list takes the most and the first takes the least
func MergeMeshConfiguration(cfgs ...WgConfiguration) (WgConfiguration, error) {
var result WgConfiguration
for _, cfg := range cfgs {
if cfg.AdvertiseDefaultRoute != nil {
result.AdvertiseDefaultRoute = cfg.AdvertiseDefaultRoute
}
if cfg.AdvertiseRoutes != nil {
result.AdvertiseRoutes = cfg.AdvertiseRoutes
}
if cfg.Endpoint != nil {
result.Endpoint = cfg.Endpoint
}
if cfg.IPDiscovery != nil {
result.IPDiscovery = cfg.IPDiscovery
}
if cfg.KeepAliveWg != nil {
result.KeepAliveWg = cfg.KeepAliveWg
}
if cfg.PostDown != nil {
result.PostDown = cfg.PostDown
}
if cfg.PostUp != nil {
result.PostUp = cfg.PostUp
}
if cfg.PreDown != nil {
result.PreDown = cfg.PreDown
}
if cfg.PreUp != nil {
result.PreUp = cfg.PreUp
}
if cfg.Role != nil {
result.Role = cfg.Role
}
}
return result, ValidateMeshConfiguration(&result)
}

View File

@ -1,24 +1,41 @@
package conf
import "testing"
import (
"testing"
)
func getExampleConfiguration() *WgMeshConfiguration {
return &WgMeshConfiguration{
CertificatePath: "./cert/cert.pem",
PrivateKeyPath: "./cert/key.pem",
CaCertificatePath: "./cert/ca.pems",
func getExampleConfiguration() *DaemonConfiguration {
discovery := PUBLIC_IP_DISCOVERY
advertiseRoutes := false
advertiseDefaultRoute := false
endpoint := "abc.com:123"
nodeType := CLIENT_ROLE
keepAliveWg := 0
return &DaemonConfiguration{
CertificatePath: "../../../cert/cert.pem",
PrivateKeyPath: "../../../cert/priv.pem",
CaCertificatePath: "../../../cert/cacert.pem",
SkipCertVerification: true,
GrpcPort: "8080",
AdvertiseRoutes: true,
Endpoint: "localhost",
ClusterSize: 1,
SyncRate: 1,
InterClusterChance: 0.1,
BranchRate: 2,
KeepAliveTime: 4,
InfectionCount: 1,
Timeout: 2,
PruneTime: 20,
GrpcPort: 25,
Timeout: 5,
Profile: false,
StubWg: false,
SyncTime: 2,
HeartBeat: 2,
ClusterSize: 64,
InterClusterChance: 0.15,
BranchRate: 3,
PullTime: 0,
InfectionCount: 2,
BaseConfiguration: WgConfiguration{
IPDiscovery: &discovery,
AdvertiseRoutes: &advertiseRoutes,
AdvertiseDefaultRoute: &advertiseDefaultRoute,
Endpoint: &endpoint,
Role: &nodeType,
KeepAliveWg: &keepAliveWg,
},
}
}
@ -26,7 +43,7 @@ func TestConfigurationCertificatePathEmpty(t *testing.T) {
conf := getExampleConfiguration()
conf.CertificatePath = ""
err := ValidateConfiguration(conf)
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
@ -37,7 +54,7 @@ func TestConfigurationPrivateKeyPathEmpty(t *testing.T) {
conf := getExampleConfiguration()
conf.PrivateKeyPath = ""
err := ValidateConfiguration(conf)
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
@ -48,7 +65,7 @@ func TestConfigurationCaCertificatePathEmpty(t *testing.T) {
conf := getExampleConfiguration()
conf.CaCertificatePath = ""
err := ValidateConfiguration(conf)
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
@ -57,9 +74,110 @@ func TestConfigurationCaCertificatePathEmpty(t *testing.T) {
func TestConfigurationGrpcPortEmpty(t *testing.T) {
conf := getExampleConfiguration()
conf.GrpcPort = ""
conf.GrpcPort = 0
err := ValidateConfiguration(conf)
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestIPDiscoveryNotSet(t *testing.T) {
conf := getExampleConfiguration()
ipDiscovery := IPDiscovery("djdsjdskd")
conf.BaseConfiguration.IPDiscovery = &ipDiscovery
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestAdvertiseRoutesNotSet(t *testing.T) {
conf := getExampleConfiguration()
conf.BaseConfiguration.AdvertiseRoutes = nil
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestAdvertiseDefaultRouteNotSet(t *testing.T) {
conf := getExampleConfiguration()
conf.BaseConfiguration.AdvertiseDefaultRoute = nil
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestKeepAliveWgNegative(t *testing.T) {
conf := getExampleConfiguration()
keepAliveWg := -1
conf.BaseConfiguration.KeepAliveWg = &keepAliveWg
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestRoleTypeNotValid(t *testing.T) {
conf := getExampleConfiguration()
role := NodeType("bruhhh")
conf.BaseConfiguration.Role = &role
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestRoleTypeNotSpecified(t *testing.T) {
conf := getExampleConfiguration()
conf.BaseConfiguration.Role = nil
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`invalid role type`)
}
}
func TestBranchRateZero(t *testing.T) {
conf := getExampleConfiguration()
conf.BranchRate = 0
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestsyncTimeZero(t *testing.T) {
conf := getExampleConfiguration()
conf.SyncTime = 0
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestKeepAliveTimeZero(t *testing.T) {
conf := getExampleConfiguration()
conf.HeartBeat = 0
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
@ -69,97 +187,51 @@ func TestConfigurationGrpcPortEmpty(t *testing.T) {
func TestClusterSizeZero(t *testing.T) {
conf := getExampleConfiguration()
conf.ClusterSize = 0
err := ValidateConfiguration(conf)
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func SyncRateZero(t *testing.T) {
func TestInterClusterChanceZero(t *testing.T) {
conf := getExampleConfiguration()
conf.SyncRate = 0
conf.InterClusterChance = 0
err := ValidateConfiguration(conf)
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func BranchRateZero(t *testing.T) {
conf := getExampleConfiguration()
conf.BranchRate = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func InfectionCountZero(t *testing.T) {
func TestInfectionCountOne(t *testing.T) {
conf := getExampleConfiguration()
conf.InfectionCount = 0
err := ValidateConfiguration(conf)
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func KeepAliveRateZero(t *testing.T) {
func TestPullTimeNegative(t *testing.T) {
conf := getExampleConfiguration()
conf.KeepAliveTime = 0
conf.PullTime = -1
err := ValidateConfiguration(conf)
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestValidCOnfiguration(t *testing.T) {
func TestValidConfiguration(t *testing.T) {
conf := getExampleConfiguration()
err := ValidateConfiguration(conf)
err := ValidateDaemonConfiguration(conf)
if err != nil {
t.Error(err)
}
}
func TestTimeout(t *testing.T) {
conf := getExampleConfiguration()
conf.Timeout = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestPruneTimeZero(t *testing.T) {
conf := getExampleConfiguration()
conf.PruneTime = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatalf(`Error should be thrown`)
}
}
func TestPruneTimeLessThanKeepAliveTime(t *testing.T) {
conf := getExampleConfiguration()
conf.PruneTime = 1
err := ValidateConfiguration(conf)
if err == nil {
t.Fatalf(`Error should be thrown`)
}
}

View File

@ -23,9 +23,10 @@ func binarySearch(global []string, selfId string, groupSize int) (int, int) {
lower := 0
higher := len(global) - 1
mid := (lower + higher) / 2
for (higher+1)-lower > groupSize {
mid := (lower + higher) / 2
if global[mid] < selfId {
lower = mid + 1
} else if global[mid] > selfId {
@ -33,8 +34,6 @@ func binarySearch(global []string, selfId string, groupSize int) (int, int) {
} else {
break
}
mid = (lower + higher) / 2
}
return lower, int(math.Min(float64(lower+groupSize), float64(len(global))))
@ -55,12 +54,14 @@ func (i *ConnClusterImpl) GetNeighbours(global []string, selfId string) []string
// you will communicate with a random node that is not in your cluster.
func (i *ConnClusterImpl) GetInterCluster(global []string, selfId string) string {
// Doesn't matter if not in it. Get index of where the node 'should' be
slices.Sort(global)
index, _ := binarySearch(global, selfId, 1)
numClusters := math.Ceil(float64(len(global)) / float64(i.clusterSize))
randomCluster := rand.Intn(int(numClusters)-1) + 1
randomCluster := rand.Intn(2) + 1
neighbourIndex := (index + randomCluster) % len(global)
// cluster is considered a heap
neighbourIndex := (2*index + (randomCluster * i.clusterSize)) % len(global)
return global[neighbourIndex]
}

View File

@ -6,7 +6,7 @@ import (
"crypto/tls"
"errors"
logging "github.com/tim-beatham/wgmesh/pkg/log"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)

View File

@ -7,7 +7,7 @@ import (
"os"
"sync"
logging "github.com/tim-beatham/wgmesh/pkg/log"
logging "github.com/tim-beatham/smegmesh/pkg/log"
)
// ConnectionManager defines an interface for maintaining peer connections
@ -19,9 +19,11 @@ type ConnectionManager interface {
// If the endpoint does not exist then add the connection. Returns an error
// if something went wrong
GetConnection(endPoint string) (PeerConnection, error)
// HasConnections returns true if a client has already registered at the givne
// HasConnections returns true if a peer has already registered at the given
// endpoint or false otherwise.
HasConnection(endPoint string) bool
// Removes a connection if it exists
RemoveConnection(endPoint string) error
// Goes through all the connections and closes eachone
Close() error
}
@ -32,7 +34,6 @@ type ConnectionManagerImpl struct {
// clientConnections maps an endpoint to a connection
conLoc sync.RWMutex
clientConnections map[string]PeerConnection
serverConfig *tls.Config
clientConfig *tls.Config
connFactory PeerConnectionFactory
}
@ -61,37 +62,25 @@ func NewConnectionManager(params *NewConnectionManagerParams) (ConnectionManager
return nil, err
}
serverAuth := tls.RequireAndVerifyClientCert
if params.SkipCertVerification {
serverAuth = tls.RequireAnyClientCert
}
certPool := x509.NewCertPool()
if !params.SkipCertVerification {
if params.CaCert == "" {
return nil, errors.New("CA Cert is not specified")
}
caCert, err := os.ReadFile(params.CaCert)
if err != nil {
return nil, err
}
certPool.AppendCertsFromPEM(caCert)
if params.CaCert == "" {
return nil, errors.New("CA Cert is not specified")
}
serverConfig := &tls.Config{
ClientAuth: serverAuth,
Certificates: []tls.Certificate{cert},
caCert, err := os.ReadFile(params.CaCert)
if err != nil {
return nil, err
}
if ok := certPool.AppendCertsFromPEM(caCert); !ok {
return nil, errors.New("could not parse PEM")
}
clientConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: params.SkipCertVerification,
Certificates: []tls.Certificate{cert},
RootCAs: certPool,
}
@ -99,7 +88,6 @@ func NewConnectionManager(params *NewConnectionManagerParams) (ConnectionManager
connMgr := ConnectionManagerImpl{
sync.RWMutex{},
connections,
serverConfig,
clientConfig,
params.ConnFactory,
}
@ -150,6 +138,15 @@ func (m *ConnectionManagerImpl) HasConnection(endPoint string) bool {
return exists
}
// RemoveConnection removes the given connection if it exists
func (m *ConnectionManagerImpl) RemoveConnection(endPoint string) error {
m.conLoc.Lock()
err := m.clientConnections[endPoint].Close()
delete(m.clientConnections, endPoint)
m.conLoc.Unlock()
return err
}
func (m *ConnectionManagerImpl) Close() error {
for _, conn := range m.clientConnections {
if err := conn.Close(); err != nil {

View File

@ -53,13 +53,13 @@ func TestNewConnectionManagerCACertDoesNotExistAndVerify(t *testing.T) {
func TestNewConnectionManagerCACertDoesNotExistAndNotVerify(t *testing.T) {
params := getConnectionManagerParams()
params.CaCert = ""
params.CaCert = "./cert/sdjsdjsdjk.pem"
params.SkipCertVerification = true
_, err := NewConnectionManager(params)
if err != nil {
t.Fatal(`an error should not be thrown`)
if err == nil {
t.Fatalf(`an error should be thrown`)
}
}

View File

@ -2,32 +2,34 @@ package conn
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"os"
"github.com/tim-beatham/wgmesh/pkg/conf"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/rpc"
"github.com/tim-beatham/smegmesh/pkg/conf"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/rpc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
// ConnectionServer manages gRPC server peer connections
type ConnectionServer struct {
// tlsConfiguration of the server
serverConfig *tls.Config
// server an instance of the grpc server
server *grpc.Server // the authentication service to authenticate nodes
server *grpc.Server
// the ctrl service to manage node
ctrlProvider rpc.MeshCtrlServerServer
// the sync service to synchronise nodes
syncProvider rpc.SyncServiceServer
Conf *conf.WgMeshConfiguration
Conf *conf.DaemonConfiguration
listener net.Listener
}
// NewConnectionServerParams contains params for creating a new connection server
type NewConnectionServerParams struct {
Conf *conf.WgMeshConfiguration
Conf *conf.DaemonConfiguration
CtrlProvider rpc.MeshCtrlServerServer
SyncProvider rpc.SyncServiceServer
}
@ -47,9 +49,26 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
serverAuth = tls.RequireAnyClientCert
}
certPool := x509.NewCertPool()
if params.Conf.CaCertificatePath == "" {
return nil, errors.New("CA Cert is not specified")
}
caCert, err := os.ReadFile(params.Conf.CaCertificatePath)
if err != nil {
return nil, err
}
if ok := certPool.AppendCertsFromPEM(caCert); !ok {
return nil, errors.New("could not parse PEM")
}
serverConfig := &tls.Config{
ClientAuth: serverAuth,
Certificates: []tls.Certificate{cert},
ClientCAs: certPool,
}
server := grpc.NewServer(
@ -60,7 +79,6 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
syncProvider := params.SyncProvider
connServer := ConnectionServer{
serverConfig: serverConfig,
server: server,
ctrlProvider: ctrlProvider,
syncProvider: syncProvider,
@ -73,13 +91,12 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
// Listen for incoming requests. Returns an error if something went wrong.
func (s *ConnectionServer) Listen() error {
rpc.RegisterMeshCtrlServerServer(s.server, s.ctrlProvider)
rpc.RegisterSyncServiceServer(s.server, s.syncProvider)
lis, err := net.Listen("tcp", ":"+s.Conf.GrpcPort)
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", s.Conf.GrpcPort))
s.listener = lis
logging.Log.WriteInfof("GRPC listening on %s\n", s.Conf.GrpcPort)
logging.Log.WriteInfof("GRPC listening on %d\n", s.Conf.GrpcPort)
if err != nil {
logging.Log.WriteErrorf(err.Error())

View File

@ -16,6 +16,11 @@ func (s *ConnectionManagerStub) AddConnection(endPoint string) (PeerConnection,
return mock, nil
}
func (s *ConnectionManagerStub) RemoveConnection(endPoint string) error {
delete(s.Endpoints, endPoint)
return nil
}
func (s *ConnectionManagerStub) GetConnection(endPoint string) (PeerConnection, error) {
endpoint, ok := s.Endpoints[endPoint]

View File

@ -1,84 +0,0 @@
package conn
import (
"errors"
"slices"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
// ConnectionWindow maintains a sliding window of connections between users
type ConnectionWindow interface {
// GetWindow is a list of connections to choose from
GetWindow() []string
// SlideConnection removes a node from the window and adds a random node
// not already in the window. connList represents the list of possible
// connections to choose from
SlideConnection(connList []string) error
// PushConneciton is used when connection list less than window size.
PutConnection(conn []string) error
// IsFull returns true if the window is full. In which case we must slide the window
IsFull() bool
}
type ConnectionWindowImpl struct {
window []string
windowSize int
}
// GetWindow gets the current list of active connections in
// the window
func (c *ConnectionWindowImpl) GetWindow() []string {
return c.window
}
// SlideConnection slides the connection window by one shuffling items
// in the windows
func (c *ConnectionWindowImpl) SlideConnection(connList []string) error {
// If the number of peer connections is less than the length of the window
// then exit early. Can't slide the window it should contain all nodes!
if len(c.window) < c.windowSize {
return nil
}
filter := func(node string) bool {
return !slices.Contains(c.window, node)
}
pool := lib.Filter(connList, filter)
newNode := lib.RandomSubsetOfLength(pool, 1)
if len(newNode) == 0 {
return errors.New("could not slide window")
}
for i := len(c.window) - 1; i >= 1; i-- {
c.window[i] = c.window[i-1]
}
c.window[0] = newNode[0]
return nil
}
// PutConnection put random connections in the connection
func (c *ConnectionWindowImpl) PutConnection(connList []string) error {
if len(c.window) >= c.windowSize {
return errors.New("cannot place connection. Window full need to slide")
}
c.window = lib.RandomSubsetOfLength(connList, c.windowSize)
return nil
}
func (c *ConnectionWindowImpl) IsFull() bool {
return len(c.window) >= c.windowSize
}
func NewConnectionWindow(windowLength int) ConnectionWindow {
window := &ConnectionWindowImpl{
window: make([]string, 0),
windowSize: windowLength,
}
return window
}

View File

@ -5,13 +5,14 @@ import (
"encoding/gob"
"fmt"
"net"
"slices"
"strings"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@ -48,6 +49,13 @@ type MeshNode struct {
Description string
Services map[string]string
Type string
Tombstone bool
}
// Mark: marks the node is unreachable. This is not broadcast on
// syncrhonisation
func (m *TwoPhaseStoreMeshManager) Mark(nodeId string) {
m.store.Mark(nodeId)
}
// GetHostEndpoint: gets the gRPC endpoint of the node
@ -146,12 +154,13 @@ func (m *MeshSnapshot) GetNodes() map[string]mesh.MeshNode {
}
type TwoPhaseStoreMeshManager struct {
MeshId string
IfName string
Client *wgctrl.Client
LastClock uint64
conf *conf.WgMeshConfiguration
store *TwoPhaseMap[string, MeshNode]
MeshId string
IfName string
Client *wgctrl.Client
LastClock uint64
Conf *conf.WgConfiguration
DaemonConf *conf.DaemonConfiguration
store *TwoPhaseMap[string, MeshNode]
}
// AddNode() adds a node to the mesh
@ -171,8 +180,16 @@ func (m *TwoPhaseStoreMeshManager) AddNode(node mesh.MeshNode) {
// GetMesh() returns a snapshot of the mesh provided by the mesh provider.
func (m *TwoPhaseStoreMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
nodes := m.store.AsList()
snapshot := make(map[string]MeshNode)
for _, node := range nodes {
snapshot[node.PublicKey] = node
}
return &MeshSnapshot{
Nodes: m.store.AsMap(),
Nodes: snapshot,
}, nil
}
@ -187,7 +204,6 @@ func (m *TwoPhaseStoreMeshManager) Save() []byte {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(*snapshot)
if err != nil {
@ -200,11 +216,11 @@ func (m *TwoPhaseStoreMeshManager) Save() []byte {
// Load() loads a mesh network
func (m *TwoPhaseStoreMeshManager) Load(bs []byte) error {
buf := bytes.NewBuffer(bs)
dec := gob.NewDecoder(buf)
var snapshot TwoPhaseMapSnapshot[string, MeshNode]
err := dec.Decode(&snapshot)
m.store.Merge(snapshot)
return err
}
@ -238,6 +254,31 @@ func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
// Sort nodes by their public key
peers := m.GetPeers()
slices.Sort(peers)
if len(peers) == 0 {
return nil
}
peerToUpdate := peers[0]
if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.DaemonConf.HeartBeat) {
m.store.Mark(peerToUpdate)
if len(peers) < 2 {
return nil
}
peerToUpdate = peers[1]
}
if peerToUpdate != nodeId {
return nil
}
// Refresh causing node to update it's time stamp
node := m.store.Get(nodeId)
node.Timestamp = time.Now().Unix()
m.store.Put(nodeId, node)
@ -256,19 +297,30 @@ func (m *TwoPhaseStoreMeshManager) AddRoutes(nodeId string, routes ...mesh.Route
node := m.store.Get(nodeId)
changes := false
for _, route := range routes {
node.Routes[route.GetDestination().String()] = Route{
Destination: route.GetDestination().String(),
Path: route.GetPath(),
prevRoute, ok := node.Routes[route.GetDestination().String()]
if !ok || route.GetHopCount() < prevRoute.GetHopCount() {
changes = true
node.Routes[route.GetDestination().String()] = Route{
Destination: route.GetDestination().String(),
Path: route.GetPath(),
}
}
}
m.store.Put(nodeId, node)
if changes {
m.store.Put(nodeId, node)
}
return nil
}
// DeleteRoutes: deletes the routes from the node
func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
@ -279,8 +331,15 @@ func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...string)
node := m.store.Get(nodeId)
changes := false
for _, route := range routes {
delete(node.Routes, route)
changes = true
delete(node.Routes, route.GetDestination().String())
}
if changes {
m.store.Put(nodeId, node)
}
return nil
@ -326,7 +385,7 @@ func (m *TwoPhaseStoreMeshManager) SetAlias(nodeId string, alias string) error {
}
node := m.store.Get(nodeId)
node.Description = alias
node.Alias = alias
m.store.Put(nodeId, node)
return nil
@ -351,26 +410,38 @@ func (m *TwoPhaseStoreMeshManager) RemoveService(nodeId string, key string) erro
}
node := m.store.Get(nodeId)
if _, ok := node.Services[key]; !ok {
return fmt.Errorf("datastore: node does not contain service %s", key)
}
delete(node.Services, key)
m.store.Put(nodeId, node)
return nil
}
// Prune: prunes all nodes that have not updated their timestamp in
// pruneAmount seconds
func (m *TwoPhaseStoreMeshManager) Prune(pruneAmount int) error {
func (m *TwoPhaseStoreMeshManager) Prune() error {
m.store.Prune()
return nil
}
// GetPeers: get a list of contactable peers
func (m *TwoPhaseStoreMeshManager) GetPeers() []string {
nodes := lib.MapValues(m.store.AsMap())
nodes := m.store.AsList()
nodes = lib.Filter(nodes, func(mn MeshNode) bool {
if mn.Type != string(conf.PEER_ROLE) {
return false
}
return time.Now().Unix()-mn.Timestamp < int64(m.conf.DeadTime)
// If the node is marked as unreachable don't consider it a peer.
// this help to optimize convergence time for unreachable nodes.
// However advertising it to other nodes could result in flapping.
if m.store.IsMarked(mn.PublicKey) {
return false
}
return true
})
return lib.Map(nodes, func(mn MeshNode) string {
@ -440,3 +511,8 @@ func (m *TwoPhaseStoreMeshManager) RemoveNode(nodeId string) error {
m.store.Remove(nodeId)
return nil
}
// GetConfiguration implements mesh.MeshProvider.
func (m *TwoPhaseStoreMeshManager) GetConfiguration() *conf.WgConfiguration {
return m.Conf
}

442
pkg/crdt/datastore_test.go Normal file
View File

@ -0,0 +1,442 @@
package crdt
import (
"net"
"slices"
"testing"
"time"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type TestParams struct {
manager mesh.MeshProvider
publicKey *wgtypes.Key
}
func setUpTests() *TestParams {
advertiseRoutes := false
advertiseDefaultRoute := false
role := conf.PEER_ROLE
discovery := conf.DNS_IP_DISCOVERY
factory := &TwoPhaseMapFactory{
Config: &conf.DaemonConfiguration{
CertificatePath: "/somecertificatepath",
PrivateKeyPath: "/someprivatekeypath",
CaCertificatePath: "/somecacertificatepath",
SkipCertVerification: true,
GrpcPort: 0,
Timeout: 20,
Profile: false,
SyncTime: 2,
HeartBeat: 10,
ClusterSize: 32,
InterClusterChance: 0.15,
BranchRate: 3,
InfectionCount: 3,
BaseConfiguration: conf.WgConfiguration{
IPDiscovery: &discovery,
AdvertiseRoutes: &advertiseRoutes,
AdvertiseDefaultRoute: &advertiseDefaultRoute,
Role: &role,
},
},
}
key, _ := wgtypes.GeneratePrivateKey()
mesh, _ := factory.CreateMesh(&mesh.MeshProviderFactoryParams{
DevName: "bob",
MeshId: "meshid123",
Client: nil,
Conf: &factory.Config.BaseConfiguration,
DaemonConf: factory.Config,
NodeID: "bob",
})
publicKey := key.PublicKey()
return &TestParams{
manager: mesh,
publicKey: &publicKey,
}
}
func getOurNode(testParams *TestParams) *MeshNode {
return &MeshNode{
HostEndpoint: "public-endpoint:8080",
WgEndpoint: "public-endpoint:21906",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d218/128",
PublicKey: testParams.publicKey.String(),
Timestamp: time.Now().Unix(),
Description: "A node that we are adding",
Type: "peer",
}
}
func getRandomNode() *MeshNode {
key, _ := wgtypes.GeneratePrivateKey()
publicKey := key.PublicKey()
return &MeshNode{
HostEndpoint: "public-endpoint:8081",
WgEndpoint: "public-endpoint:21907",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d234/128",
PublicKey: publicKey.String(),
Timestamp: time.Now().Unix(),
Description: "A node that we are adding",
Type: "peer",
}
}
func TestAddNodeAddsTheNodesToTheStore(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
if !testParams.manager.NodeExists(testParams.publicKey.String()) {
t.Fatalf(`node %s should have been added to the mesh network`, testParams.publicKey.String())
}
}
func TestAddNodeNodeAlreadyExistsReplacesTheNode(t *testing.T) {
TestAddNodeAddsTheNodesToTheStore(t)
TestAddNodeAddsTheNodesToTheStore(t)
}
func TestSaveThenLoad(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
bytes := testParams.manager.Save()
if err := testParams.manager.Load(bytes); err != nil {
t.Fatalf(`error caused by loading datastore: %s`, err.Error())
}
}
func TestHasChangesReturnsTrueWhenThereAreChangesInTheMesh(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
if !testParams.manager.HasChanges() {
t.Fatalf(`mesh has change but HasChanges returned false`)
}
testParams.manager.SetDescription(testParams.publicKey.String(), "Bob marley")
if !testParams.manager.HasChanges() {
t.Fatalf(`mesh has change but HasChanges returned false`)
}
testParams.manager.SaveChanges()
}
func TestHasChangesWhenThereAreNoChangesInTheMeshReturnsFalse(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.SaveChanges()
if testParams.manager.HasChanges() {
t.Fatalf(`mesh has no changes but HasChanges was true`)
}
testParams.manager.SetDescription(testParams.publicKey.String(), "Bob marley")
testParams.manager.SaveChanges()
if testParams.manager.HasChanges() {
t.Fatalf(`mesh has no changes but HasChanges was true`)
}
}
func TestUpdateTimeStampUpdatesTheTimeStampOfTheGivenNodeIfItIsTheLeader(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
before, _ := testParams.manager.GetNode(testParams.publicKey.String())
time.Sleep(1 * time.Second)
testParams.manager.UpdateTimeStamp(testParams.publicKey.String())
after, _ := testParams.manager.GetNode(testParams.publicKey.String())
if before.GetTimeStamp() >= after.GetTimeStamp() {
t.Fatalf(`before should not be after after`)
}
}
func TestUpdateTimeStampUpdatesTheTimeStampOfTheGivenNodeIfItIsNotLeader(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
newNode := getRandomNode()
newNode.PublicKey = "aaaaaaaaaa"
testParams.manager.AddNode(newNode)
before, _ := testParams.manager.GetNode(testParams.publicKey.String())
time.Sleep(1 * time.Second)
after, _ := testParams.manager.GetNode(testParams.publicKey.String())
if before.GetTimeStamp() != after.GetTimeStamp() {
t.Fatalf(`before and after should be the same`)
}
}
func TestAddRoutesAddsARouteToTheGivenMesh(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
_, destination, _ := net.ParseCIDR("0353:1da7:7f33:acc0:7a3f:6e55:912b:bc1f/64")
testParams.manager.AddRoutes(testParams.publicKey.String(), &mesh.RouteStub{
Destination: destination,
HopCount: 0,
Path: make([]string, 0),
})
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
containsDestination := lib.Contains(node.GetRoutes(), func(r mesh.Route) bool {
return r.GetDestination().Contains(destination.IP)
})
if !containsDestination {
t.Fatalf(`route has not been added to the node`)
}
}
func TestRemoveRoutesWithdrawsRoutesFromTheMesh(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
_, destination, _ := net.ParseCIDR("0353:1da7:7f33:acc0:7a3f:6e55:912b:bc1f/64")
route := &mesh.RouteStub{
Destination: destination,
HopCount: 0,
Path: make([]string, 0),
}
testParams.manager.AddRoutes(testParams.publicKey.String(), route)
testParams.manager.RemoveRoutes(testParams.publicKey.String(), route)
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
containsDestination := lib.Contains(node.GetRoutes(), func(r mesh.Route) bool {
return r.GetDestination().Contains(destination.IP)
})
if containsDestination {
t.Fatalf(`route has not been removed from the node`)
}
}
func TestGetNodeGetsTheNodeWhenItExists(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
if node == nil {
t.Fatalf(`node not found returned nil`)
}
}
func TestGetNodeReturnsNilWhenItDoesNotExist(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.RemoveNode(testParams.publicKey.String())
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
if node != nil {
t.Fatalf(`node found but should be nil`)
}
}
func TestNodeExistsReturnsFalseWhenNotExists(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.RemoveNode(testParams.publicKey.String())
if testParams.manager.NodeExists(testParams.publicKey.String()) {
t.Fatalf(`nodeexists should be false`)
}
}
func TestSetDescriptionReturnsErrorWhenNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.SetDescription("djdjdj", "djdsjkd")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestSetDescriptionSetsTheDescription(t *testing.T) {
testParams := setUpTests()
descriptionToSet := "djdsjkd"
testParams.manager.AddNode(getOurNode(testParams))
err := testParams.manager.SetDescription(testParams.publicKey.String(), descriptionToSet)
if err != nil {
t.Fatalf(`error %s thrown`, err.Error())
}
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
description := node.GetDescription()
if description != descriptionToSet {
t.Fatalf(`description was %s should be %s`, description, descriptionToSet)
}
}
func TestAliasNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.SetAlias("djdjdj", "djdsjkd")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestSetAliasSetsAlias(t *testing.T) {
testParams := setUpTests()
aliasToSet := "djdsjkd"
testParams.manager.AddNode(getOurNode(testParams))
err := testParams.manager.SetAlias(testParams.publicKey.String(), aliasToSet)
if err != nil {
t.Fatalf(`error %s thrown`, err.Error())
}
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
alias := node.GetAlias()
if alias != aliasToSet {
t.Fatalf(`description was %s should be %s`, alias, aliasToSet)
}
}
func TestAddServiceNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.AddService("djdjdj", "djdsjkd", "sddsds")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestAddServiceNodeExists(t *testing.T) {
testParams := setUpTests()
service := "djdsjkd"
serviceValue := "dsdsds"
testParams.manager.AddNode(getOurNode(testParams))
err := testParams.manager.AddService(testParams.publicKey.String(), service, serviceValue)
if err != nil {
t.Fatalf(`error %s thrown`, err.Error())
}
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
services := node.GetServices()
if value, ok := services[service]; !ok || value != serviceValue {
t.Fatalf(`service not added to the data store`)
}
}
func TestRemoveServiceDoesNotExists(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.RemoveService("djdjdj", "dsdssd")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestRemoveServiceServiceDoesNotExist(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
if err := testParams.manager.RemoveService(testParams.publicKey.String(), "dhsdh"); err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestGetPeersReturnsAllPeersInTheMesh(t *testing.T) {
testParams := setUpTests()
peer1 := getRandomNode()
peer2 := getRandomNode()
client := getRandomNode()
client.Type = "client"
testParams.manager.AddNode(peer1)
testParams.manager.AddNode(peer2)
testParams.manager.AddNode(client)
peers := testParams.manager.GetPeers()
slices.Sort(peers)
if len(peers) != 2 {
t.Fatalf(`there should be two peers in the mesh`)
}
peer1Pub, _ := peer1.GetPublicKey()
if !slices.Contains(peers, peer1Pub.String()) {
t.Fatalf(`peer1 not in the list`)
}
peer2Pub, _ := peer2.GetPublicKey()
if !slices.Contains(peers, peer2Pub.String()) {
t.Fatalf(`peer2 not in the list`)
}
}
func TestRemoveNodeReturnsErrorIfNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.RemoveNode("dsjdssjk")
if err == nil {
t.Fatalf(`error should have returned`)
}
}

View File

@ -2,46 +2,56 @@ package crdt
import (
"fmt"
"hash/fnv"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/mesh"
)
type TwoPhaseMapFactory struct{}
type TwoPhaseMapFactory struct {
Config *conf.DaemonConfiguration
}
func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) {
return &TwoPhaseStoreMeshManager{
MeshId: params.MeshId,
IfName: params.DevName,
Client: params.Client,
conf: params.Conf,
store: NewTwoPhaseMap[string, MeshNode](params.NodeID),
MeshId: params.MeshId,
IfName: params.DevName,
Client: params.Client,
Conf: params.Conf,
DaemonConf: params.DaemonConf,
store: NewTwoPhaseMap[string, MeshNode](params.NodeID, func(s string) uint64 {
h := fnv.New64a()
h.Write([]byte(s))
return h.Sum64()
}, uint64(3*f.Config.HeartBeat)),
}, nil
}
type MeshNodeFactory struct {
Config conf.WgMeshConfiguration
Config conf.DaemonConfiguration
}
func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode {
hostName := f.getAddress(params)
grpcEndpoint := fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort)
grpcEndpoint := fmt.Sprintf("%s:%d", hostName, f.Config.GrpcPort)
wgEndpoint := fmt.Sprintf("%s:%d", hostName, params.WgPort)
if f.Config.Role == conf.CLIENT_ROLE {
if *params.MeshConfig.Role == conf.CLIENT_ROLE {
grpcEndpoint = "-"
wgEndpoint = "-"
}
return &MeshNode{
HostEndpoint: grpcEndpoint,
PublicKey: params.PublicKey.String(),
WgEndpoint: fmt.Sprintf("%s:%d", hostName, params.WgPort),
WgEndpoint: wgEndpoint,
WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()),
Routes: make(map[string]Route),
Description: "",
Alias: "",
Type: string(f.Config.Role),
Type: string(*params.MeshConfig.Role),
}
}
@ -51,12 +61,12 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string
if params.Endpoint != "" {
hostName = params.Endpoint
} else if len(f.Config.Endpoint) != 0 {
hostName = f.Config.Endpoint
} else if params.MeshConfig.Endpoint != nil && len(*params.MeshConfig.Endpoint) != 0 {
hostName = *params.MeshConfig.Endpoint
} else {
ipFunc := lib.GetPublicIP
if f.Config.IPDiscovery == conf.DNS_IP_DISCOVERY {
if *params.MeshConfig.IPDiscovery == conf.DNS_IP_DISCOVERY {
ipFunc = lib.GetOutboundIP
}

View File

@ -1,28 +1,30 @@
// crdt is a golang implementation of a crdt
// crdt provides go implementations for crdts
package crdt
import (
"cmp"
"sync"
)
type Bucket[D any] struct {
Vector uint64
Contents D
Vector uint64
Contents D
Gravestone bool
}
// GMap is a set that can only grow in size
type GMap[K comparable, D any] struct {
type GMap[K cmp.Ordered, D any] struct {
lock sync.RWMutex
contents map[K]Bucket[D]
getClock func() uint64
contents map[uint64]Bucket[D]
clock *VectorClock[K]
}
func (g *GMap[K, D]) Put(key K, value D) {
g.lock.Lock()
clock := g.getClock() + 1
clock := g.clock.IncrementClock()
g.contents[key] = Bucket[D]{
g.contents[g.clock.hashFunc(key)] = Bucket[D]{
Vector: clock,
Contents: value,
}
@ -31,6 +33,10 @@ func (g *GMap[K, D]) Put(key K, value D) {
}
func (g *GMap[K, D]) Contains(key K) bool {
return g.contains(g.clock.hashFunc(key))
}
func (g *GMap[K, D]) contains(key uint64) bool {
g.lock.RLock()
_, ok := g.contents[key]
@ -40,7 +46,7 @@ func (g *GMap[K, D]) Contains(key K) bool {
return ok
}
func (g *GMap[K, D]) put(key K, b Bucket[D]) {
func (g *GMap[K, D]) put(key uint64, b Bucket[D]) {
g.lock.Lock()
if g.contents[key].Vector < b.Vector {
@ -50,7 +56,7 @@ func (g *GMap[K, D]) put(key K, b Bucket[D]) {
g.lock.Unlock()
}
func (g *GMap[K, D]) get(key K) Bucket[D] {
func (g *GMap[K, D]) get(key uint64) Bucket[D] {
g.lock.RLock()
bucket := g.contents[key]
g.lock.RUnlock()
@ -59,13 +65,46 @@ func (g *GMap[K, D]) get(key K) Bucket[D] {
}
func (g *GMap[K, D]) Get(key K) D {
return g.get(key).Contents
if !g.Contains(key) {
var def D
return def
}
return g.get(g.clock.hashFunc(key)).Contents
}
func (g *GMap[K, D]) Keys() []K {
func (g *GMap[K, D]) Mark(key K) {
if !g.Contains(key) {
return
}
g.lock.Lock()
bucket := g.contents[g.clock.hashFunc(key)]
bucket.Gravestone = true
g.contents[g.clock.hashFunc(key)] = bucket
g.lock.Unlock()
}
// IsMarked: returns true if the node is marked
func (g *GMap[K, D]) IsMarked(key K) bool {
marked := false
g.lock.RLock()
contents := make([]K, len(g.contents))
bucket, ok := g.contents[g.clock.hashFunc(key)]
if ok {
marked = bucket.Gravestone
}
g.lock.RUnlock()
return marked
}
func (g *GMap[K, D]) Keys() []uint64 {
g.lock.RLock()
contents := make([]uint64, len(g.contents))
index := 0
for key := range g.contents {
@ -77,8 +116,8 @@ func (g *GMap[K, D]) Keys() []K {
return contents
}
func (g *GMap[K, D]) Save() map[K]Bucket[D] {
buckets := make(map[K]Bucket[D])
func (g *GMap[K, D]) Save() map[uint64]Bucket[D] {
buckets := make(map[uint64]Bucket[D])
g.lock.RLock()
for key, value := range g.contents {
@ -89,8 +128,8 @@ func (g *GMap[K, D]) Save() map[K]Bucket[D] {
return buckets
}
func (g *GMap[K, D]) SaveWithKeys(keys []K) map[K]Bucket[D] {
buckets := make(map[K]Bucket[D])
func (g *GMap[K, D]) SaveWithKeys(keys []uint64) map[uint64]Bucket[D] {
buckets := make(map[uint64]Bucket[D])
g.lock.RLock()
for _, key := range keys {
@ -101,8 +140,8 @@ func (g *GMap[K, D]) SaveWithKeys(keys []K) map[K]Bucket[D] {
return buckets
}
func (g *GMap[K, D]) GetClock() map[K]uint64 {
clock := make(map[K]uint64)
func (g *GMap[K, D]) GetClock() map[uint64]uint64 {
clock := make(map[uint64]uint64)
g.lock.RLock()
for key, bucket := range g.contents {
@ -113,9 +152,33 @@ func (g *GMap[K, D]) GetClock() map[K]uint64 {
return clock
}
func NewGMap[K comparable, D any](getClock func() uint64) *GMap[K, D] {
func (g *GMap[K, D]) GetHash() uint64 {
hash := uint64(0)
g.lock.RLock()
for _, value := range g.contents {
hash += value.Vector
}
g.lock.RUnlock()
return hash
}
func (g *GMap[K, D]) Prune() {
stale := g.clock.getStale()
g.lock.Lock()
for _, outlier := range stale {
delete(g.contents, outlier)
}
g.lock.Unlock()
}
func NewGMap[K cmp.Ordered, D any](clock *VectorClock[K]) *GMap[K, D] {
return &GMap[K, D]{
contents: make(map[K]Bucket[D]),
getClock: getClock,
contents: make(map[uint64]Bucket[D]),
clock: clock,
}
}

224
pkg/crdt/g_map_test.go Normal file
View File

@ -0,0 +1,224 @@
// crdt_test unit tests the crdt implementations
package crdt
import (
"hash/fnv"
"slices"
"testing"
"time"
"github.com/tim-beatham/smegmesh/pkg/lib"
)
func NewGmap() *GMap[string, bool] {
vectorClock := NewVectorClock("a", func(key string) uint64 {
hash := fnv.New64a()
hash.Write([]byte(key))
return hash.Sum64()
}, 1) // 1 second stale time
gMap := NewGMap[string, bool](vectorClock)
return gMap
}
func TestGMapPutInsertsItems(t *testing.T) {
gMap := NewGmap()
gMap.Put("bruh1234", true)
if !gMap.Contains("bruh1234") {
t.Fatalf(`value not added to map`)
}
}
func TestGMapPutReplacesItems(t *testing.T) {
gMap := NewGmap()
gMap.Put("bruh1234", true)
gMap.Put("bruh1234", false)
value := gMap.Get("bruh1234")
if value {
t.Fatalf(`value should ahve been replaced to false`)
}
}
func TestContainsValueNotPresent(t *testing.T) {
gMap := NewGmap()
if gMap.Contains("sdhjsdhsdj") {
t.Fatalf(`value should not be present in the map`)
}
}
func TestContainsValuePresent(t *testing.T) {
gMap := NewGmap()
key := "hehehehe"
gMap.Put(key, false)
if !gMap.Contains(key) {
t.Fatalf(`%s should not be present in the map`, key)
}
}
func TestGMapGetNotPresentReturnsError(t *testing.T) {
gMap := NewGmap()
value := gMap.Get("bruh123")
if value != false {
t.Fatalf(`value should be default type false`)
}
}
func TestGMapGetReturnsValue(t *testing.T) {
gMap := NewGmap()
gMap.Put("bobdylan", true)
value := gMap.Get("bobdylan")
if !value {
t.Fatalf("value should be true but was false")
}
}
func TestMarkMarksTheValue(t *testing.T) {
gMap := NewGmap()
gMap.Put("hello123", true)
gMap.Mark("hello123")
if !gMap.IsMarked("hello123") {
t.Fatal(`hello123 should be marked`)
}
}
func TestMarkValueNotPresent(t *testing.T) {
gMap := NewGmap()
gMap.Mark("ok123456")
}
func TestKeysMapEmpty(t *testing.T) {
gMap := NewGmap()
keys := gMap.Keys()
if len(keys) != 0 {
t.Fatal(`list of keys was not empty but should be empty`)
}
}
func TestKeysMapReturnsKeysInMap(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
keys := gMap.Keys()
if len(keys) != 3 {
t.Fatal(`key length should be 3`)
}
}
func TestSaveMapEmptyReturnsEmptyMap(t *testing.T) {
gMap := NewGmap()
saveMap := gMap.Save()
if len(saveMap) != 0 {
t.Fatal(`saves should be empty`)
}
}
func TestSaveMapReturnsMapOfBuckets(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
saveMap := gMap.Save()
if len(saveMap) != 3 {
t.Fatalf(`save length should be 3`)
}
}
func TestSaveWithKeysNoKeysReturnsEmptyBucket(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
saveMap := gMap.SaveWithKeys([]uint64{})
if len(saveMap) != 0 {
t.Fatalf(`save map should be empty`)
}
}
func TestSaveWithKeysReturnsIntersection(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
clock := lib.MapKeys(gMap.GetClock())
clock = clock[:len(clock)-1]
values := gMap.SaveWithKeys(clock)
if len(values) != len(clock) {
t.Fatalf(`intersection not returned`)
}
}
func TestGetClockMapEmptyReturnsEmptyClock(t *testing.T) {
gMap := NewGmap()
clocks := gMap.GetClock()
if len(clocks) != 0 {
t.Fatalf(`vector clock is not empty`)
}
}
func TestGetClockReturnsAllCLocks(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
clocks := lib.MapValues(gMap.GetClock())
slices.Sort(clocks)
if !slices.Equal([]uint64{0, 1, 2}, clocks) {
t.Fatalf(`clocks are invalid`)
}
}
func TestGetHashChangesHashOnValueAdded(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
prevHash := gMap.GetHash()
gMap.Put("b", true)
if prevHash == gMap.GetHash() {
t.Fatalf(`hash should be different`)
}
}
func TestPruneGarbageCollectsValuesThatHaveNotBeenUpdated(t *testing.T) {
gMap := NewGmap()
gMap.clock.Put("c", 12)
gMap.Put("c", false)
gMap.Put("a", false)
time.Sleep(4 * time.Second)
gMap.Put("a", true)
gMap.Prune()
if gMap.Contains("c") {
t.Fatalf(`a should have been pruned`)
}
}

View File

@ -1,33 +1,37 @@
package crdt
import (
"sync"
"cmp"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/lib"
)
type TwoPhaseMap[K comparable, D any] struct {
type TwoPhaseMap[K cmp.Ordered, D any] struct {
addMap *GMap[K, D]
removeMap *GMap[K, bool]
vectors map[K]uint64
Clock *VectorClock[K]
processId K
lock sync.RWMutex
}
type TwoPhaseMapSnapshot[K comparable, D any] struct {
Add map[K]Bucket[D]
Remove map[K]Bucket[bool]
type TwoPhaseMapSnapshot[K cmp.Ordered, D any] struct {
Add map[uint64]Bucket[D]
Remove map[uint64]Bucket[bool]
}
// Contains checks whether the value exists in the map
func (m *TwoPhaseMap[K, D]) Contains(key K) bool {
if !m.addMap.Contains(key) {
return m.contains(m.Clock.hashFunc(key))
}
// Contains checks whether the value exists in the map
func (m *TwoPhaseMap[K, D]) contains(key uint64) bool {
if !m.addMap.contains(key) {
return false
}
addValue := m.addMap.get(key)
if !m.removeMap.Contains(key) {
if !m.removeMap.contains(key) {
return true
}
@ -46,32 +50,39 @@ func (m *TwoPhaseMap[K, D]) Get(key K) D {
return m.addMap.Get(key)
}
// Put places the key K in the map
func (m *TwoPhaseMap[K, D]) Put(key K, data D) {
msgSequence := m.incrementClock()
func (m *TwoPhaseMap[K, D]) get(key uint64) D {
var result D
m.lock.Lock()
if _, ok := m.vectors[key]; !ok {
m.vectors[key] = msgSequence
if !m.contains(key) {
return result
}
m.lock.Unlock()
return m.addMap.get(key).Contents
}
// Put places the key K in the map
func (m *TwoPhaseMap[K, D]) Put(key K, data D) {
msgSequence := m.Clock.IncrementClock()
m.Clock.Put(key, msgSequence)
m.addMap.Put(key, data)
}
func (m *TwoPhaseMap[K, D]) Mark(key K) {
m.addMap.Mark(key)
}
// Remove removes the value from the map
func (m *TwoPhaseMap[K, D]) Remove(key K) {
m.removeMap.Put(key, true)
}
func (m *TwoPhaseMap[K, D]) Keys() []K {
keys := make([]K, 0)
func (m *TwoPhaseMap[K, D]) keys() []uint64 {
keys := make([]uint64, 0)
addKeys := m.addMap.Keys()
for _, key := range addKeys {
if !m.Contains(key) {
if !m.contains(key) {
continue
}
@ -81,16 +92,16 @@ func (m *TwoPhaseMap[K, D]) Keys() []K {
return keys
}
func (m *TwoPhaseMap[K, D]) AsMap() map[K]D {
theMap := make(map[K]D)
func (m *TwoPhaseMap[K, D]) AsList() []D {
theList := make([]D, 0)
keys := m.Keys()
keys := m.keys()
for _, key := range keys {
theMap[key] = m.Get(key)
theList = append(theList, m.get(key))
}
return theMap
return theList
}
func (m *TwoPhaseMap[K, D]) Snapshot() *TwoPhaseMapSnapshot[K, D] {
@ -110,37 +121,21 @@ func (m *TwoPhaseMap[K, D]) SnapShotFromState(state *TwoPhaseMapState[K]) *TwoPh
}
}
type TwoPhaseMapState[K comparable] struct {
AddContents map[K]uint64
RemoveContents map[K]uint64
type TwoPhaseMapState[K cmp.Ordered] struct {
Vectors map[uint64]uint64
AddContents map[uint64]uint64
RemoveContents map[uint64]uint64
}
func (m *TwoPhaseMap[K, D]) incrementClock() uint64 {
maxClock := uint64(0)
m.lock.Lock()
for _, value := range m.vectors {
maxClock = max(maxClock, value)
}
m.vectors[m.processId] = maxClock + 1
m.lock.Unlock()
return maxClock
func (m *TwoPhaseMap[K, D]) IsMarked(key K) bool {
return m.addMap.IsMarked(key)
}
// GetHash: Get the hash of the current state of the map
// Sums the current values of the vectors. Provides good approximation
// of increasing numbers
func (m *TwoPhaseMap[K, D]) GetHash() uint64 {
m.lock.RLock()
sum := lib.Reduce(uint64(0), lib.MapValues(m.vectors), func(sum uint64, current uint64) uint64 {
return current + sum
})
m.lock.RUnlock()
return sum
return (m.addMap.GetHash() + 1) * (m.removeMap.GetHash() + 1)
}
// GetState: get the current vector clock of the add and remove
@ -150,29 +145,30 @@ func (m *TwoPhaseMap[K, D]) GenerateMessage() *TwoPhaseMapState[K] {
removeContents := m.removeMap.GetClock()
return &TwoPhaseMapState[K]{
Vectors: m.Clock.GetClock(),
AddContents: addContents,
RemoveContents: removeContents,
}
}
func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMapState[K] {
func (m *TwoPhaseMapState[K]) Difference(highestStale uint64, state *TwoPhaseMapState[K]) *TwoPhaseMapState[K] {
mapState := &TwoPhaseMapState[K]{
AddContents: make(map[K]uint64),
RemoveContents: make(map[K]uint64),
AddContents: make(map[uint64]uint64),
RemoveContents: make(map[uint64]uint64),
}
for key, value := range state.AddContents {
otherValue, ok := m.AddContents[key]
if !ok || otherValue < value {
if value > highestStale && (!ok || otherValue < value) {
mapState.AddContents[key] = value
}
}
for key, value := range state.AddContents {
for key, value := range state.RemoveContents {
otherValue, ok := m.RemoveContents[key]
if !ok || otherValue < value {
if value > highestStale && (!ok || otherValue < value) {
mapState.RemoveContents[key] = value
}
}
@ -181,31 +177,35 @@ func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMa
}
func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) {
m.lock.Lock()
for key, value := range snapshot.Add {
// Gravestone is local only to that node.
// Discover ourselves if the node is alive
m.addMap.put(key, value)
m.vectors[key] = max(value.Vector, m.vectors[key])
m.Clock.put(key, value.Vector)
}
for key, value := range snapshot.Remove {
m.removeMap.put(key, value)
m.vectors[key] = max(value.Vector, m.vectors[key])
m.Clock.put(key, value.Vector)
}
}
m.lock.Unlock()
func (m *TwoPhaseMap[K, D]) Prune() {
m.addMap.Prune()
m.removeMap.Prune()
m.Clock.Prune()
}
// NewTwoPhaseMap: create a new two phase map. Consists of two maps
// a grow map and a remove map. If both timestamps equal then favour keeping
// it in the map
func NewTwoPhaseMap[K comparable, D any](processId K) *TwoPhaseMap[K, D] {
func NewTwoPhaseMap[K cmp.Ordered, D any](processId K, hashKey func(K) uint64, staleTime uint64) *TwoPhaseMap[K, D] {
m := TwoPhaseMap[K, D]{
vectors: make(map[K]uint64),
processId: processId,
Clock: NewVectorClock(processId, hashKey, staleTime),
}
m.addMap = NewGMap[K, D](m.incrementClock)
m.removeMap = NewGMap[K, bool](m.incrementClock)
m.addMap = NewGMap[K, D](m.Clock)
m.removeMap = NewGMap[K, bool](m.Clock)
return &m
}

View File

@ -4,13 +4,14 @@ import (
"bytes"
"encoding/gob"
logging "github.com/tim-beatham/wgmesh/pkg/log"
logging "github.com/tim-beatham/smegmesh/pkg/log"
)
type SyncState int
const (
PREPARE SyncState = iota
HASH SyncState = iota
PREPARE
PRESENT
EXCHANGE
MERGE
@ -26,16 +27,61 @@ type TwoPhaseSyncer struct {
peerMsg []byte
}
type TwoPhaseHash struct {
Hash uint64
}
type SyncFSM map[SyncState]func(*TwoPhaseSyncer) ([]byte, bool)
func prepare(syncer *TwoPhaseSyncer) ([]byte, bool) {
func hash(syncer *TwoPhaseSyncer) ([]byte, bool) {
hash := TwoPhaseHash{
Hash: syncer.manager.store.Clock.GetHash(),
}
var buffer bytes.Buffer
enc := gob.NewEncoder(&buffer)
err := enc.Encode(*syncer.mapState)
err := enc.Encode(hash)
if err != nil {
logging.Log.WriteInfof(err.Error())
logging.Log.WriteErrorf(err.Error())
}
syncer.IncrementState()
return buffer.Bytes(), true
}
func prepare(syncer *TwoPhaseSyncer) ([]byte, bool) {
var recvBuffer = bytes.NewBuffer(syncer.peerMsg)
dec := gob.NewDecoder(recvBuffer)
var hash TwoPhaseHash
err := dec.Decode(&hash)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
// If vector clocks are equal then no need to merge state
// Helps to reduce bandwidth by detecting early
if hash.Hash == syncer.manager.store.Clock.GetHash() {
return nil, false
}
// Increment the clock here so the clock gets
// distributed to everyone else in the mesh
syncer.manager.store.Clock.IncrementClock()
var buffer bytes.Buffer
enc := gob.NewEncoder(&buffer)
mapState := syncer.manager.store.GenerateMessage()
syncer.mapState = mapState
err = enc.Encode(*syncer.mapState)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
syncer.IncrementState()
@ -54,10 +100,11 @@ func present(syncer *TwoPhaseSyncer) ([]byte, bool) {
err := dec.Decode(&mapState)
if err != nil {
logging.Log.WriteInfof(err.Error())
logging.Log.WriteErrorf(err.Error())
}
difference := syncer.mapState.Difference(&mapState)
difference := syncer.mapState.Difference(syncer.manager.store.Clock.GetStaleCount(), &mapState)
syncer.manager.store.Clock.Merge(mapState.Vectors)
var sendBuffer bytes.Buffer
enc := gob.NewEncoder(&sendBuffer)
@ -100,7 +147,6 @@ func merge(syncer *TwoPhaseSyncer) ([]byte, bool) {
dec.Decode(&snapshot)
syncer.manager.store.Merge(snapshot)
return nil, false
}
@ -129,6 +175,7 @@ func (t *TwoPhaseSyncer) Complete() {
func NewTwoPhaseSyncer(manager *TwoPhaseStoreMeshManager) *TwoPhaseSyncer {
var generateMessageFsm SyncFSM = SyncFSM{
HASH: hash,
PREPARE: prepare,
PRESENT: present,
EXCHANGE: exchange,
@ -137,8 +184,7 @@ func NewTwoPhaseSyncer(manager *TwoPhaseStoreMeshManager) *TwoPhaseSyncer {
return &TwoPhaseSyncer{
manager: manager,
state: PREPARE,
mapState: manager.store.GenerateMessage(),
state: HASH,
generateMessageFSM: generateMessageFsm,
}
}

View File

@ -0,0 +1,214 @@
package crdt
import (
"hash/fnv"
"slices"
"testing"
)
func NewMap(processId string) *TwoPhaseMap[string, string] {
theMap := NewTwoPhaseMap[string, string](processId, func(key string) uint64 {
hash := fnv.New64a()
hash.Write([]byte(key))
return hash.Sum64()
}, 1)
return theMap
}
func TestTwoPhaseMapEmpty(t *testing.T) {
theMap := NewMap("a")
if theMap.Contains("a") {
t.Fatalf(`a should not be present in the map`)
}
}
func TestTwoPhaseMapValuePresent(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "")
if !theMap.Contains("a") {
t.Fatalf(`should be present within the map`)
}
}
func TestTwoPhaseMapValueNotPresent(t *testing.T) {
theMap := NewMap("a")
theMap.Put("b", "")
if theMap.Contains("a") {
t.Fatalf(`a should not be present in the map`)
}
}
func TestTwoPhaseMapPutThenRemove(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "")
theMap.Remove("a")
if theMap.Contains("a") {
t.Fatalf(`a should not be present within the map`)
}
}
func TestTwoPhaseMapPutThenRemoveThenPut(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "")
theMap.Remove("a")
theMap.Put("a", "")
if !theMap.Contains("a") {
t.Fatalf(`a should be present within the map`)
}
}
func TestMarkMarksTheValueIn2PMap(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "")
theMap.Mark("a")
if !theMap.IsMarked("a") {
t.Fatalf(`a should be marked`)
}
}
func TestAsListReturnsItemsInList(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "bob")
theMap.Put("b", "dylan")
keys := theMap.AsList()
slices.Sort(keys)
if !slices.Equal([]string{"bob", "dylan"}, keys) {
t.Fatalf(`values should be bob, dylan`)
}
}
func TestSnapShotRemoveMapEmpty(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "bob")
theMap.Put("b", "dylan")
snapshot := theMap.Snapshot()
if len(snapshot.Add) != 2 {
t.Fatalf(`add values length should be 2`)
}
if len(snapshot.Remove) != 0 {
t.Fatalf(`remove map length should be 0`)
}
}
func TestSnapshotMapEmpty(t *testing.T) {
theMap := NewMap("a")
snapshot := theMap.Snapshot()
if len(snapshot.Add) != 0 || len(snapshot.Remove) != 0 {
t.Fatalf(`snapshot length should be 0`)
}
}
func TestSnapShotFromStateReturnsIntersection(t *testing.T) {
map1 := NewMap("a")
map1.Put("a", "heyy")
map2 := NewMap("b")
map2.Put("b", "hmmm")
message := map2.GenerateMessage()
snapShot := map1.SnapShotFromState(message)
if len(snapShot.Add) != 1 {
t.Fatalf(`add length should be 1`)
}
if len(snapShot.Remove) != 0 {
t.Fatalf(`remove length should be 0`)
}
}
func TestGetHashDifferentOnChange(t *testing.T) {
theMap := NewMap("a")
prevHash := theMap.GetHash()
theMap.Put("b", "hmmhmhmh")
if prevHash == theMap.GetHash() {
t.Fatalf(`hashes should not be the same`)
}
}
func TestGenerateMessageReturnsClocks(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "hmm")
theMap.Put("b", "hmm")
theMap.Remove("a")
message := theMap.GenerateMessage()
if len(message.AddContents) != 2 {
t.Fatalf(`two items added add should be 2`)
}
if len(message.RemoveContents) != 1 {
t.Fatalf(`a was removed remove map should be length 1`)
}
}
func TestDifferenceReturnsDifferenceOfMaps(t *testing.T) {
map1 := NewMap("a")
map1.Put("a", "ssms")
map1.Put("b", "sdmdsmd")
map2 := NewMap("b")
map2.Put("d", "eek")
map2.Put("c", "meh")
message1 := map1.GenerateMessage()
message2 := map2.GenerateMessage()
difference := message1.Difference(0, message2)
if len(difference.AddContents) != 2 {
t.Fatalf(`d and c are not in map1 they should be in add contents`)
}
if len(difference.RemoveContents) != 0 {
t.Fatalf(`remove should be empty`)
}
}
func TestMergeMergesValuesThatAreGreaterThanCurrentClock(t *testing.T) {
map1 := NewMap("a")
map1.Put("a", "ssms")
map1.Put("b", "sdmdsmd")
map2 := NewMap("b")
map2.Put("d", "eek")
map2.Put("c", "meh")
message1 := map1.GenerateMessage()
message2 := map2.GenerateMessage()
difference := message1.Difference(0, message2)
state := map2.SnapShotFromState(difference)
map1.Merge(*state)
if !map1.Contains("d") {
t.Fatalf(`d should be in the map`)
}
if !map2.Contains("c") {
t.Fatalf(`c should be in the map`)
}
}

168
pkg/crdt/vector_clock.go Normal file
View File

@ -0,0 +1,168 @@
package crdt
import (
"cmp"
"sync"
"time"
"github.com/tim-beatham/smegmesh/pkg/lib"
)
type VectorBucket struct {
// clock current value of the node's clock
clock uint64
// lastUpdate we've seen
lastUpdate uint64
}
// Vector clock defines an abstract data type
// for a vector clock implementation
type VectorClock[K cmp.Ordered] struct {
vectors map[uint64]*VectorBucket
lock sync.RWMutex
processID K
staleTime uint64
hashFunc func(K) uint64
// highest update that's been garbage collected
highestStale uint64
}
// IncrementClock: increments the node's value in the vector clock
func (m *VectorClock[K]) IncrementClock() uint64 {
maxClock := uint64(0)
m.lock.Lock()
for _, value := range m.vectors {
maxClock = max(maxClock, value.clock)
}
newBucket := VectorBucket{
clock: maxClock + 1,
lastUpdate: uint64(time.Now().Unix()),
}
m.vectors[m.hashFunc(m.processID)] = &newBucket
m.lock.Unlock()
return maxClock
}
// GetHash: gets the hash of the vector clock used to determine if there
// are any changes
func (m *VectorClock[K]) GetHash() uint64 {
m.lock.RLock()
hash := uint64(0)
for key, bucket := range m.vectors {
hash += key * (bucket.clock + 1)
}
m.lock.RUnlock()
return hash
}
func (m *VectorClock[K]) Merge(vectors map[uint64]uint64) {
for key, value := range vectors {
m.put(key, value)
}
}
// getStale: get all entries that are stale within the mesh
func (m *VectorClock[K]) getStale() []uint64 {
m.lock.RLock()
maxTimeStamp := lib.Reduce(0, lib.MapValues(m.vectors), func(i uint64, vb *VectorBucket) uint64 {
return max(i, vb.lastUpdate)
})
toRemove := make([]uint64, 0)
for key, bucket := range m.vectors {
if maxTimeStamp-bucket.lastUpdate > m.staleTime {
toRemove = append(toRemove, key)
m.highestStale = max(bucket.clock, m.highestStale)
}
}
m.lock.RUnlock()
return toRemove
}
// GetStaleCount: returns a vector clock which is considered to be stale.
// all updates must be greater than this
func (m *VectorClock[K]) GetStaleCount() uint64 {
m.lock.RLock()
staleCount := m.highestStale
m.lock.RUnlock()
return staleCount
}
func (m *VectorClock[K]) Prune() {
stale := m.getStale()
m.lock.Lock()
for _, key := range stale {
delete(m.vectors, key)
}
m.lock.Unlock()
}
func (m *VectorClock[K]) GetTimestamp(processId K) uint64 {
m.lock.RLock()
lastUpdate := m.vectors[m.hashFunc(m.processID)].lastUpdate
m.lock.RUnlock()
return lastUpdate
}
func (m *VectorClock[K]) Put(key K, value uint64) {
m.put(m.hashFunc(key), value)
}
func (m *VectorClock[K]) put(key uint64, value uint64) {
clockValue := uint64(0)
m.lock.Lock()
bucket, ok := m.vectors[key]
if ok {
clockValue = bucket.clock
}
// Make sure that entries that were garbage collected don't get
// addded back
if value > clockValue && value > m.highestStale {
newBucket := VectorBucket{
clock: value,
lastUpdate: uint64(time.Now().Unix()),
}
m.vectors[key] = &newBucket
}
m.lock.Unlock()
}
func (m *VectorClock[K]) GetClock() map[uint64]uint64 {
clock := make(map[uint64]uint64)
m.lock.RLock()
for key, value := range m.vectors {
clock[key] = value.clock
}
m.lock.RUnlock()
return clock
}
func NewVectorClock[K cmp.Ordered](processID K, hashFunc func(K) uint64, staleTime uint64) *VectorClock[K] {
return &VectorClock[K]{
vectors: make(map[uint64]*VectorBucket),
processID: processID,
staleTime: staleTime,
hashFunc: hashFunc,
}
}

View File

@ -1,22 +1,22 @@
package ctrlserver
import (
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/crdt"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query"
"github.com/tim-beatham/wgmesh/pkg/rpc"
"github.com/tim-beatham/wgmesh/pkg/wg"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/smegmesh/pkg/crdt"
"github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/query"
"github.com/tim-beatham/smegmesh/pkg/rpc"
"github.com/tim-beatham/smegmesh/pkg/wg"
"golang.zx2c4.com/wireguard/wgctrl"
)
// NewCtrlServerParams are the params requried to create a new ctrl server
type NewCtrlServerParams struct {
Conf *conf.WgMeshConfiguration
Conf *conf.DaemonConfiguration
Client *wgctrl.Client
CtrlProvider rpc.MeshCtrlServerServer
SyncProvider rpc.SyncServiceServer
@ -28,15 +28,17 @@ type NewCtrlServerParams struct {
// operation failed
func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
ctrlServer := new(MeshCtrlServer)
meshFactory := &crdt.TwoPhaseMapFactory{}
meshFactory := &crdt.TwoPhaseMapFactory{
Config: params.Conf,
}
nodeFactory := &crdt.MeshNodeFactory{
Config: *params.Conf,
}
idGenerator := &lib.IDNameGenerator{}
idGenerator := &lib.ShortIDGenerator{}
ipAllocator := &ip.ULABuilder{}
interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client)
configApplyer := mesh.NewWgMeshConfigApplyer(params.Conf)
configApplyer := mesh.NewWgMeshConfigApplyer()
meshManagerParams := &mesh.NewMeshManagerParams{
Conf: *params.Conf,
@ -87,7 +89,7 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
return ctrlServer, nil
}
func (s *MeshCtrlServer) GetConfiguration() *conf.WgMeshConfiguration {
func (s *MeshCtrlServer) GetConfiguration() *conf.DaemonConfiguration {
return s.Conf
}

View File

@ -1,10 +1,14 @@
package ctrlserver
import (
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query"
"net"
"time"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/query"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@ -14,6 +18,14 @@ type MeshRoute struct {
Path []string
}
// Represents the WireGuard configuration attached to the node
type WireGuardStats struct {
AllowedIPs []string
TransmitBytes int64
ReceivedBytes int64
PersistentKeepAliveInterval time.Duration
}
// Represents a WireGuard MeshNode
type MeshNode struct {
HostEndpoint string
@ -25,6 +37,7 @@ type MeshNode struct {
Description string
Alias string
Services map[string]string
Stats WireGuardStats
}
// Represents a WireGuard Mesh
@ -34,7 +47,7 @@ type Mesh struct {
}
type CtrlServer interface {
GetConfiguration() *conf.WgMeshConfiguration
GetConfiguration() *conf.DaemonConfiguration
GetClient() *wgctrl.Client
GetQuerier() query.Querier
GetMeshManager() mesh.MeshManager
@ -48,6 +61,56 @@ type MeshCtrlServer struct {
MeshManager mesh.MeshManager
ConnectionManager conn.ConnectionManager
ConnectionServer *conn.ConnectionServer
Conf *conf.WgMeshConfiguration
Conf *conf.DaemonConfiguration
Querier query.Querier
}
// NewCtrlNode create an instance of a ctrl node to send over an
// IPC call
func NewCtrlNode(provider mesh.MeshProvider, node mesh.MeshNode) *MeshNode {
pubKey, _ := node.GetPublicKey()
ctrlNode := MeshNode{
HostEndpoint: node.GetHostEndpoint(),
WgEndpoint: node.GetWgEndpoint(),
PublicKey: pubKey.String(),
WgHost: node.GetWgHost().String(),
Timestamp: node.GetTimeStamp(),
Routes: lib.Map(node.GetRoutes(), func(r mesh.Route) MeshRoute {
return MeshRoute{
Destination: r.GetDestination().String(),
Path: r.GetPath(),
}
}),
Description: node.GetDescription(),
Alias: node.GetAlias(),
Services: node.GetServices(),
}
device, err := provider.GetDevice()
if err != nil {
return &ctrlNode
}
peers := lib.Filter(device.Peers, func(p wgtypes.Peer) bool {
return p.PublicKey.String() == pubKey.String()
})
if len(peers) > 0 {
peer := peers[0]
stats := WireGuardStats{
AllowedIPs: lib.Map(peer.AllowedIPs, func(i net.IPNet) string {
return i.String()
}),
TransmitBytes: peer.TransmitBytes,
ReceivedBytes: peer.ReceiveBytes,
PersistentKeepAliveInterval: peer.PersistentKeepaliveInterval,
}
ctrlNode.Stats = stats
}
return &ctrlNode
}

View File

@ -1,10 +1,10 @@
package ctrlserver
import (
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/query"
"golang.zx2c4.com/wireguard/wgctrl"
)
@ -23,10 +23,10 @@ func NewCtrlServerStub() *CtrlServerStub {
}
}
func (c *CtrlServerStub) GetConfiguration() *conf.WgMeshConfiguration {
return &conf.WgMeshConfiguration{
GrpcPort: "8080",
Endpoint: "abc.com",
func (c *CtrlServerStub) GetConfiguration() *conf.DaemonConfiguration {
return &conf.DaemonConfiguration{
GrpcPort: 8080,
BaseConfiguration: conf.WgConfiguration{},
}
}

View File

@ -4,21 +4,18 @@ import (
"encoding/json"
"fmt"
"net"
"net/rpc"
"github.com/miekg/dns"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/query"
"github.com/tim-beatham/smegmesh/pkg/ipc"
"github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/query"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
const MeshRegularExpression = `(?P<meshId>.+)\.(?P<alias>.+)\.smeg\.`
type DNSHandler struct {
client *rpc.Client
client *ipc.SmegmeshIpc
server *dns.Server
}
@ -27,7 +24,7 @@ type DNSHandler struct {
func (d *DNSHandler) queryMesh(meshId, alias string) net.IP {
var reply string
err := d.client.Call("IpcHandler.Query", &ipc.QueryMesh{
err := d.client.Query(ipc.QueryMesh{
MeshId: meshId,
Query: fmt.Sprintf("[?alias == '%s'] | [0]", alias),
}, &reply)
@ -97,7 +94,7 @@ func (h *DNSHandler) Close() error {
}
func NewDns(udpPort int) (*DNSHandler, error) {
client, err := rpc.DialHTTP("unix", SockAddr)
client, err := ipc.NewClientIpc()
if err != nil {
return nil, err

227
pkg/dot/dot.go Normal file
View File

@ -0,0 +1,227 @@
// Graph allows the definition of a DOT graph in golang
package graph
import (
"fmt"
"hash/fnv"
"strings"
"github.com/tim-beatham/smegmesh/pkg/lib"
)
type GraphType string
type Shape string
const (
GRAPH GraphType = "graph"
DIGRAPH GraphType = "digraph"
)
const (
CIRCLE Shape = "circle"
STAR Shape = "star"
HEXAGON Shape = "hexagon"
PARALLELOGRAM Shape = "parallelogram"
)
type Graph interface {
Dottable
GetType() GraphType
}
type Cluster struct {
Type GraphType
Name string
Label string
nodes map[string]*Node
edges map[string]Edge
}
type RootGraph struct {
Type GraphType
Label string
nodes map[string]*Node
clusters map[string]*Cluster
edges map[string]Edge
}
type Node struct {
Name string
Label string
Shape Shape
Size int
}
type Edge interface {
Dottable
}
type DirectedEdge struct {
Name string
Label string
From string
To string
}
type UndirectedEdge struct {
Name string
Label string
From string
To string
}
// Dottable means an implementer can convert the struct to DOT representation
type Dottable interface {
GetDOT() (string, error)
}
func NewGraph(label string, graphType GraphType) *RootGraph {
return &RootGraph{Type: graphType, Label: label, clusters: map[string]*Cluster{}, nodes: make(map[string]*Node), edges: make(map[string]Edge)}
}
// PutNode: puts a node in the graph
func (g *RootGraph) PutNode(name, label string, size int, shape Shape) error {
_, exists := g.nodes[name]
if exists {
// If exists no need to add the ndoe
return nil
}
g.nodes[name] = &Node{Name: name, Label: label, Size: size, Shape: shape}
return nil
}
func (g *RootGraph) PutCluster(graph *Cluster) {
g.clusters[graph.Label] = graph
}
func writeContituents[D Dottable](result *strings.Builder, elements ...D) error {
for _, node := range elements {
dot, err := node.GetDOT()
if err != nil {
return err
}
_, err = result.WriteString(dot)
if err != nil {
return err
}
}
return nil
}
func (g *RootGraph) GetDOT() (string, error) {
var result strings.Builder
result.WriteString(fmt.Sprintf("%s {\n", g.Type))
result.WriteString("node [colorscheme=set312];\n")
result.WriteString("layout = fdp;\n")
nodes := lib.MapValues(g.nodes)
edges := lib.MapValues(g.edges)
writeContituents(&result, nodes...)
writeContituents(&result, edges...)
for _, cluster := range g.clusters {
clusterDOT, err := cluster.GetDOT()
if err != nil {
return "", err
}
result.WriteString(clusterDOT)
}
result.WriteString("}")
return result.String(), nil
}
// GetType implements Graph.
func (r *RootGraph) GetType() GraphType {
return r.Type
}
func constructEdge(graph Graph, name, label, from, to string) Edge {
switch graph.GetType() {
case DIGRAPH:
return &DirectedEdge{Name: name, Label: label, From: from, To: to}
default:
return &UndirectedEdge{Name: name, Label: label, From: from, To: to}
}
}
// AddEdge: adds an edge between two nodes in the graph
func (g *RootGraph) AddEdge(name string, label string, from string, to string) error {
g.edges[name] = constructEdge(g, name, label, from, to)
return nil
}
const numColours = 12
func (n *Node) hash() int {
h := fnv.New32a()
h.Write([]byte(n.Name))
return (int(h.Sum32()) % numColours) + 1
}
func (n *Node) GetDOT() (string, error) {
return fmt.Sprintf("node[label=\"%s\",shape=%s, style=\"filled\", fillcolor=%d, width=%d, height=%d, fixedsize=true] \"%s\";\n",
n.Label, n.Shape, n.hash(), n.Size, n.Size, n.Name), nil
}
func (e *DirectedEdge) GetDOT() (string, error) {
return fmt.Sprintf("\"%s\" -> \"%s\" [label=\"%s\"];\n", e.From, e.To, e.Label), nil
}
func (e *UndirectedEdge) GetDOT() (string, error) {
return fmt.Sprintf("\"%s\" -- \"%s\" [label=\"%s\"];\n", e.From, e.To, e.Label), nil
}
// AddEdge: adds an edge between two nodes in the graph
func (g *Cluster) AddEdge(name string, label string, from string, to string) error {
g.edges[name] = constructEdge(g, name, label, from, to)
return nil
}
// PutNode: puts a node in the graph
func (g *Cluster) PutNode(name, label string, size int, shape Shape) error {
_, exists := g.nodes[name]
if exists {
// If exists no need to add the ndoe
return nil
}
g.nodes[name] = &Node{Name: name, Label: label, Shape: shape, Size: size}
return nil
}
func (g *Cluster) GetDOT() (string, error) {
var builder strings.Builder
builder.WriteString(fmt.Sprintf("subgraph \"cluster%s\" {\n", g.Label))
builder.WriteString(fmt.Sprintf("label = \"%s\"\n", g.Label))
nodes := lib.MapValues(g.nodes)
edges := lib.MapValues(g.edges)
writeContituents(&builder, nodes...)
writeContituents(&builder, edges...)
builder.WriteString("}\n")
return builder.String(), nil
}
func (g *Cluster) GetType() GraphType {
return g.Type
}
func NewSubGraph(name string, label string, graphType GraphType) *Cluster {
return &Cluster{
Label: name,
Type: graphType,
Name: name,
nodes: make(map[string]*Node),
edges: make(map[string]Edge),
}
}

116
pkg/dot/wg.go Normal file
View File

@ -0,0 +1,116 @@
package graph
import (
"fmt"
"slices"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
)
// MeshGraphConverter converts a mesh to a graph
type MeshGraphConverter interface {
// convert the mesh to textual form
Generate() (string, error)
}
type MeshDOTConverter struct {
meshes map[string][]ctrlserver.MeshNode
destinations map[string]interface{}
}
func (c *MeshDOTConverter) Generate() (string, error) {
g := NewGraph("Smegmesh", GRAPH)
for meshId := range c.meshes {
err := c.generateMesh(g, meshId)
if err != nil {
return "", err
}
}
for mesh := range c.meshes {
g.PutNode(mesh, mesh, 1, CIRCLE)
}
for destination := range c.destinations {
g.PutNode(destination, destination, 1, HEXAGON)
}
return g.GetDOT()
}
func (c *MeshDOTConverter) generateMesh(g *RootGraph, meshId string) error {
nodes := c.meshes[meshId]
g.PutNode(meshId, meshId, 1, CIRCLE)
for _, node := range nodes {
c.graphNode(g, node, meshId)
}
for _, node := range nodes {
g.AddEdge(fmt.Sprintf("%s to %s", node.PublicKey, meshId), "", node.PublicKey, meshId)
}
return nil
}
// graphNode: graphs a node within the mesh
func (c *MeshDOTConverter) graphNode(g *RootGraph, node ctrlserver.MeshNode, meshId string) {
alias := node.Alias
if alias == "" {
alias = node.WgHost[1:len(node.WgHost)-20] + "\\n" + node.WgHost[len(node.WgHost)-20:len(node.WgHost)]
}
g.PutNode(node.PublicKey, alias, 2, CIRCLE)
for _, route := range node.Routes {
if len(route.Path) == 0 {
g.AddEdge(route.Destination, "", node.PublicKey, route.Destination)
continue
}
reversedPath := slices.Clone(route.Path)
slices.Reverse(reversedPath)
g.AddEdge(fmt.Sprintf("%s to %s", node.PublicKey, reversedPath[0]), "", node.PublicKey, reversedPath[0])
for _, mesh := range route.Path {
if _, ok := c.meshes[mesh]; !ok {
c.destinations[mesh] = struct{}{}
}
}
for index := range reversedPath[0 : len(reversedPath)-1] {
routeID := fmt.Sprintf("%s to %s", reversedPath[index], reversedPath[index+1])
g.AddEdge(routeID, "", reversedPath[index], reversedPath[index+1])
}
if route.Destination == "::/0" {
c.destinations[route.Destination] = struct{}{}
lastMesh := reversedPath[len(reversedPath)-1]
routeID := fmt.Sprintf("%s to %s", lastMesh, route.Destination)
g.AddEdge(routeID, "", lastMesh, route.Destination)
}
}
for service := range node.Services {
c.putService(g, service, meshId, node)
}
}
// putService: construct a service node and a link between the nodes
func (c *MeshDOTConverter) putService(g *RootGraph, key, meshId string, node ctrlserver.MeshNode) {
serviceID := fmt.Sprintf("%s%s%s", key, node.PublicKey, meshId)
g.PutNode(serviceID, key, 1, PARALLELOGRAM)
g.AddEdge(fmt.Sprintf("%s to %s", node.PublicKey, serviceID), "", node.PublicKey, serviceID)
}
func NewMeshGraphConverter(meshes map[string][]ctrlserver.MeshNode) MeshGraphConverter {
return &MeshDOTConverter{
meshes: meshes,
destinations: make(map[string]interface{}),
}
}

View File

@ -1,178 +0,0 @@
// Graph allows the definition of a DOT graph in golang
package graph
import (
"errors"
"fmt"
"hash/fnv"
"strings"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
type GraphType string
type Shape string
const (
GRAPH GraphType = "graph"
DIGRAPH = "digraph"
)
const (
CIRCLE Shape = "circle"
STAR Shape = "star"
HEXAGON Shape = "hexagon"
)
type Graph struct {
Type GraphType
Label string
nodes map[string]*Node
edges []Edge
}
type Node struct {
Name string
Shape Shape
}
type Edge interface {
Dottable
}
type DirectedEdge struct {
Label string
From *Node
To *Node
}
type UndirectedEdge struct {
Label string
From *Node
To *Node
}
// Dottable means an implementer can convert the struct to DOT representation
type Dottable interface {
GetDOT() (string, error)
}
func NewGraph(label string, graphType GraphType) *Graph {
return &Graph{Type: graphType, Label: label, nodes: make(map[string]*Node), edges: make([]Edge, 0)}
}
// PutNode: puts a node in the graph
func (g *Graph) PutNode(label string, shape Shape) error {
_, exists := g.nodes[label]
if exists {
// If exists no need to add the ndoe
return nil
}
g.nodes[label] = &Node{Name: label, Shape: shape}
return nil
}
func writeContituents[D Dottable](result *strings.Builder, elements ...D) error {
for _, node := range elements {
dot, err := node.GetDOT()
if err != nil {
return err
}
_, err = result.WriteString(dot)
if err != nil {
return err
}
}
return nil
}
func (g *Graph) GetDOT() (string, error) {
var result strings.Builder
_, err := result.WriteString(fmt.Sprintf("%s {\n", g.Type))
if err != nil {
return "", err
}
_, err = result.WriteString("node [colorscheme=set312];\n")
if err != nil {
return "", err
}
nodes := lib.MapValues(g.nodes)
err = writeContituents(&result, nodes...)
if err != nil {
return "", err
}
err = writeContituents(&result, g.edges...)
if err != nil {
return "", err
}
_, err = result.WriteString("}")
if err != nil {
return "", err
}
return result.String(), nil
}
func (g *Graph) constructEdge(label string, from *Node, to *Node) Edge {
switch g.Type {
case DIGRAPH:
return &DirectedEdge{Label: label, From: from, To: to}
default:
return &UndirectedEdge{Label: label, From: from, To: to}
}
}
// AddEdge: adds an edge between two nodes in the graph
func (g *Graph) AddEdge(label string, from string, to string) error {
fromNode, exists := g.nodes[from]
if !exists {
return errors.New(fmt.Sprintf("Node %s does not exist", from))
}
toNode, exists := g.nodes[to]
if !exists {
return errors.New(fmt.Sprintf("Node %s does not exist", to))
}
g.edges = append(g.edges, g.constructEdge(label, fromNode, toNode))
return nil
}
const numColours = 12
func (n *Node) hash() int {
h := fnv.New32a()
h.Write([]byte(n.Name))
return (int(h.Sum32()) % numColours) + 1
}
func (n *Node) GetDOT() (string, error) {
return fmt.Sprintf("node[shape=%s, style=\"filled\", fillcolor=%d] %s;\n",
n.Shape, n.hash(), n.Name), nil
}
func (e *DirectedEdge) GetDOT() (string, error) {
return fmt.Sprintf("%s -> %s;\n", e.From.Name, e.To.Name), nil
}
func (e *UndirectedEdge) GetDOT() (string, error) {
return fmt.Sprintf("%s -- %s;\n", e.From.Name, e.To.Name), nil
}

View File

@ -1,132 +0,0 @@
// hosts: utility for modifying the /etc/hosts file
package hosts
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"os"
"strings"
)
// HOSTS_FILE is the hosts file location
const HOSTS_FILE = "/etc/hosts"
const DOMAIN_HEADER = "#WG AUTO GENERATED HOSTS"
const DOMAIN_TRAILER = "#WG AUTO GENERATED HOSTS END"
type HostsEntry struct {
Alias string
Ip net.IP
}
// Generic interface to manipulate /etc/hosts file
type HostsManipulator interface {
// AddrAddr associates an aliasd with a given IP address
AddAddr(hosts ...HostsEntry)
// Remove deletes the entry from /etc/hosts
Remove(hosts ...HostsEntry)
// Writes the changes to /etc/hosts file
Write() error
}
type HostsManipulatorImpl struct {
hosts map[string]HostsEntry
}
// AddAddr implements HostsManipulator.
func (m *HostsManipulatorImpl) AddAddr(hosts ...HostsEntry) {
changed := false
for _, host := range hosts {
prev, ok := m.hosts[host.Ip.String()]
if !ok || prev.Alias != host.Alias {
changed = true
}
m.hosts[host.Ip.String()] = host
}
if changed {
m.Write()
}
}
// Remove implements HostsManipulator.
func (m *HostsManipulatorImpl) Remove(hosts ...HostsEntry) {
lenBefore := len(m.hosts)
for _, host := range hosts {
delete(m.hosts, host.Alias)
}
if lenBefore != len(m.hosts) {
m.Write()
}
}
func (m *HostsManipulatorImpl) removeHosts() string {
hostsFile, err := os.ReadFile(HOSTS_FILE)
if err != nil {
return ""
}
var contents strings.Builder
scanner := bufio.NewScanner(bytes.NewReader(hostsFile))
hostsSection := false
for scanner.Scan() {
line := scanner.Text()
if err == io.EOF {
break
} else if err != nil {
return ""
}
if !hostsSection && strings.Contains(line, DOMAIN_HEADER) {
hostsSection = true
}
if !hostsSection {
contents.WriteString(line + "\n")
}
if hostsSection && strings.Contains(line, DOMAIN_TRAILER) {
hostsSection = false
}
}
if scanner.Err() != nil && scanner.Err() != io.EOF {
return ""
}
return contents.String()
}
// Write implements HostsManipulator
func (m *HostsManipulatorImpl) Write() error {
contents := m.removeHosts()
var nextHosts strings.Builder
nextHosts.WriteString(contents)
nextHosts.WriteString(DOMAIN_HEADER + "\n")
for _, host := range m.hosts {
nextHosts.WriteString(fmt.Sprintf("%s\t%s\n", host.Ip.String(), host.Alias))
}
nextHosts.WriteString(DOMAIN_TRAILER + "\n")
return os.WriteFile(HOSTS_FILE, []byte(nextHosts.String()), 0644)
}
func NewHostsManipulator() HostsManipulator {
return &HostsManipulatorImpl{hosts: make(map[string]HostsEntry)}
}

View File

@ -34,7 +34,7 @@ type CgaParameters struct {
flag byte
}
func NewCga(key wgtypes.Key, subnetPrefix [2 * InterfaceIdLen]byte) (*CgaParameters, error) {
func NewCga(key wgtypes.Key, collisionCount uint8, subnetPrefix [2 * InterfaceIdLen]byte) (*CgaParameters, error) {
var params CgaParameters
_, err := rand.Read(params.Modifier[:])
@ -45,6 +45,7 @@ func NewCga(key wgtypes.Key, subnetPrefix [2 * InterfaceIdLen]byte) (*CgaParamet
params.PublicKey = key
params.SubnetPrefix = subnetPrefix
params.CollisionCount = collisionCount
return &params, nil
}
@ -78,7 +79,6 @@ func (c *CgaParameters) generateHash1() []byte {
byteVal[hash1Length-1] = c.CollisionCount
hash := sha1.Sum(byteVal[:])
return hash[:Hash1Prefix]
}
@ -90,9 +90,6 @@ func clearBit(num, pos int) byte {
}
func (c *CgaParameters) generateInterface() []byte {
// TODO: On duplicate address detection increment collision.
// Also incorporate SEC
hash1 := c.generateHash1()
var interfaceId []byte = make([]byte, InterfaceIdLen)

View File

@ -7,5 +7,5 @@ import (
)
type IPAllocator interface {
GetIP(key wgtypes.Key, meshId string) (net.IP, error)
GetIP(key wgtypes.Key, meshId string, collisionCount uint8) (net.IP, error)
}

View File

@ -39,10 +39,10 @@ func (u *ULABuilder) GetIPNet(meshId string) (*net.IPNet, error) {
return net, nil
}
func (u *ULABuilder) GetIP(key wgtypes.Key, meshId string) (net.IP, error) {
func (u *ULABuilder) GetIP(key wgtypes.Key, meshId string, collisionCount uint8) (net.IP, error) {
ulaPrefix := getMeshPrefix(meshId)
c, err := NewCga(key, ulaPrefix)
c, err := NewCga(key, collisionCount, ulaPrefix)
if err != nil {
return nil, err

View File

@ -5,37 +5,80 @@ import (
"net"
"net/http"
"net/rpc"
ipcRpc "net/rpc"
"os"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
)
type NewMeshArgs struct {
const SockAddr = "/tmp/wgmesh_sock"
type MeshIpc interface {
CreateMesh(args *NewMeshArgs, reply *string) error
ListMeshes(name string, reply *ListMeshReply) error
JoinMesh(args *JoinMeshArgs, reply *string) error
LeaveMesh(meshId string, reply *string) error
GetMesh(meshId string, reply *GetMeshReply) error
Query(query QueryMesh, reply *string) error
PutDescription(args PutDescriptionArgs, reply *string) error
PutAlias(args PutAliasArgs, reply *string) error
PutService(args PutServiceArgs, reply *string) error
DeleteService(args DeleteServiceArgs, reply *string) error
}
// WireGuardArgs are provided args specific to WireGuard
type WireGuardArgs struct {
// WgPort is the WireGuard port to expose
WgPort int
// KeepAliveWg is the number of seconds to keep alive
// for WireGuard NAT/firewall traversal
KeepAliveWg int
// AdvertiseRoutes whether or not to advertise routes to and from the
// mesh network
AdvertiseRoutes bool
// AdvertiseDefaultRoute whether or not to advertise the default route
// into the mesh network
AdvertiseDefaultRoute bool
// Endpoint is the routable alias of the machine. Can be an IP
// or DNS entry
Endpoint string
// Role is the role of the individual in the mesh
Role string
}
type NewMeshArgs struct {
// WgArgs are specific WireGuard args to use
WgArgs WireGuardArgs
}
type JoinMeshArgs struct {
// MeshId is the ID of the mesh to join
MeshId string
// IpAddress is a routable IP in another mesh
IpAdress string
// Port is the WireGuard port to expose
Port int
// Endpoint is the routable address of this machine. If not provided
// defaults to the default address
Endpoint string
// Client specifies whether we should join as a client of the peer
// we are connecting to
Client bool
IpAddress string
// WgArgs is the WireGuard parameters to use.
WgArgs WireGuardArgs
}
type PutServiceArgs struct {
Service string
Value string
MeshId string
}
type DeleteServiceArgs struct {
Service string
MeshId string
}
type PutAliasArgs struct {
Alias string
MeshId string
}
type PutDescriptionArgs struct {
Description string
MeshId string
}
type GetMeshReply struct {
@ -51,27 +94,78 @@ type QueryMesh struct {
Query string
}
type GetNodeArgs struct {
NodeId string
MeshId string
}
type MeshIpc interface {
type ClientIpc interface {
CreateMesh(args *NewMeshArgs, reply *string) error
ListMeshes(name string, reply *ListMeshReply) error
ListMeshes(args *ListMeshReply, reply *string) error
JoinMesh(args JoinMeshArgs, reply *string) error
LeaveMesh(meshId string, reply *string) error
GetMesh(meshId string, reply *GetMeshReply) error
GetDOT(meshId string, reply *string) error
Query(query QueryMesh, reply *string) error
PutDescription(description string, reply *string) error
PutAlias(alias string, reply *string) error
PutDescription(args PutDescriptionArgs, reply *string) error
PutAlias(args PutAliasArgs, reply *string) error
PutService(args PutServiceArgs, reply *string) error
GetNode(args GetNodeArgs, reply *string) error
DeleteService(service string, reply *string) error
DeleteService(args DeleteServiceArgs, reply *string) error
}
const SockAddr = "/tmp/wgmesh_ipc.sock"
type SmegmeshIpc struct {
client *ipcRpc.Client
}
func NewClientIpc() (*SmegmeshIpc, error) {
client, err := ipcRpc.DialHTTP("unix", SockAddr)
if err != nil {
return nil, err
}
return &SmegmeshIpc{
client: client,
}, nil
}
func (c *SmegmeshIpc) CreateMesh(args *NewMeshArgs, reply *string) error {
return c.client.Call("IpcHandler.CreateMesh", args, reply)
}
func (c *SmegmeshIpc) ListMeshes(reply *ListMeshReply) error {
return c.client.Call("IpcHandler.ListMeshes", "", reply)
}
func (c *SmegmeshIpc) JoinMesh(args JoinMeshArgs, reply *string) error {
return c.client.Call("IpcHandler.JoinMesh", &args, reply)
}
func (c *SmegmeshIpc) LeaveMesh(meshId string, reply *string) error {
return c.client.Call("IpcHandler.LeaveMesh", &meshId, reply)
}
func (c *SmegmeshIpc) GetMesh(meshId string, reply *GetMeshReply) error {
return c.client.Call("IpcHandler.GetMesh", &meshId, reply)
}
func (c *SmegmeshIpc) Query(query QueryMesh, reply *string) error {
return c.client.Call("IpcHandler.Query", &query, reply)
}
func (c *SmegmeshIpc) PutDescription(args PutDescriptionArgs, reply *string) error {
return c.client.Call("IpcHandler.PutDescription", &args, reply)
}
func (c *SmegmeshIpc) PutAlias(args PutAliasArgs, reply *string) error {
return c.client.Call("IpcHandler.PutAlias", &args, reply)
}
func (c *SmegmeshIpc) PutService(args PutServiceArgs, reply *string) error {
return c.client.Call("IpcHandler.PutService", &args, reply)
}
func (c *SmegmeshIpc) DeleteService(args DeleteServiceArgs, reply *string) error {
return c.client.Call("IpcHandler.DeleteService", &args, reply)
}
func (c *SmegmeshIpc) Close() error {
return c.Close()
}
func RunIpcHandler(server MeshIpc) error {
if err := os.RemoveAll(SockAddr); err != nil {

View File

@ -1,11 +1,34 @@
package lib
import "cmp"
// MapToSlice converts a map to a slice in go
func MapValues[K comparable, V any](m map[K]V) []V {
func MapValues[K cmp.Ordered, V any](m map[K]V) []V {
return MapValuesWithExclude(m, map[K]struct{}{})
}
func MapValuesWithExclude[K comparable, V any](m map[K]V, exclude map[K]struct{}) []V {
type MapItemsEntry[K cmp.Ordered, V any] struct {
Key K
Value V
}
func MapItems[K cmp.Ordered, V any](m map[K]V) []MapItemsEntry[K, V] {
keys := MapKeys(m)
values := MapValues(m)
vs := make([]MapItemsEntry[K, V], len(keys))
for index, _ := range keys {
vs[index] = MapItemsEntry[K, V]{
Key: keys[index],
Value: values[index],
}
}
return vs
}
func MapValuesWithExclude[K cmp.Ordered, V any](m map[K]V, exclude map[K]struct{}) []V {
values := make([]V, len(m)-len(exclude))
i := 0
@ -26,7 +49,7 @@ func MapValuesWithExclude[K comparable, V any](m map[K]V, exclude map[K]struct{}
return values
}
func MapKeys[K comparable, V any](m map[K]V) []K {
func MapKeys[K cmp.Ordered, V any](m map[K]V) []K {
values := make([]K, len(m))
i := 0

View File

@ -3,6 +3,7 @@ package lib
import (
"github.com/anandvarma/namegen"
"github.com/google/uuid"
"github.com/lithammer/shortuuid"
)
// IdGenerator generates unique ids
@ -19,6 +20,14 @@ func (g *UUIDGenerator) GetId() (string, error) {
return id.String(), nil
}
type ShortIDGenerator struct {
}
func (g *ShortIDGenerator) GetId() (string, error) {
id := shortuuid.New()
return id, nil
}
type IDNameGenerator struct {
}

View File

@ -6,7 +6,7 @@ import (
"net"
"github.com/jsimonetti/rtnetlink"
logging "github.com/tim-beatham/wgmesh/pkg/log"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"golang.org/x/sys/unix"
)
@ -140,26 +140,38 @@ func (c *RtNetlinkConfig) AddRoute(ifName string, route Route) error {
family = unix.AF_INET
}
attr := rtnetlink.RouteAttributes{
Dst: dst.IP,
OutIface: uint32(iface.Index),
Gateway: gw,
}
ones, _ := dst.Mask.Size()
err = c.conn.Route.Replace(&rtnetlink.RouteMessage{
Family: family,
Table: unix.RT_TABLE_MAIN,
Protocol: unix.RTPROT_BOOT,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
DstLength: uint8(ones),
Attributes: attr,
})
routes, err := c.listRoutes(ifName, family)
if err != nil {
return fmt.Errorf("failed to add route %w", err)
return err
}
// If it already exists no need to add the route
if !Contains(routes, func(prevRoute rtnetlink.RouteMessage) bool {
return prevRoute.Attributes.Dst.Equal(route.Destination.IP) &&
prevRoute.Attributes.Gateway.Equal(route.Gateway)
}) {
attr := rtnetlink.RouteAttributes{
Dst: dst.IP,
OutIface: uint32(iface.Index),
Gateway: gw,
}
ones, _ := dst.Mask.Size()
err = c.conn.Route.Replace(&rtnetlink.RouteMessage{
Family: family,
Table: unix.RT_TABLE_MAIN,
Protocol: unix.RTPROT_BOOT,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
DstLength: uint8(ones),
Attributes: attr,
})
if err != nil {
return fmt.Errorf("failed to add route %w", err)
}
}
return nil
@ -213,8 +225,11 @@ type Route struct {
}
func (r1 Route) equal(r2 Route) bool {
mask1Ones, _ := r1.Destination.Mask.Size()
mask2Ones, _ := r2.Destination.Mask.Size()
return r1.Gateway.String() == r2.Gateway.String() &&
r1.Destination.String() == r2.Destination.String()
(mask1Ones == 0 && mask2Ones == 0 || r1.Destination.IP.Equal(r2.Destination.IP))
}
// DeleteRoutes deletes all routes not in exclude
@ -245,17 +260,18 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
shouldExclude := func(r Route) bool {
for _, route := range exclude {
if route.equal(r) {
if r.equal(route) {
return false
}
}
return true
}
toDelete := Filter(ifRoutes, shouldExclude)
for _, route := range toDelete {
logging.Log.WriteInfof("Deleting route: %s", route.Gateway.String())
logging.Log.WriteInfof("Deleting route: %s", route.Destination.String())
err := c.DeleteRoute(ifName, route)
if err != nil {

View File

@ -1,46 +0,0 @@
package mesh
import (
"fmt"
"github.com/tim-beatham/wgmesh/pkg/hosts"
)
type MeshAliasManager interface {
AddAliases(nodes []MeshNode)
RemoveAliases(node []MeshNode)
}
type AliasManager struct {
hosts hosts.HostsManipulator
}
// AddAliases: on node update or change add aliases to the hosts file
func (a *AliasManager) AddAliases(nodes []MeshNode) {
for _, node := range nodes {
if node.GetAlias() != "" {
a.hosts.AddAddr(hosts.HostsEntry{
Alias: fmt.Sprintf("%s.smeg", node.GetAlias()),
Ip: node.GetWgHost().IP,
})
}
}
}
// RemoveAliases: on node remove remove aliases from the hosts file
func (a *AliasManager) RemoveAliases(nodes []MeshNode) {
for _, node := range nodes {
if node.GetAlias() != "" {
a.hosts.Remove(hosts.HostsEntry{
Alias: fmt.Sprintf("%s.smeg", node.GetAlias()),
Ip: node.GetWgHost().IP,
})
}
}
}
func NewAliasManager() MeshAliasManager {
return &AliasManager{
hosts: hosts.NewHostsManipulator(),
}
}

View File

@ -7,11 +7,10 @@ import (
"strings"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/route"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/route"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@ -25,8 +24,8 @@ type MeshConfigApplyer interface {
// WgMeshConfigApplyer applies WireGuard configuration
type WgMeshConfigApplyer struct {
meshManager MeshManager
config *conf.WgMeshConfiguration
routeInstaller route.RouteInstaller
hashFunc func(MeshNode) int
}
type routeNode struct {
@ -34,49 +33,44 @@ type routeNode struct {
route Route
}
func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Device,
peerToClients map[string][]net.IPNet,
routes map[string][]routeNode) (*wgtypes.PeerConfig, error) {
type convertMeshNodeParams struct {
node MeshNode
self MeshNode
mesh MeshProvider
device *wgtypes.Device
peerToClients map[string][]net.IPNet
routes map[string][]routeNode
}
endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint())
if err != nil {
return nil, err
}
pubKey, err := node.GetPublicKey()
func (m *WgMeshConfigApplyer) convertMeshNode(params convertMeshNodeParams) (*wgtypes.PeerConfig, error) {
pubKey, err := params.node.GetPublicKey()
if err != nil {
return nil, err
}
allowedips := make([]net.IPNet, 1)
allowedips[0] = *node.GetWgHost()
allowedips[0] = *params.node.GetWgHost()
clients, ok := peerToClients[pubKey.String()]
clients, ok := params.peerToClients[pubKey.String()]
if ok {
allowedips = append(allowedips, clients...)
}
for _, route := range node.GetRoutes() {
bestRoutes := routes[route.GetDestination().String()]
for _, route := range params.node.GetRoutes() {
bestRoutes := params.routes[route.GetDestination().String()]
var pickedRoute routeNode
if len(bestRoutes) == 1 {
pickedRoute = bestRoutes[0]
} else if len(bestRoutes) > 1 {
keyFunc := func(mn MeshNode) int {
pubKey, _ := mn.GetPublicKey()
return lib.HashString(pubKey.String())
}
bucketFunc := func(rn routeNode) int {
return lib.HashString(rn.gateway)
}
// Else there is more than one candidate so consistently hash
pickedRoute = lib.ConsistentHash(bestRoutes, node, bucketFunc, keyFunc)
pickedRoute = lib.ConsistentHash(bestRoutes, params.self, bucketFunc, m.hashFunc)
}
if pickedRoute.gateway == pubKey.String() {
@ -84,15 +78,32 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev
}
}
keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second
config := params.mesh.GetConfiguration()
existing := slices.IndexFunc(device.Peers, func(p wgtypes.Peer) bool {
pubKey, _ := node.GetPublicKey()
var keepAlive time.Duration = time.Duration(0)
if config.KeepAliveWg != nil {
keepAlive = time.Duration(*config.KeepAliveWg) * time.Second
}
existing := slices.IndexFunc(params.device.Peers, func(p wgtypes.Peer) bool {
pubKey, _ := params.node.GetPublicKey()
return p.PublicKey.String() == pubKey.String()
})
var endpoint *net.UDPAddr = nil
if params.node.GetType() == conf.PEER_ROLE {
endpoint, err = net.ResolveUDPAddr("udp", params.node.GetWgEndpoint())
}
if err != nil {
return nil, err
}
// Don't override the existing IP in case it already exists
if existing != -1 {
endpoint = device.Peers[existing].Endpoint
endpoint = params.device.Peers[existing].Endpoint
}
peerConfig := wgtypes.PeerConfig{
@ -108,15 +119,17 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev
// getRoutes: finds the routes with the least hop distance. If more than one route exists
// consistently hash to evenly spread the distribution of traffic
func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]routeNode {
func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) (map[string][]routeNode, error) {
mesh, _ := meshProvider.GetMesh()
routes := make(map[string][]routeNode)
peers := lib.Filter(lib.MapValues(mesh.GetNodes()), func(p MeshNode) bool {
return p.GetType() == conf.PEER_ROLE
})
meshPrefixes := lib.Map(lib.MapValues(m.meshManager.GetMeshes()), func(mesh MeshProvider) *net.IPNet {
ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
return ipNet
})
@ -125,6 +138,10 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
for _, route := range node.GetRoutes() {
if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool {
if prefix.IP.Equal(net.IPv6zero) && *meshProvider.GetConfiguration().AdvertiseDefaultRoute {
return true
}
return prefix.Contains(route.GetDestination().IP)
}) {
continue
@ -138,6 +155,26 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
route: route,
}
// Client's only acessible by another peer
if node.GetType() == conf.CLIENT_ROLE {
peer := m.getCorrespondingPeer(peers, node)
self, err := meshProvider.GetNode(m.meshManager.GetPublicKey().String())
if err != nil {
return nil, err
}
if !NodeEquals(peer, self) {
peerPub, _ := peer.GetPublicKey()
rn.gateway = peerPub.String()
rn.route = &RouteStub{
Destination: rn.route.GetDestination(),
HopCount: rn.route.GetHopCount() + 1,
Path: append(rn.route.GetPath(), peer.GetWgHost().IP.String()),
}
}
}
if !ok {
otherRoute = make([]routeNode, 1)
otherRoute[0] = rn
@ -145,79 +182,138 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
} else if route.GetHopCount() < otherRoute[0].route.GetHopCount() {
otherRoute[0] = rn
} else if otherRoute[0].route.GetHopCount() == route.GetHopCount() {
logging.Log.WriteInfof("Other Route Hop: %d", otherRoute[0].route.GetHopCount())
logging.Log.WriteInfof("Route gateway %s, route hop %d", rn.gateway, route.GetHopCount())
routes[destination] = append(otherRoute, rn)
}
}
}
return routes
return routes, nil
}
// getCorrespondignPeer: gets the peer corresponding to the client
func (m *WgMeshConfigApplyer) getCorrespondingPeer(peers []MeshNode, client MeshNode) MeshNode {
hashFunc := func(mn MeshNode) int {
pubKey, _ := mn.GetPublicKey()
return lib.HashString(pubKey.String())
}
peer := lib.ConsistentHash(peers, client, hashFunc, hashFunc)
peer := lib.ConsistentHash(peers, client, m.hashFunc, m.hashFunc)
return peer
}
func (m *WgMeshConfigApplyer) getClientConfig(mesh MeshProvider, peers []MeshNode, clients []MeshNode) (*wgtypes.Config, error) {
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
func (m *WgMeshConfigApplyer) getPeerCfgsToRemove(dev *wgtypes.Device, newPeers []wgtypes.PeerConfig) []wgtypes.PeerConfig {
peers := dev.Peers
peers = lib.Filter(peers, func(p1 wgtypes.Peer) bool {
return !lib.Contains(newPeers, func(p2 wgtypes.PeerConfig) bool {
return p1.PublicKey.String() == p2.PublicKey.String()
})
})
return lib.Map(peers, func(p wgtypes.Peer) wgtypes.PeerConfig {
return wgtypes.PeerConfig{
PublicKey: p.PublicKey,
Remove: true,
}
})
}
type GetConfigParams struct {
mesh MeshProvider
peers []MeshNode
clients []MeshNode
dev *wgtypes.Device
routes map[string][]routeNode
}
func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes.Config, error) {
ula := &ip.ULABuilder{}
meshNet, _ := ula.GetIPNet(params.mesh.GetMeshId())
routesForMesh := lib.Map(lib.MapValues(params.routes), func(rns []routeNode) []routeNode {
return lib.Filter(rns, func(rn routeNode) bool {
ip, _, _ := net.ParseCIDR(rn.gateway)
return meshNet.Contains(ip)
})
})
routes := lib.Map(routesForMesh, func(rs []routeNode) net.IPNet {
return *rs[0].route.GetDestination()
})
routes = append(routes, *meshNet)
self, err := params.mesh.GetNode(m.meshManager.GetPublicKey().String())
if err != nil {
return nil, err
}
peer := m.getCorrespondingPeer(peers, self)
peer := m.getCorrespondingPeer(params.peers, self)
pubKey, _ := peer.GetPublicKey()
keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second
config := params.mesh.GetConfiguration()
keepAlive := time.Duration(*config.KeepAliveWg) * time.Second
endpoint, err := net.ResolveUDPAddr("udp", peer.GetWgEndpoint())
if err != nil {
return nil, err
}
allowedips := make([]net.IPNet, 1)
_, ipnet, _ := net.ParseCIDR("::/0")
allowedips[0] = *ipnet
peerCfgs := make([]wgtypes.PeerConfig, 1)
peerCfgs[0] = wgtypes.PeerConfig{
PublicKey: pubKey,
Endpoint: endpoint,
PersistentKeepaliveInterval: &keepAlive,
AllowedIPs: allowedips,
AllowedIPs: routes,
ReplaceAllowedIPs: true,
}
installedRoutes := make([]lib.Route, 0)
for _, route := range peerCfgs[0].AllowedIPs {
installedRoutes = append(installedRoutes, lib.Route{
Gateway: peer.GetWgHost().IP,
Destination: route,
})
}
cfg := wgtypes.Config{
Peers: peerCfgs,
}
m.routeInstaller.InstallRoutes(params.dev.Name, installedRoutes...)
return &cfg, err
}
func (m *WgMeshConfigApplyer) getPeerConfig(mesh MeshProvider, peers []MeshNode, clients []MeshNode, dev *wgtypes.Device) (*wgtypes.Config, error) {
func (m *WgMeshConfigApplyer) getRoutesToInstall(wgNode *wgtypes.PeerConfig, mesh MeshProvider, node MeshNode) []lib.Route {
routes := make([]lib.Route, 0)
for _, route := range wgNode.AllowedIPs {
ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
_, defaultRoute, _ := net.ParseCIDR("::/0")
if !ipNet.Contains(route.IP) && !ipNet.IP.Equal(defaultRoute.IP) {
routes = append(routes, lib.Route{
Gateway: node.GetWgHost().IP,
Destination: route,
})
}
}
return routes
}
func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.Config, error) {
peerToClients := make(map[string][]net.IPNet)
routes := m.getRoutes(mesh)
installedRoutes := make([]lib.Route, 0)
peerConfigs := make([]wgtypes.PeerConfig, 0)
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
self, err := params.mesh.GetNode(m.meshManager.GetPublicKey().String())
if err != nil {
return nil, err
}
for _, n := range clients {
if len(peers) > 0 {
peer := m.getCorrespondingPeer(peers, n)
for _, n := range params.clients {
if len(params.peers) > 0 {
peer := m.getCorrespondingPeer(params.peers, n)
pubKey, _ := peer.GetPublicKey()
clients, ok := peerToClients[pubKey.String()]
@ -229,53 +325,56 @@ func (m *WgMeshConfigApplyer) getPeerConfig(mesh MeshProvider, peers []MeshNode,
peerToClients[pubKey.String()] = append(clients, *n.GetWgHost())
if NodeEquals(self, peer) {
cfg, err := m.convertMeshNode(n, dev, peerToClients, routes)
cfg, err := m.convertMeshNode(convertMeshNodeParams{
node: n,
self: self,
mesh: params.mesh,
device: params.dev,
peerToClients: peerToClients,
routes: params.routes,
})
if err != nil {
return nil, err
}
installedRoutes = append(installedRoutes, m.getRoutesToInstall(cfg, params.mesh, n)...)
peerConfigs = append(peerConfigs, *cfg)
}
}
}
for _, n := range peers {
for _, n := range params.peers {
if NodeEquals(n, self) {
continue
}
peer, err := m.convertMeshNode(n, dev, peerToClients, routes)
peer, err := m.convertMeshNode(convertMeshNodeParams{
node: n,
self: self,
mesh: params.mesh,
peerToClients: peerToClients,
routes: params.routes,
device: params.dev,
})
if err != nil {
return nil, err
}
for _, route := range peer.AllowedIPs {
ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
if !ipNet.Contains(route.IP) {
installedRoutes = append(installedRoutes, lib.Route{
Gateway: n.GetWgHost().IP,
Destination: route,
})
}
}
installedRoutes = append(installedRoutes, m.getRoutesToInstall(peer, params.mesh, n)...)
peerConfigs = append(peerConfigs, *peer)
}
cfg := wgtypes.Config{
Peers: peerConfigs,
ReplacePeers: true,
Peers: peerConfigs,
}
err = m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...)
err = m.routeInstaller.InstallRoutes(params.dev.Name, installedRoutes...)
return &cfg, err
}
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string][]routeNode) error {
snap, err := mesh.GetMesh()
if err != nil {
@ -297,7 +396,7 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
return mn.GetType() == conf.CLIENT_ROLE
})
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
self, err := mesh.GetNode(m.meshManager.GetPublicKey().String())
if err != nil {
return err
@ -305,17 +404,28 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
var cfg *wgtypes.Config = nil
configParams := &GetConfigParams{
mesh: mesh,
peers: peers,
clients: clients,
dev: dev,
routes: routes,
}
switch self.GetType() {
case conf.PEER_ROLE:
cfg, err = m.getPeerConfig(mesh, peers, clients, dev)
cfg, err = m.getPeerConfig(configParams)
case conf.CLIENT_ROLE:
cfg, err = m.getClientConfig(mesh, peers, clients)
cfg, err = m.getClientConfig(configParams)
}
if err != nil {
return err
}
toRemove := m.getPeerCfgsToRemove(dev, cfg.Peers)
cfg.Peers = append(cfg.Peers, toRemove...)
err = m.meshManager.GetClient().ConfigureDevice(dev.Name, *cfg)
if err != nil {
@ -325,9 +435,44 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
return nil
}
func (m *WgMeshConfigApplyer) ApplyConfig() error {
func (m *WgMeshConfigApplyer) getAllRoutes() (map[string][]routeNode, error) {
allRoutes := make(map[string][]routeNode)
for _, mesh := range m.meshManager.GetMeshes() {
err := m.updateWgConf(mesh)
routes, err := m.getRoutes(mesh)
if err != nil {
return nil, err
}
for destination, route := range routes {
_, ok := allRoutes[destination]
if !ok {
allRoutes[destination] = route
continue
}
if allRoutes[destination][0].route.GetHopCount() == route[0].route.GetHopCount() {
allRoutes[destination] = append(allRoutes[destination], route...)
} else if route[0].route.GetHopCount() < allRoutes[destination][0].route.GetHopCount() {
allRoutes[destination] = route
}
}
}
return allRoutes, nil
}
func (m *WgMeshConfigApplyer) ApplyConfig() error {
allRoutes, err := m.getAllRoutes()
if err != nil {
return err
}
for _, mesh := range m.meshManager.GetMeshes() {
err := m.updateWgConf(mesh, allRoutes)
if err != nil {
return err
@ -362,9 +507,12 @@ func (m *WgMeshConfigApplyer) SetMeshManager(manager MeshManager) {
m.meshManager = manager
}
func NewWgMeshConfigApplyer(config *conf.WgMeshConfiguration) MeshConfigApplyer {
func NewWgMeshConfigApplyer() MeshConfigApplyer {
return &WgMeshConfigApplyer{
config: config,
routeInstaller: route.NewRouteInstaller(),
hashFunc: func(mn MeshNode) int {
pubKey, _ := mn.GetPublicKey()
return lib.HashString(pubKey.String())
},
}
}

View File

@ -1,77 +0,0 @@
package mesh
import (
"errors"
"fmt"
"github.com/tim-beatham/wgmesh/pkg/graph"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
// MeshGraphConverter converts a mesh to a graph
type MeshGraphConverter interface {
// convert the mesh to textual form
Generate(meshId string) (string, error)
}
type MeshDOTConverter struct {
manager MeshManager
}
func (c *MeshDOTConverter) Generate(meshId string) (string, error) {
mesh := c.manager.GetMesh(meshId)
if mesh == nil {
return "", errors.New("mesh does not exist")
}
g := graph.NewGraph(meshId, graph.GRAPH)
snapshot, err := mesh.GetMesh()
if err != nil {
return "", err
}
for _, node := range snapshot.GetNodes() {
c.graphNode(g, node, meshId)
}
nodes := lib.MapValues(snapshot.GetNodes())
for i, node1 := range nodes[:len(nodes)-1] {
for _, node2 := range nodes[i+1:] {
if node1.GetWgEndpoint() == node2.GetWgEndpoint() {
continue
}
node1Id := fmt.Sprintf("\"%s\"", node1.GetIdentifier())
node2Id := fmt.Sprintf("\"%s\"", node2.GetIdentifier())
g.AddEdge(fmt.Sprintf("%s to %s", node1Id, node2Id), node1Id, node2Id)
}
}
return g.GetDOT()
}
// graphNode: graphs a node within the mesh
func (c *MeshDOTConverter) graphNode(g *graph.Graph, node MeshNode, meshId string) {
nodeId := fmt.Sprintf("\"%s\"", node.GetIdentifier())
g.PutNode(nodeId, graph.CIRCLE)
self, _ := c.manager.GetSelf(meshId)
if NodeEquals(self, node) {
return
}
for _, route := range node.GetRoutes() {
routeId := fmt.Sprintf("\"%s\"", route)
g.PutNode(routeId, graph.HEXAGON)
g.AddEdge(fmt.Sprintf("%s to %s", nodeId, routeId), nodeId, routeId)
}
}
func NewMeshDotConverter(m MeshManager) MeshGraphConverter {
return &MeshDOTConverter{manager: m}
}

View File

@ -3,19 +3,21 @@ package mesh
import (
"errors"
"fmt"
"net"
"sync"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/wg"
"github.com/tim-beatham/smegmesh/pkg/cmd"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/wg"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type MeshManager interface {
CreateMesh(port int) (string, error)
CreateMesh(params *CreateMeshParams) (string, error)
AddMesh(params *AddMeshParams) error
HasChanges(meshid string) bool
GetMesh(meshId string) MeshProvider
@ -24,16 +26,14 @@ type MeshManager interface {
LeaveMesh(meshId string) error
GetSelf(meshId string) (MeshNode, error)
ApplyConfig() error
SetDescription(description string) error
SetAlias(alias string) error
SetService(service string, value string) error
RemoveService(service string) error
SetDescription(meshId, description string) error
SetAlias(meshId, alias string) error
SetService(meshId, service, value string) error
RemoveService(meshId, service string) error
UpdateTimeStamp() error
GetClient() *wgctrl.Client
GetMeshes() map[string]MeshProvider
Prune() error
Close() error
GetMonitor() MeshMonitor
GetNode(string, string) MeshNode
GetRouteManager() RouteManager
}
@ -46,14 +46,14 @@ type MeshManagerImpl struct {
// HostParameters contains information that uniquely locates
// the node in the mesh network.
HostParameters *HostParameters
conf *conf.WgMeshConfiguration
conf *conf.DaemonConfiguration
meshProviderFactory MeshProviderFactory
nodeFactory MeshNodeFactory
configApplyer MeshConfigApplyer
idGenerator lib.IdGenerator
ipAllocator ip.IPAllocator
interfaceManipulator wg.WgInterfaceManipulator
Monitor MeshMonitor
cmdRunner cmd.CmdRunner
OnDelete func(MeshProvider)
}
@ -63,29 +63,33 @@ func (m *MeshManagerImpl) GetRouteManager() RouteManager {
}
// RemoveService implements MeshManager.
func (m *MeshManagerImpl) RemoveService(service string) error {
for _, mesh := range m.Meshes {
err := mesh.RemoveService(m.HostParameters.GetPublicKey(), service)
func (m *MeshManagerImpl) RemoveService(meshId, service string) error {
mesh := m.GetMesh(meshId)
if err != nil {
return err
}
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
return nil
if !mesh.NodeExists(m.HostParameters.GetPublicKey()) {
return fmt.Errorf("node %s does not exist in the mesh", meshId)
}
return mesh.RemoveService(m.HostParameters.GetPublicKey(), service)
}
// SetService implements MeshManager.
func (m *MeshManagerImpl) SetService(service string, value string) error {
for _, mesh := range m.Meshes {
err := mesh.AddService(m.HostParameters.GetPublicKey(), service, value)
func (m *MeshManagerImpl) SetService(meshId, service, value string) error {
mesh := m.GetMesh(meshId)
if err != nil {
return err
}
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
return nil
if !mesh.NodeExists(m.HostParameters.GetPublicKey()) {
return fmt.Errorf("node %s does not exist in the mesh", meshId)
}
return mesh.AddService(m.HostParameters.GetPublicKey(), service, value)
}
func (m *MeshManagerImpl) GetNode(meshid, nodeId string) MeshNode {
@ -104,26 +108,42 @@ func (m *MeshManagerImpl) GetNode(meshid, nodeId string) MeshNode {
return node
}
// GetMonitor implements MeshManager.
func (m *MeshManagerImpl) GetMonitor() MeshMonitor {
return m.Monitor
// CreateMeshParams contains the parameters required to create a mesh
type CreateMeshParams struct {
Port int
Conf *conf.WgConfiguration
}
// Prune implements MeshManager.
func (m *MeshManagerImpl) Prune() error {
for _, mesh := range m.Meshes {
err := mesh.Prune(m.conf.PruneTime)
// getConf: gets the new configuration with the base configuration overriden
// from the recent
func (m *MeshManagerImpl) getConf(override *conf.WgConfiguration) (*conf.WgConfiguration, error) {
meshConfiguration := m.conf.BaseConfiguration
if override != nil {
newConf, err := conf.MergeMeshConfiguration(meshConfiguration, *override)
if err != nil {
return err
return nil, err
}
meshConfiguration = newConf
}
return nil
return &meshConfiguration, nil
}
// CreateMesh: Creates a new mesh, stores it and returns the mesh id
func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
func (m *MeshManagerImpl) CreateMesh(args *CreateMeshParams) (string, error) {
meshConfiguration, err := m.getConf(args.Conf)
if err != nil {
return "", err
}
if *meshConfiguration.Role == conf.CLIENT_ROLE {
return "", fmt.Errorf("cannot create mesh as a client")
}
meshId, err := m.idGenerator.GetId()
var ifName string = ""
@ -132,8 +152,10 @@ func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
return "", err
}
m.cmdRunner.RunCommands(m.conf.BaseConfiguration.PreUp...)
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(port, m.HostParameters.PrivateKey)
ifName, err = m.interfaceManipulator.CreateInterface(args.Port, m.HostParameters.PrivateKey)
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
@ -141,12 +163,13 @@ func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
}
nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: ifName,
Port: port,
Conf: m.conf,
Client: m.Client,
MeshId: meshId,
NodeID: m.HostParameters.GetPublicKey(),
DevName: ifName,
Port: args.Port,
Conf: meshConfiguration,
Client: m.Client,
MeshId: meshId,
DaemonConf: m.conf,
NodeID: m.HostParameters.GetPublicKey(),
})
if err != nil {
@ -156,6 +179,9 @@ func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
m.lock.Lock()
m.Meshes[meshId] = nodeManager
m.lock.Unlock()
m.cmdRunner.RunCommands(m.conf.BaseConfiguration.PostUp...)
return meshId, nil
}
@ -163,6 +189,7 @@ type AddMeshParams struct {
MeshId string
WgPort int
MeshBytes []byte
Conf *conf.WgConfiguration
}
// AddMesh: Add the mesh to the list of meshes
@ -170,6 +197,14 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
var ifName string
var err error
meshConfiguration, err := m.getConf(params.Conf)
if err != nil {
return err
}
m.cmdRunner.RunCommands(meshConfiguration.PreUp...)
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(params.WgPort, m.HostParameters.PrivateKey)
@ -179,14 +214,17 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
}
meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: ifName,
Port: params.WgPort,
Conf: m.conf,
Client: m.Client,
MeshId: params.MeshId,
NodeID: m.HostParameters.GetPublicKey(),
DevName: ifName,
Port: params.WgPort,
Conf: meshConfiguration,
Client: m.Client,
MeshId: params.MeshId,
DaemonConf: m.conf,
NodeID: m.HostParameters.GetPublicKey(),
})
m.cmdRunner.RunCommands(meshConfiguration.PostUp...)
if err != nil {
return err
}
@ -249,17 +287,44 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
pubKey := s.HostParameters.PrivateKey.PublicKey()
nodeIP, err := s.ipAllocator.GetIP(pubKey, params.MeshId)
collisionCount := uint8(0)
if err != nil {
return err
var nodeIP net.IP
// Perform Duplicate Address Detection with the nodes
// that are already in the network
for {
generatedIP, err := s.ipAllocator.GetIP(pubKey, params.MeshId, collisionCount)
if err != nil {
return err
}
snapshot, err := mesh.GetMesh()
if err != nil {
return err
}
proposition := func(node MeshNode) bool {
ipNet := node.GetWgHost()
return ipNet.IP.Equal(nodeIP)
}
if lib.Contains(lib.MapValues(snapshot.GetNodes()), proposition) {
collisionCount++
} else {
nodeIP = generatedIP
break
}
}
node := s.nodeFactory.Build(&MeshNodeFactoryParams{
PublicKey: &pubKey,
NodeIP: nodeIP,
WgPort: params.WgPort,
Endpoint: params.Endpoint,
PublicKey: &pubKey,
NodeIP: nodeIP,
WgPort: params.WgPort,
Endpoint: params.Endpoint,
MeshConfig: mesh.GetConfiguration(),
})
if !s.conf.StubWg {
@ -277,7 +342,7 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
}
s.Meshes[params.MeshId].AddNode(node)
return s.RouteManager.UpdateRoutes()
return nil
}
// LeaveMesh leaves the mesh network
@ -288,13 +353,10 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
return fmt.Errorf("mesh %s does not exist", meshId)
}
var err error
s.RouteManager.RemoveRoutes(meshId)
err = mesh.RemoveNode(s.HostParameters.GetPublicKey())
err := mesh.RemoveNode(s.HostParameters.GetPublicKey())
if err != nil {
return err
logging.Log.WriteErrorf(err.Error())
}
if s.OnDelete != nil {
@ -305,6 +367,8 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
delete(s.Meshes, meshId)
s.lock.Unlock()
s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PreDown...)
if !s.conf.StubWg {
device, err := mesh.GetDevice()
@ -319,6 +383,7 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
}
}
s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PostDown...)
return err
}
@ -329,7 +394,6 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
return nil, fmt.Errorf("mesh %s does not exist", meshId)
}
logging.Log.WriteInfof(s.HostParameters.GetPublicKey())
node, err := meshInstance.GetNode(s.HostParameters.GetPublicKey())
if err != nil {
@ -344,43 +408,36 @@ func (s *MeshManagerImpl) ApplyConfig() error {
return nil
}
err := s.configApplyer.ApplyConfig()
if err != nil {
return err
}
return nil
return s.configApplyer.ApplyConfig()
}
func (s *MeshManagerImpl) SetDescription(description string) error {
meshes := s.GetMeshes()
for _, mesh := range meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetDescription(s.HostParameters.GetPublicKey(), description)
func (s *MeshManagerImpl) SetDescription(meshId, description string) error {
mesh := s.GetMesh(meshId)
if err != nil {
return err
}
}
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
return nil
if !mesh.NodeExists(s.HostParameters.GetPublicKey()) {
return fmt.Errorf("node %s does not exist in the mesh", meshId)
}
return mesh.SetDescription(s.HostParameters.GetPublicKey(), description)
}
// SetAlias implements MeshManager.
func (s *MeshManagerImpl) SetAlias(alias string) error {
meshes := s.GetMeshes()
for _, mesh := range meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetAlias(s.HostParameters.GetPublicKey(), alias)
func (s *MeshManagerImpl) SetAlias(meshId, alias string) error {
mesh := s.GetMesh(meshId)
if err != nil {
return err
}
}
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
return nil
if !mesh.NodeExists(s.HostParameters.GetPublicKey()) {
return fmt.Errorf("node %s does not exist in the mesh", meshId)
}
return mesh.SetAlias(s.HostParameters.GetPublicKey(), alias)
}
// UpdateTimeStamp updates the timestamp of this node in all meshes
@ -441,7 +498,7 @@ func (s *MeshManagerImpl) Close() error {
// NewMeshManagerParams params required to create an instance of a mesh manager
type NewMeshManagerParams struct {
Conf conf.WgMeshConfiguration
Conf conf.DaemonConfiguration
Client *wgctrl.Client
MeshProvider MeshProviderFactory
NodeFactory MeshNodeFactory
@ -450,6 +507,7 @@ type NewMeshManagerParams struct {
InterfaceManipulator wg.WgInterfaceManipulator
ConfigApplyer MeshConfigApplyer
RouteManager RouteManager
CommandRunner cmd.CmdRunner
OnDelete func(MeshProvider)
}
@ -476,15 +534,14 @@ func NewMeshManager(params *NewMeshManagerParams) MeshManager {
m.RouteManager = NewRouteManager(m)
}
if params.CommandRunner == nil {
m.cmdRunner = &cmd.UnixCmdRunner{}
}
m.idGenerator = params.IdGenerator
m.ipAllocator = params.IPAllocator
m.interfaceManipulator = params.InterfaceManipulator
m.Monitor = NewMeshMonitor(m)
aliasManager := NewAliasManager()
m.Monitor.AddUpdateCallback(aliasManager.AddAliases)
m.Monitor.AddRemoveCallback(aliasManager.RemoveAliases)
m.OnDelete = params.OnDelete
return m
}

View File

@ -3,22 +3,39 @@ package mesh
import (
"testing"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/wg"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/wg"
)
func getMeshConfiguration() *conf.WgMeshConfiguration {
return &conf.WgMeshConfiguration{
GrpcPort: "8080",
Endpoint: "abc.com",
ClusterSize: 64,
SyncRate: 4,
BranchRate: 3,
InterClusterChance: 0.15,
InfectionCount: 2,
KeepAliveTime: 60,
func getMeshConfiguration() *conf.DaemonConfiguration {
advertiseRoutes := true
advertiseDefaultRoute := true
ipDiscovery := conf.PUBLIC_IP_DISCOVERY
role := conf.PEER_ROLE
return &conf.DaemonConfiguration{
GrpcPort: 8080,
CertificatePath: "./somecertificatepath",
PrivateKeyPath: "./someprivatekeypath",
CaCertificatePath: "./somecacertificatepath",
SkipCertVerification: true,
Timeout: 5,
Profile: false,
StubWg: true,
SyncTime: 2,
HeartBeat: 60,
ClusterSize: 64,
InterClusterChance: 0.15,
BranchRate: 3,
InfectionCount: 3,
BaseConfiguration: conf.WgConfiguration{
IPDiscovery: &ipDiscovery,
AdvertiseRoutes: &advertiseRoutes,
AdvertiseDefaultRoute: &advertiseDefaultRoute,
Role: &role,
},
}
}
@ -41,7 +58,10 @@ func getMeshManager() MeshManager {
func TestCreateMeshCreatesANewMeshProvider(t *testing.T) {
manager := getMeshManager()
meshId, err := manager.CreateMesh("wg0", 5000)
meshId, err := manager.CreateMesh(&CreateMeshParams{
Port: 0,
Conf: &conf.WgConfiguration{},
})
if err != nil {
t.Error(err)
@ -128,7 +148,7 @@ func TestAddSelfAddsSelfToTheMesh(t *testing.T) {
t.Error(err)
}
_, ok := mesh.GetNodes()["abc.com"]
_, ok := mesh.GetNodes()[manager.GetPublicKey().String()]
if !ok {
t.Fatalf(`node has not been added`)
@ -193,36 +213,80 @@ func TestLeaveMeshDeletesMesh(t *testing.T) {
}
}
func TestSetDescription(t *testing.T) {
func TestSetAliasUpdatesAliasOfNode(t *testing.T) {
manager := getMeshManager()
alias := "Firpo"
meshId, _ := manager.CreateMesh(&CreateMeshParams{
Port: 5000,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
err := manager.SetAlias(meshId, alias)
if err != nil {
t.Fatalf(`failed to set the alias`)
}
self, err := manager.GetSelf(meshId)
if err != nil {
t.Fatalf(`failed to set the alias err: %s`, err.Error())
}
if alias != self.GetAlias() {
t.Fatalf(`alias should be %s was %s`, alias, self.GetAlias())
}
}
func TestSetDescriptionSetsTheDescriptionOfTheNode(t *testing.T) {
manager := getMeshManager()
description := "wooooo"
meshId1, _ := manager.CreateMesh(5000)
meshId2, _ := manager.CreateMesh(5001)
meshId1, _ := manager.CreateMesh(&CreateMeshParams{
Port: 5000,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId2,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
err := manager.SetDescription(description)
err := manager.SetDescription(meshId1, description)
if err != nil {
t.Fatalf(`failed to set the descriptions`)
}
}
self1, err := manager.GetSelf(meshId1)
if err != nil {
t.Fatalf(`failed to set the description`)
}
if description != self1.GetDescription() {
t.Fatalf(`description should be %s was %s`, description, self1.GetDescription())
}
}
func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) {
manager := getMeshManager()
meshId1, _ := manager.CreateMesh(5000)
meshId2, _ := manager.CreateMesh(5001)
meshId1, _ := manager.CreateMesh(&CreateMeshParams{
Port: 5000,
Conf: &conf.WgConfiguration{},
})
meshId2, _ := manager.CreateMesh(&CreateMeshParams{
Port: 5001,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,
@ -241,3 +305,68 @@ func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) {
t.Fatalf(`failed to update the timestamp`)
}
}
func TestAddServiceAddsServiceToTheMesh(t *testing.T) {
manager := getMeshManager()
meshId1, _ := manager.CreateMesh(&CreateMeshParams{
Port: 5000,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
serviceName := "hello"
manager.SetService(meshId1, serviceName, "dave")
self, err := manager.GetSelf(meshId1)
if err != nil {
t.Fatalf(`error thrown %s:`, err.Error())
}
if _, ok := self.GetServices()[serviceName]; !ok {
t.Fatalf(`service not added`)
}
}
func TestRemoveServiceRemovesTheServiceFromTheMesh(t *testing.T) {
manager := getMeshManager()
meshId1, _ := manager.CreateMesh(&CreateMeshParams{
Port: 5000,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
serviceName := "hello"
manager.SetService(meshId1, serviceName, "dave")
self, err := manager.GetSelf(meshId1)
if err != nil {
t.Fatalf(`error thrown %s:`, err.Error())
}
if _, ok := self.GetServices()[serviceName]; !ok {
t.Fatalf(`service not added`)
}
manager.RemoveService(meshId1, serviceName)
self, err = manager.GetSelf(meshId1)
if err != nil {
t.Fatalf(`error thrown %s:`, err.Error())
}
if _, ok := self.GetServices()[serviceName]; ok {
t.Fatalf(`service still exists`)
}
}

View File

@ -1,81 +0,0 @@
package mesh
type OnChange = func([]MeshNode)
type MeshMonitor interface {
AddUpdateCallback(cb OnChange)
AddRemoveCallback(cb OnChange)
Trigger() error
}
type MeshMonitorImpl struct {
updateCbs []OnChange
removeCbs []OnChange
nodes map[string]MeshNode
manager MeshManager
}
// Trigger causes the mesh monitor to trigger all of
// the callbacks.
func (m *MeshMonitorImpl) Trigger() error {
changedNodes := make([]MeshNode, 0)
removedNodes := make([]MeshNode, 0)
nodes := make(map[string]MeshNode)
for _, mesh := range m.manager.GetMeshes() {
snapshot, err := mesh.GetMesh()
if err != nil {
return err
}
for _, node := range snapshot.GetNodes() {
previous, exists := m.nodes[node.GetWgHost().String()]
if !exists || !NodeEquals(previous, node) {
changedNodes = append(changedNodes, node)
}
nodes[node.GetWgHost().String()] = node
}
}
for _, previous := range m.nodes {
_, ok := nodes[previous.GetWgHost().String()]
if !ok {
removedNodes = append(removedNodes, previous)
}
}
if len(removedNodes) > 0 {
for _, cb := range m.removeCbs {
cb(removedNodes)
}
}
if len(changedNodes) > 0 {
for _, cb := range m.updateCbs {
cb(changedNodes)
}
}
return nil
}
func (m *MeshMonitorImpl) AddUpdateCallback(cb OnChange) {
m.updateCbs = append(m.updateCbs, cb)
}
func (m *MeshMonitorImpl) AddRemoveCallback(cb OnChange) {
m.removeCbs = append(m.removeCbs, cb)
}
func NewMeshMonitor(manager MeshManager) MeshMonitor {
return &MeshMonitorImpl{
updateCbs: make([]OnChange, 0),
nodes: make(map[string]MeshNode),
manager: manager,
}
}

View File

@ -1,16 +0,0 @@
package mesh
import (
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
func pruneFunction(m MeshManager) lib.TimerFunc {
return func() error {
return m.Prune()
}
}
func NewPruner(m MeshManager, conf conf.WgMeshConfiguration) *lib.Timer {
return lib.NewTimer(pruneFunction(m), conf.PruneTime/2)
}

View File

@ -1,14 +1,14 @@
package mesh
import (
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"net"
"github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/smegmesh/pkg/lib"
)
type RouteManager interface {
UpdateRoutes() error
RemoveRoutes(meshId string) error
}
type RouteManagerImpl struct {
@ -17,74 +17,106 @@ type RouteManagerImpl struct {
func (r *RouteManagerImpl) UpdateRoutes() error {
meshes := r.meshManager.GetMeshes()
ulaBuilder := new(ip.ULABuilder)
routes := make(map[string][]Route)
for _, mesh1 := range meshes {
self, err := r.meshManager.GetSelf(mesh1.GetMeshId())
if !*mesh1.GetConfiguration().AdvertiseRoutes {
continue
}
self, err := mesh1.GetNode(r.meshManager.GetPublicKey().String())
if err != nil {
return err
}
pubKey, err := self.GetPublicKey()
if err != nil {
return err
if _, ok := routes[mesh1.GetMeshId()]; !ok {
routes[mesh1.GetMeshId()] = make([]Route, 0)
}
routes, err := mesh1.GetRoutes(pubKey.String())
if *mesh1.GetConfiguration().AdvertiseDefaultRoute {
_, ipv6Default, _ := net.ParseCIDR("::/0")
defaultRoute := &RouteStub{
Destination: ipv6Default,
HopCount: 0,
Path: []string{mesh1.GetMeshId()},
}
mesh1.AddRoutes(NodeID(self), defaultRoute)
routes[mesh1.GetMeshId()] = append(routes[mesh1.GetMeshId()], defaultRoute)
}
routeMap, err := mesh1.GetRoutes(NodeID(self))
if err != nil {
return err
}
for _, mesh2 := range meshes {
routeValues, ok := routes[mesh2.GetMeshId()]
if !ok {
routeValues = make([]Route, 0)
}
if mesh1 == mesh2 {
continue
}
ipNet, err := ulaBuilder.GetIPNet(mesh2.GetMeshId())
mesh1IpNet, _ := (&ip.ULABuilder{}).GetIPNet(mesh1.GetMeshId())
if err != nil {
logging.Log.WriteErrorf(err.Error())
return err
}
err = mesh2.AddRoutes(NodeID(self), append(lib.MapValues(routes), &RouteStub{
Destination: ipNet,
routeValues = append(routeValues, &RouteStub{
Destination: mesh1IpNet,
HopCount: 0,
Path: make([]string, 0),
})...)
Path: []string{mesh1.GetMeshId()},
})
if err != nil {
return err
}
routeValues = append(routeValues, lib.MapValues(routeMap)...)
mesh2IpNet, _ := (&ip.ULABuilder{}).GetIPNet(mesh2.GetMeshId())
routeValues = lib.Filter(routeValues, func(r Route) bool {
pathNotMesh := func(s string) bool {
return s == mesh2.GetMeshId()
}
// Remove any potential routing loops
return !r.GetDestination().IP.Equal(mesh2IpNet.IP) &&
!lib.Contains(r.GetPath()[1:], pathNotMesh)
})
routes[mesh2.GetMeshId()] = routeValues
}
}
return nil
}
// Calculate the set different of each, working out routes to remove and to keep.
for meshId, meshRoutes := range routes {
mesh := meshes[meshId]
// removeRoutes: removes all meshes we are no longer a part of
func (r *RouteManagerImpl) RemoveRoutes(meshId string) error {
ulaBuilder := new(ip.ULABuilder)
meshes := r.meshManager.GetMeshes()
ipNet, err := ulaBuilder.GetIPNet(meshId)
if err != nil {
return err
}
for _, mesh1 := range meshes {
self, err := r.meshManager.GetSelf(meshId)
self, err := mesh.GetNode(r.meshManager.GetPublicKey().String())
if err != nil {
return err
}
mesh1.RemoveRoutes(NodeID(self), ipNet.String())
toRemove := make([]Route, 0)
prevRoutes, err := mesh.GetRoutes(NodeID(self))
if err != nil {
return err
}
for _, route := range prevRoutes {
if !lib.Contains(meshRoutes, func(r Route) bool {
return RouteEquals(r, route)
}) {
toRemove = append(toRemove, route)
}
}
mesh.RemoveRoutes(NodeID(self), toRemove...)
mesh.AddRoutes(NodeID(self), meshRoutes...)
}
return nil
}

View File

@ -5,7 +5,8 @@ import (
"net"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/lib"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@ -19,6 +20,8 @@ type MeshNodeStub struct {
routes []Route
identifier string
description string
alias string
services map[string]string
}
// GetType implements MeshNode.
@ -27,13 +30,13 @@ func (*MeshNodeStub) GetType() conf.NodeType {
}
// GetServices implements MeshNode.
func (*MeshNodeStub) GetServices() map[string]string {
return make(map[string]string)
func (m *MeshNodeStub) GetServices() map[string]string {
return m.services
}
// GetAlias implements MeshNode.
func (*MeshNodeStub) GetAlias() string {
return ""
func (s *MeshNodeStub) GetAlias() string {
return s.alias
}
func (m *MeshNodeStub) GetHostEndpoint() string {
@ -81,9 +84,28 @@ type MeshProviderStub struct {
snapshot *MeshSnapshotStub
}
// GetConfiguration implements MeshProvider.
func (*MeshProviderStub) GetConfiguration() *conf.WgConfiguration {
advertiseRoutes := true
advertiseDefaultRoute := true
ipDiscovery := conf.PUBLIC_IP_DISCOVERY
role := conf.PEER_ROLE
return &conf.WgConfiguration{
IPDiscovery: &ipDiscovery,
AdvertiseRoutes: &advertiseRoutes,
AdvertiseDefaultRoute: &advertiseDefaultRoute,
Role: &role,
}
}
// Mark implements MeshProvider.
func (*MeshProviderStub) Mark(nodeId string) {
}
// RemoveNode implements MeshProvider.
func (*MeshProviderStub) RemoveNode(nodeId string) error {
panic("unimplemented")
return nil
}
func (*MeshProviderStub) GetRoutes(targetId string) (map[string]Route, error) {
@ -96,47 +118,71 @@ func (*MeshProviderStub) GetPeers() []string {
}
// GetNode implements MeshProvider.
func (*MeshProviderStub) GetNode(string) (MeshNode, error) {
return nil, nil
func (m *MeshProviderStub) GetNode(nodeId string) (MeshNode, error) {
return m.snapshot.nodes[nodeId], nil
}
// NodeExists implements MeshProvider.
func (*MeshProviderStub) NodeExists(string) bool {
return false
func (m *MeshProviderStub) NodeExists(nodeId string) bool {
return m.snapshot.nodes[nodeId] != nil
}
// AddService implements MeshProvider.
func (*MeshProviderStub) AddService(nodeId string, key string, value string) error {
func (m *MeshProviderStub) AddService(nodeId string, key string, value string) error {
node := (m.snapshot.nodes[nodeId]).(*MeshNodeStub)
node.services[key] = value
return nil
}
// RemoveService implements MeshProvider.
func (*MeshProviderStub) RemoveService(nodeId string, key string) error {
func (m *MeshProviderStub) RemoveService(nodeId string, key string) error {
node := (m.snapshot.nodes[nodeId]).(*MeshNodeStub)
delete(node.services, key)
return nil
}
// SetAlias implements MeshProvider.
func (*MeshProviderStub) SetAlias(nodeId string, alias string) error {
panic("unimplemented")
func (m *MeshProviderStub) SetAlias(nodeId string, alias string) error {
node := (m.snapshot.nodes[nodeId]).(*MeshNodeStub)
node.alias = alias
return nil
}
// AddRoutes implements
func (m *MeshProviderStub) AddRoutes(nodeId string, route ...Route) error {
node := (m.snapshot.nodes[nodeId]).(*MeshNodeStub)
node.routes = append(node.routes, route...)
return nil
}
// RemoveRoutes implements MeshProvider.
func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error {
func (m *MeshProviderStub) RemoveRoutes(nodeId string, route ...Route) error {
node := (m.snapshot.nodes[nodeId]).(*MeshNodeStub)
newRoutes := lib.Filter(node.routes, func(r1 Route) bool {
return !lib.Contains(route, func(r2 Route) bool {
return RouteEqual(r1, r2)
})
})
node.routes = newRoutes
return nil
}
// Prune implements MeshProvider.
func (*MeshProviderStub) Prune(pruneAmount int) error {
func (*MeshProviderStub) Prune() error {
return nil
}
// UpdateTimeStamp implements MeshProvider.
func (*MeshProviderStub) UpdateTimeStamp(nodeId string) error {
func (m *MeshProviderStub) UpdateTimeStamp(nodeId string) error {
node := (m.snapshot.nodes[nodeId]).(*MeshNodeStub)
node.timeStamp = time.Now().Unix()
return nil
}
func (s *MeshProviderStub) AddNode(node MeshNode) {
s.snapshot.nodes[node.GetHostEndpoint()] = node
pubKey, _ := node.GetPublicKey()
s.snapshot.nodes[pubKey.String()] = node
}
func (s *MeshProviderStub) GetMesh() (MeshSnapshot, error) {
@ -168,15 +214,13 @@ func (s *MeshProviderStub) HasChanges() bool {
return false
}
func (s *MeshProviderStub) AddRoutes(nodeId string, route ...Route) error {
return nil
}
func (s *MeshProviderStub) GetSyncer() MeshSyncer {
return nil
}
func (s *MeshProviderStub) SetDescription(nodeId string, description string) error {
meshNode := (s.snapshot.nodes[nodeId]).(*MeshNodeStub)
meshNode.description = description
return nil
}
@ -190,7 +234,7 @@ func (s *StubMeshProviderFactory) CreateMesh(params *MeshProviderFactoryParams)
}
type StubNodeFactory struct {
Config *conf.WgMeshConfiguration
Config *conf.DaemonConfiguration
}
func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode {
@ -199,12 +243,13 @@ func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode {
return &MeshNodeStub{
hostEndpoint: params.Endpoint,
publicKey: *params.PublicKey,
wgEndpoint: fmt.Sprintf("%s:%s", params.Endpoint, s.Config.GrpcPort),
wgEndpoint: fmt.Sprintf("%s:%d", params.Endpoint, s.Config.GrpcPort),
wgHost: wgHost,
timeStamp: time.Now().Unix(),
routes: make([]Route, 0),
identifier: "abc",
description: "A Mesh Node Stub",
services: make(map[string]string),
}
}
@ -227,37 +272,32 @@ type MeshManagerStub struct {
// GetRouteManager implements MeshManager.
func (*MeshManagerStub) GetRouteManager() RouteManager {
panic("unimplemented")
return nil
}
// GetNode implements MeshManager.
func (*MeshManagerStub) GetNode(string, string) MeshNode {
panic("unimplemented")
func (*MeshManagerStub) GetNode(meshId, nodeId string) MeshNode {
return nil
}
// RemoveService implements MeshManager.
func (*MeshManagerStub) RemoveService(service string) error {
panic("unimplemented")
func (*MeshManagerStub) RemoveService(meshId, service string) error {
return nil
}
// SetService implements MeshManager.
func (*MeshManagerStub) SetService(service string, value string) error {
panic("unimplemented")
}
// GetMonitor implements MeshManager.
func (*MeshManagerStub) GetMonitor() MeshMonitor {
panic("unimplemented")
func (*MeshManagerStub) SetService(meshId, service, value string) error {
return nil
}
// SetAlias implements MeshManager.
func (*MeshManagerStub) SetAlias(alias string) error {
panic("unimplemented")
func (*MeshManagerStub) SetAlias(meshId, alias string) error {
return nil
}
// Close implements MeshManager.
func (*MeshManagerStub) Close() error {
panic("unimplemented")
return nil
}
// Prune implements MeshManager.
@ -269,7 +309,7 @@ func NewMeshManagerStub() MeshManager {
return &MeshManagerStub{meshes: make(map[string]MeshProvider)}
}
func (m *MeshManagerStub) CreateMesh(port int) (string, error) {
func (m *MeshManagerStub) CreateMesh(*CreateMeshParams) (string, error) {
return "tim123", nil
}
@ -309,7 +349,7 @@ func (m *MeshManagerStub) ApplyConfig() error {
return nil
}
func (m *MeshManagerStub) SetDescription(description string) error {
func (m *MeshManagerStub) SetDescription(meshId, description string) error {
return nil
}

View File

@ -4,8 +4,9 @@ package mesh
import (
"net"
"slices"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/conf"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@ -19,6 +20,18 @@ type Route interface {
GetPath() []string
}
func RouteEqual(r1 Route, r2 Route) bool {
return r1.GetDestination().IP.Equal(r2.GetDestination().IP) &&
r1.GetHopCount() == r2.GetHopCount() &&
slices.Equal(r1.GetPath(), r2.GetPath())
}
func RouteEquals(r1, r2 Route) bool {
return r1.GetDestination().String() == r2.GetDestination().String() &&
r1.GetHopCount() == r2.GetHopCount() &&
slices.Equal(r1.GetPath(), r2.GetPath())
}
type RouteStub struct {
Destination *net.IPNet
HopCount int
@ -68,12 +81,11 @@ func NodeEquals(node1, node2 MeshNode) bool {
key1, _ := node1.GetPublicKey()
key2, _ := node2.GetPublicKey()
return key1.String() == key2.String()
}
if node1 == nil || node2 == nil {
return false
}
func RouteEquals(route1, route2 Route) bool {
return route1.GetDestination().String() == route2.GetDestination().String() &&
route1.GetHopCount() == route2.GetHopCount()
return key1.String() == key2.String()
}
func NodeID(node MeshNode) string {
@ -116,7 +128,7 @@ type MeshProvider interface {
// AddRoutes: adds routes to the given node
AddRoutes(nodeId string, route ...Route) error
// DeleteRoutes: deletes the routes from the node
RemoveRoutes(nodeId string, route ...string) error
RemoveRoutes(nodeId string, route ...Route) error
// GetSyncer: returns the automerge syncer for sync
GetSyncer() MeshSyncer
// GetNode get a particular not within the mesh
@ -131,15 +143,21 @@ type MeshProvider interface {
AddService(nodeId, key, value string) error
// RemoveService: removes the service form the node. throws an error if the service does not exist
RemoveService(nodeId, key string) error
// Prune: prunes all nodes that have not updated their timestamp in
// pruneAmount seconds
Prune(pruneAmount int) error
// Prune: prunes all nodes that have not updated their
// vector clock
Prune() error
// GetPeers: get a list of contactable peers
GetPeers() []string
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen
GetRoutes(targetNode string) (map[string]Route, error)
// RemoveNode(): remove the node from the mesh
RemoveNode(nodeId string) error
// Mark: marks the node as unreachable. This is not broadcast to the entire
// this is not considered when syncing node state
Mark(nodeId string)
// GetConfiguration: gets the configuration parameters specific for this
// mesh network
GetConfiguration() *conf.WgConfiguration
}
// HostParameters contains the IDs of a node
@ -154,12 +172,13 @@ func (h *HostParameters) GetPublicKey() string {
// MeshProviderFactoryParams parameters required to build a mesh provider
type MeshProviderFactoryParams struct {
DevName string
MeshId string
Port int
Conf *conf.WgMeshConfiguration
Client *wgctrl.Client
NodeID string
DevName string
MeshId string
Port int
Conf *conf.WgConfiguration
DaemonConf *conf.DaemonConfiguration
Client *wgctrl.Client
NodeID string
}
// MeshProviderFactory creates an instance of a mesh provider
@ -170,10 +189,11 @@ type MeshProviderFactory interface {
// MeshNodeFactoryParams are the parameters required to construct
// a mesh node
type MeshNodeFactoryParams struct {
PublicKey *wgtypes.Key
NodeIP net.IP
WgPort int
Endpoint string
PublicKey *wgtypes.Key
NodeIP net.IP
WgPort int
Endpoint string
MeshConfig *conf.WgConfiguration
}
// MeshBuilder build the hosts mesh node for it to be added to the mesh

View File

@ -6,9 +6,9 @@ import (
"strings"
"github.com/jmespath/go-jmespath"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/mesh"
)
// Querier queries a data store for the given data

View File

@ -2,26 +2,51 @@ package robin
import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
"time"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query"
"github.com/tim-beatham/wgmesh/pkg/rpc"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/smegmesh/pkg/ipc"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/rpc"
)
type IpcHandler struct {
Server ctrlserver.CtrlServer
}
func getOverrideConfiguration(args *ipc.WireGuardArgs) conf.WgConfiguration {
overrideConf := conf.WgConfiguration{}
if args.Role != "" {
role := conf.NodeType(args.Role)
overrideConf.Role = &role
}
if args.Endpoint != "" {
overrideConf.Endpoint = &args.Endpoint
}
if args.KeepAliveWg != 0 {
keepAliveWg := args.KeepAliveWg
overrideConf.KeepAliveWg = &keepAliveWg
}
overrideConf.AdvertiseRoutes = &args.AdvertiseRoutes
overrideConf.AdvertiseDefaultRoute = &args.AdvertiseDefaultRoute
return overrideConf
}
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
meshId, err := n.Server.GetMeshManager().CreateMesh(args.WgPort)
overrideConf := getOverrideConfiguration(&args.WgArgs)
meshId, err := n.Server.GetMeshManager().CreateMesh(&mesh.CreateMeshParams{
Port: args.WgArgs.WgPort,
Conf: &overrideConf,
})
if err != nil {
return err
@ -29,8 +54,8 @@ func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
err = n.Server.GetMeshManager().AddSelf(&mesh.AddSelfParams{
MeshId: meshId,
WgPort: args.WgPort,
Endpoint: args.Endpoint,
WgPort: args.WgArgs.WgPort,
Endpoint: args.WgArgs.Endpoint,
})
if err != nil {
@ -45,7 +70,7 @@ func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
meshNames := make([]string, len(n.Server.GetMeshManager().GetMeshes()))
i := 0
for meshId, _ := range n.Server.GetMeshManager().GetMeshes() {
for meshId := range n.Server.GetMeshManager().GetMeshes() {
meshNames[i] = meshId
i++
}
@ -54,8 +79,14 @@ func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
return nil
}
func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAdress)
func (n *IpcHandler) JoinMesh(args *ipc.JoinMeshArgs, reply *string) error {
overrideConf := getOverrideConfiguration(&args.WgArgs)
if n.Server.GetMeshManager().GetMesh(args.MeshId) != nil {
return fmt.Errorf("user is already apart of the mesh")
}
peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAddress)
if err != nil {
return err
@ -86,8 +117,9 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
MeshId: args.MeshId,
WgPort: args.Port,
WgPort: args.WgArgs.WgPort,
MeshBytes: meshReply.Mesh,
Conf: &overrideConf,
})
if err != nil {
@ -96,8 +128,8 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
err = n.Server.GetMeshManager().AddSelf(&mesh.AddSelfParams{
MeshId: args.MeshId,
WgPort: args.Port,
Endpoint: args.Endpoint,
WgPort: args.WgArgs.WgPort,
Endpoint: args.WgArgs.Endpoint,
})
if err != nil {
@ -115,7 +147,6 @@ func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error {
if err == nil {
*reply = fmt.Sprintf("Left Mesh %s", meshId)
}
return err
}
@ -140,30 +171,9 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
i := 0
for _, node := range meshSnapshot.GetNodes() {
pubKey, _ := node.GetPublicKey()
node := ctrlserver.NewCtrlNode(theMesh, node)
if err != nil {
return err
}
node := ctrlserver.MeshNode{
HostEndpoint: node.GetHostEndpoint(),
WgEndpoint: node.GetWgEndpoint(),
PublicKey: pubKey.String(),
WgHost: node.GetWgHost().String(),
Timestamp: node.GetTimeStamp(),
Routes: lib.Map(node.GetRoutes(), func(r mesh.Route) ctrlserver.MeshRoute {
return ctrlserver.MeshRoute{
Destination: r.GetDestination().String(),
Path: r.GetPath(),
}
}),
Description: node.GetDescription(),
Alias: node.GetAlias(),
Services: node.GetServices(),
}
nodes[i] = node
nodes[i] = *node
i += 1
}
@ -171,19 +181,6 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
return nil
}
func (n *IpcHandler) GetDOT(meshId string, reply *string) error {
g := mesh.NewMeshDotConverter(n.Server.GetMeshManager())
result, err := g.Generate(meshId)
if err != nil {
return err
}
*reply = result
return nil
}
func (n *IpcHandler) Query(params ipc.QueryMesh, reply *string) error {
queryResponse, err := n.Server.GetQuerier().Query(params.MeshId, params.Query)
@ -195,30 +192,34 @@ func (n *IpcHandler) Query(params ipc.QueryMesh, reply *string) error {
return nil
}
func (n *IpcHandler) PutDescription(description string, reply *string) error {
err := n.Server.GetMeshManager().SetDescription(description)
func (n *IpcHandler) PutDescription(args ipc.PutDescriptionArgs, reply *string) error {
err := n.Server.GetMeshManager().SetDescription(args.MeshId, args.Description)
if err != nil {
return err
}
*reply = fmt.Sprintf("Set description to %s", description)
*reply = fmt.Sprintf("set description to %s for %s", args.Description, args.MeshId)
return nil
}
func (n *IpcHandler) PutAlias(alias string, reply *string) error {
err := n.Server.GetMeshManager().SetAlias(alias)
func (n *IpcHandler) PutAlias(args ipc.PutAliasArgs, reply *string) error {
if args.Alias == "" {
return fmt.Errorf("alias not provided")
}
err := n.Server.GetMeshManager().SetAlias(args.MeshId, args.Alias)
if err != nil {
return err
}
*reply = fmt.Sprintf("Set alias to %s", alias)
*reply = fmt.Sprintf("Set alias to %s", args.Alias)
return nil
}
func (n *IpcHandler) PutService(service ipc.PutServiceArgs, reply *string) error {
err := n.Server.GetMeshManager().SetService(service.Service, service.Value)
err := n.Server.GetMeshManager().SetService(service.MeshId, service.Service, service.Value)
if err != nil {
return err
@ -228,8 +229,8 @@ func (n *IpcHandler) PutService(service ipc.PutServiceArgs, reply *string) error
return nil
}
func (n *IpcHandler) DeleteService(service string, reply *string) error {
err := n.Server.GetMeshManager().RemoveService(service)
func (n *IpcHandler) DeleteService(service ipc.DeleteServiceArgs, reply *string) error {
err := n.Server.GetMeshManager().RemoveService(service.MeshId, service.Service)
if err != nil {
return err
@ -239,27 +240,6 @@ func (n *IpcHandler) DeleteService(service string, reply *string) error {
return nil
}
func (n *IpcHandler) GetNode(args ipc.GetNodeArgs, reply *string) error {
node := n.Server.GetMeshManager().GetNode(args.MeshId, args.NodeId)
if node == nil {
*reply = "nil"
return nil
}
queryNode := query.MeshNodeToQueryNode(node)
bytes, err := json.Marshal(queryNode)
if err != nil {
*reply = err.Error()
return nil
}
*reply = string(bytes)
return nil
}
type RobinIpcParams struct {
CtrlServer ctrlserver.CtrlServer
}

View File

@ -3,9 +3,10 @@ package robin
import (
"testing"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/smegmesh/pkg/ipc"
"github.com/tim-beatham/smegmesh/pkg/mesh"
)
func getRequester() *IpcHandler {
@ -17,9 +18,11 @@ func TestCreateMeshRepliesMeshId(t *testing.T) {
requester := getRequester()
err := requester.CreateMesh(&ipc.NewMeshArgs{
IfName: "wg0",
WgPort: 5000,
Endpoint: "abc.com",
WgArgs: ipc.WireGuardArgs{
WgPort: 500,
Endpoint: "abc.com:1234",
Role: "peer",
},
}, &reply)
if err != nil {
@ -52,9 +55,8 @@ func TestListMeshesMeshesNotEmpty(t *testing.T) {
requester.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
MeshId: "tim123",
DevName: "wg0",
WgPort: 5000,
MeshBytes: make([]byte, 0),
Conf: &conf.WgConfiguration{},
})
err := requester.ListMeshes("", &reply)

View File

@ -4,8 +4,8 @@ import (
"context"
"errors"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/rpc"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/smegmesh/pkg/rpc"
)
type WgRpc struct {

View File

@ -1 +0,0 @@
package robin

View File

@ -1,7 +1,7 @@
package route
import (
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/lib"
"golang.org/x/sys/unix"
)

View File

@ -1,20 +1,22 @@
package sync
import (
"fmt"
"io"
"math/rand"
"sync"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/mesh"
)
// Syncer: picks random nodes from the mesh
// Syncer: picks random nodes from the meshs
type Syncer interface {
Sync(meshId string) error
Sync(theMesh mesh.MeshProvider) error
SyncMeshes() error
}
@ -24,89 +26,186 @@ type SyncerImpl struct {
infectionCount int
syncCount int
cluster conn.ConnCluster
conf *conf.WgMeshConfiguration
conf *conf.DaemonConfiguration
lastSync map[string]int64
lock sync.RWMutex
}
// Sync: Sync random nodes
func (s *SyncerImpl) Sync(meshId string) error {
if !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
logging.Log.WriteInfof("No changes for %s", meshId)
// Sync: Sync with random nodes
func (s *SyncerImpl) Sync(correspondingMesh mesh.MeshProvider) error {
if correspondingMesh == nil {
return fmt.Errorf("mesh provided was nil cannot sync nil mesh")
}
// Self can be nil if the node is removed
selfID := s.manager.GetPublicKey()
self, err := correspondingMesh.GetNode(selfID.String())
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
correspondingMesh.Prune()
if correspondingMesh.HasChanges() {
logging.Log.WriteInfof("meshes %s has changes", correspondingMesh.GetMeshId())
}
// If removed sync with other nodes to gossip the node is removed
if self != nil && self.GetType() == conf.PEER_ROLE && !correspondingMesh.HasChanges() && s.infectionCount == 0 {
logging.Log.WriteInfof("no changes for %s", correspondingMesh.GetMeshId())
// If not synchronised in certain time pull from random neighbour
if s.conf.PullTime != 0 && time.Now().Unix()-s.lastSync[correspondingMesh.GetMeshId()] > int64(s.conf.PullTime) {
return s.Pull(self, correspondingMesh)
}
return nil
}
logging.Log.WriteInfof("UPDATING WG CONF")
s.manager.GetRouteManager().UpdateRoutes()
err := s.manager.ApplyConfig()
before := time.Now()
err = s.manager.GetRouteManager().UpdateRoutes()
if err != nil {
logging.Log.WriteInfof("Failed to update config %w", err)
logging.Log.WriteErrorf(err.Error())
}
publicKey := s.manager.GetPublicKey()
nodeNames := correspondingMesh.GetPeers()
logging.Log.WriteInfof(publicKey.String())
nodeNames = lib.Filter(nodeNames, func(s string) bool {
// Filter our only public key out so we dont sync with ourself
return s != publicKey.String()
})
nodeNames := s.manager.GetMesh(meshId).GetPeers()
neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
var gossipNodes []string
for _, node := range randomSubset {
logging.Log.WriteInfof("Random node: %s", node)
// Clients always pings its peer for configuration
if self != nil && self.GetType() == conf.CLIENT_ROLE && len(nodeNames) > 1 {
neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
if len(neighbours) == 0 {
return nil
}
// Peer with 2 nodes so that there is redundnacy in
// the situation the node leaves pre-emptively
redundancyLength := min(len(neighbours), 2)
gossipNodes = neighbours[:redundancyLength]
} else {
neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
gossipNodes = lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
gossipNodes[len(gossipNodes)-1] = s.cluster.GetInterCluster(nodeNames, publicKey.String())
}
}
before := time.Now()
var succeeded bool = false
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
logging.Log.WriteInfof("Sending to random cluster")
interCluster := s.cluster.GetInterCluster(nodeNames, publicKey.String())
randomSubset = append(randomSubset, interCluster)
// Do this synchronously to conserve bandwidth
for _, node := range gossipNodes {
correspondingPeer, err := correspondingMesh.GetNode(node)
if correspondingPeer == nil || err != nil {
logging.Log.WriteErrorf("node %s does not exist", node)
continue
}
err = s.requester.SyncMesh(correspondingMesh.GetMeshId(), correspondingPeer)
if err == nil || err == io.EOF {
succeeded = true
}
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
}
var waitGroup sync.WaitGroup
for index := range randomSubset {
waitGroup.Add(1)
go func(i int) error {
defer waitGroup.Done()
correspondingPeer := s.manager.GetNode(meshId, randomSubset[i])
if correspondingPeer == nil {
logging.Log.WriteErrorf("node %s does not exist", randomSubset[i])
}
err := s.requester.SyncMesh(meshId, correspondingPeer.GetHostEndpoint())
return err
}(index)
}
waitGroup.Wait()
s.syncCount++
logging.Log.WriteInfof("SYNC TIME: %v", time.Since(before))
logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount)
logging.Log.WriteInfof("sync time: %v", time.Since(before))
logging.Log.WriteInfof("number of syncs: %d", s.syncCount)
s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount)
s.manager.GetMesh(meshId).SaveChanges()
if !succeeded {
s.infectionCount++
}
correspondingMesh.SaveChanges()
s.lock.Lock()
s.lastSync[correspondingMesh.GetMeshId()] = time.Now().Unix()
s.lock.Unlock()
return nil
}
// Pull one node in the cluster, if there has not been message dissemination
// in a certain period of time pull a random node within the cluster
func (s *SyncerImpl) Pull(self mesh.MeshNode, mesh mesh.MeshProvider) error {
peers := mesh.GetPeers()
pubKey, _ := self.GetPublicKey()
neighbours := s.cluster.GetNeighbours(peers, pubKey.String())
neighbour := lib.RandomSubsetOfLength(neighbours, 1)
if len(neighbour) == 0 {
logging.Log.WriteInfof("no neighbours")
return nil
}
logging.Log.WriteInfof("pulling from node %s", neighbour[0])
pullNode, err := mesh.GetNode(neighbour[0])
if err != nil || pullNode == nil {
return fmt.Errorf("node %s does not exist in the mesh", neighbour[0])
}
err = s.requester.SyncMesh(mesh.GetMeshId(), pullNode)
if err == nil || err == io.EOF {
s.lastSync[mesh.GetMeshId()] = time.Now().Unix()
} else {
return err
}
s.syncCount++
return nil
}
// SyncMeshes: Sync all meshes
func (s *SyncerImpl) SyncMeshes() error {
for meshId := range s.manager.GetMeshes() {
err := s.Sync(meshId)
var wg sync.WaitGroup
if err != nil {
return err
for _, mesh := range s.manager.GetMeshes() {
wg.Add(1)
sync := func() {
defer wg.Done()
err := s.Sync(mesh)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
}
}
go sync()
}
wg.Wait()
logging.Log.WriteInfof("updating the WireGuard configuration")
err := s.manager.ApplyConfig()
if err != nil {
logging.Log.WriteInfof("failed to update config %w", err)
}
return nil
}
func NewSyncer(m mesh.MeshManager, conf *conf.WgMeshConfiguration, r SyncRequester) Syncer {
func NewSyncer(m mesh.MeshManager, conf *conf.DaemonConfiguration, r SyncRequester) Syncer {
cluster, _ := conn.NewConnCluster(conf.ClusterSize)
return &SyncerImpl{
manager: m,
@ -114,5 +213,6 @@ func NewSyncer(m mesh.MeshManager, conf *conf.WgMeshConfiguration, r SyncRequest
requester: r,
infectionCount: 0,
syncCount: 0,
cluster: cluster}
cluster: cluster,
lastSync: make(map[string]int64)}
}

View File

@ -1,8 +1,9 @@
package sync
import (
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/conn"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
@ -15,38 +16,52 @@ type SyncErrorHandler interface {
// SyncErrorHandlerImpl Is an implementation of the SyncErrorHandler
type SyncErrorHandlerImpl struct {
meshManager mesh.MeshManager
connManager conn.ConnectionManager
}
func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint string) bool {
func (s *SyncErrorHandlerImpl) handleFailed(meshId string, nodeId string) bool {
mesh := s.meshManager.GetMesh(meshId)
mesh.Mark(nodeId)
node, err := mesh.GetNode(nodeId)
if mesh == nil {
return false
if err != nil {
s.connManager.RemoveConnection(node.GetHostEndpoint())
}
// self, err := s.meshManager.GetSelf(meshId)
// if err != nil {
// return false
// }
// mesh.DecrementHealth(endpoint, self.GetHostEndpoint())
return true
}
func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error) bool {
func (s *SyncErrorHandlerImpl) handleDeadlineExceeded(meshId string, nodeId string) bool {
mesh := s.meshManager.GetMesh(meshId)
if mesh == nil {
return true
}
node, err := mesh.GetNode(nodeId)
if err != nil {
return false
}
s.connManager.RemoveConnection(node.GetHostEndpoint())
return true
}
func (s *SyncErrorHandlerImpl) Handle(meshId string, nodeId string, err error) bool {
errStatus, _ := status.FromError(err)
logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message())
switch errStatus.Code() {
case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound:
return s.incrementFailedCount(meshId, endpoint)
case codes.Unavailable, codes.Unknown, codes.Internal, codes.NotFound:
return s.handleFailed(meshId, nodeId)
case codes.DeadlineExceeded:
return s.handleDeadlineExceeded(meshId, nodeId)
}
return false
}
func NewSyncErrorHandler(m mesh.MeshManager) SyncErrorHandler {
return &SyncErrorHandlerImpl{meshManager: m}
func NewSyncErrorHandler(m mesh.MeshManager, conn conn.ConnectionManager) SyncErrorHandler {
return &SyncErrorHandlerImpl{meshManager: m, connManager: conn}
}

View File

@ -6,16 +6,16 @@ import (
"io"
"time"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/rpc"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/rpc"
)
// SyncRequester: coordinates the syncing of meshes
type SyncRequester interface {
GetMesh(meshId string, ifName string, port int, endPoint string) error
SyncMesh(meshid string, endPoint string) error
SyncMesh(meshid string, meshNode mesh.MeshNode) error
}
type SyncRequesterImpl struct {
@ -56,8 +56,8 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endP
return err
}
func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error {
ok := s.errorHdlr.Handle(meshId, endpoint, err)
func (s *SyncRequesterImpl) handleErr(meshId, pubKey string, err error) error {
ok := s.errorHdlr.Handle(meshId, pubKey, err)
if ok {
return nil
@ -67,7 +67,10 @@ func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error
}
// SyncMesh: Proactively send a sync request to the other mesh
func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) error {
endpoint := meshNode.GetHostEndpoint()
pubKey, _ := meshNode.GetPublicKey()
peerConnection, err := s.server.ConnectionManager.GetConnection(endpoint)
if err != nil {
@ -88,7 +91,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
c := rpc.NewSyncServiceClient(client)
syncTimeOut := s.server.Conf.SyncRate * float64(time.Second)
syncTimeOut := float64(s.server.Conf.SyncTime) * float64(time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(syncTimeOut))
defer cancel()
@ -96,11 +99,11 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
err = s.syncMesh(mesh, ctx, c)
if err != nil {
return s.handleErr(meshId, endpoint, err)
s.handleErr(meshId, pubKey.String(), err)
}
logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId)
return nil
return err
}
func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context, client rpc.SyncServiceClient) error {
@ -148,6 +151,6 @@ func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context
}
func NewSyncRequester(s *ctrlserver.MeshCtrlServer) SyncRequester {
errorHdlr := NewSyncErrorHandler(s.MeshManager)
errorHdlr := NewSyncErrorHandler(s.MeshManager, s.ConnectionManager)
return &SyncRequesterImpl{server: s, errorHdlr: errorHdlr}
}

View File

@ -1,17 +1,18 @@
package sync
import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/smegmesh/pkg/lib"
)
// Run implements SyncScheduler.
func syncFunction(syncer Syncer) lib.TimerFunc {
return func() error {
return syncer.SyncMeshes()
syncer.SyncMeshes()
return nil
}
}
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester, syncer Syncer) *lib.Timer {
return lib.NewTimer(syncFunction(syncer), int(s.Conf.SyncRate))
return lib.NewTimer(syncFunction(syncer), s.Conf.SyncTime)
}

View File

@ -6,9 +6,9 @@ import (
"errors"
"io"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/rpc"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/rpc"
)
type SyncServiceImpl struct {

View File

@ -1,9 +1,9 @@
package timer
import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/smegmesh/pkg/log"
)
func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
@ -11,6 +11,5 @@ func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
logging.Log.WriteInfof("Updated Timestamp")
return ctrlServer.MeshManager.UpdateTimeStamp()
}
return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime)
return *lib.NewTimer(timerFunc, ctrlServer.Conf.HeartBeat)
}

View File

@ -1,15 +1,20 @@
package wg
import "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
type WgInterfaceManipulatorStub struct{}
func (i *WgInterfaceManipulatorStub) CreateInterface(port int) (string, error) {
return "", nil
// CreateInterface creates a WireGuard interface
func (w *WgInterfaceManipulatorStub) CreateInterface(port int, privateKey *wgtypes.Key) (string, error) {
return "aninterface", nil
}
func (i *WgInterfaceManipulatorStub) AddAddress(ifName string, addr string) error {
// AddAddress adds an address to the given interface name
func (w *WgInterfaceManipulatorStub) AddAddress(ifName string, addr string) error {
return nil
}
func (i *WgInterfaceManipulatorStub) RemoveInterface(ifName string) error {
// RemoveInterface removes the specified interface
func (w *WgInterfaceManipulatorStub) RemoveInterface(ifName string) error {
return nil
}

View File

@ -2,14 +2,6 @@ package wg
import "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
type WgError struct {
msg string
}
func (m *WgError) Error() string {
return m.msg
}
type WgInterfaceManipulator interface {
// CreateInterface creates a WireGuard interface
CreateInterface(port int, privateKey *wgtypes.Key) (string, error)
@ -18,3 +10,11 @@ type WgInterfaceManipulator interface {
// RemoveInterface removes the specified interface
RemoveInterface(ifName string) error
}
type WgError struct {
msg string
}
func (m *WgError) Error() string {
return m.msg
}

View File

@ -3,11 +3,10 @@ package wg
import (
"crypto"
"crypto/rand"
"encoding/base64"
"fmt"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@ -35,8 +34,7 @@ func (m *WgInterfaceManipulatorImpl) CreateInterface(port int, privKey *wgtypes.
}
md5 := crypto.MD5.New().Sum(randomBuf)
md5Str := fmt.Sprintf("wg%s", base64.StdEncoding.EncodeToString(md5)[:hashLength])
md5Str := fmt.Sprintf("wg%x", md5)[:hashLength]
err = rtnl.CreateLink(md5Str)