1
0
forked from extern/smegmesh

Compare commits

...

59 Commits

Author SHA1 Message Date
9a30f4d5cb Submitting 2024-01-05 18:22:05 +00:00
f647c1b806 81-seperate-processes
Prep for submission
2024-01-05 16:59:02 +00:00
a55dadf088 81-seperate-synchronisation-into-independent-procs
- Neaten code
2024-01-05 12:59:13 +00:00
0ec5156e59 81-procs
- fixed issue where route not deleting if mesh only one
2024-01-05 00:14:25 +00:00
2b73d241b6 81-serparate-procs
- nil dereference again
2024-01-04 22:29:30 +00:00
69b1790bb6 81-processes
- issue with client client traversal
2024-01-04 22:08:14 +00:00
4a92743880 81-seperate-sync
- build error
2024-01-04 21:48:54 +00:00
038393052c 81-seperate-synchronisation-into-independent-proc
- build error
2024-01-04 21:47:29 +00:00
5efff2314b 81-separate-synchronisation-into-independent-process
- nil dereference when no joins
2024-01-04 21:45:28 +00:00
1f8d229076 81-seperate-synchronisation-into-independent-process
- nil dereference due to concurrency issues (the method shouldn't be
  concurrent)
2024-01-04 21:16:33 +00:00
a0e7a4a644 81-seperateprocesses-into-independent-processes
- Fixed errors
2024-01-04 13:15:29 +00:00
f9b8b85ec3 81-seperate-synchronisation
- Removed authentication.proto
2024-01-04 13:12:33 +00:00
59d8ae4334 81-seperate-synchronisation
- More code comments
2024-01-04 13:12:07 +00:00
02dfd73e08 81-seperate-synchronisation-into-independent
- Separated synchronisation calls into independent processes
- Commented code for submission
2024-01-04 13:10:08 +00:00
9818645299 Merge pull request #82 from tim-beatham/bugfix-node-not-leving
bugfix-node-not-leaving
2024-01-04 00:24:58 +00:00
1f0914e2df bugfix-node-not-leaving
- Add lock when perform synchronisation on concurrent access
2024-01-04 00:23:20 +00:00
efb40d65de Merge pull request #80 from tim-beatham/bugfix-node-not-leving
main
2024-01-02 20:32:09 +00:00
27e00196cd main
- Not waiting in the waitgroup
2024-01-02 20:31:24 +00:00
4543205703 Merge pull request #79 from tim-beatham/bugfix-node-not-leving
main
2024-01-02 20:21:27 +00:00
dea6f1a22d main
- error in code invalid check for nil
2024-01-02 20:19:34 +00:00
4d19da6727 Merge pull request #78 from tim-beatham/bugfix-node-not-leving
main
2024-01-02 20:12:10 +00:00
913de57568 main
- Fixed bug
2024-01-02 20:11:11 +00:00
8a5673e303 Merge pull request #77 from tim-beatham/bugfix-node-not-leving
bugfix node not leaving
2024-01-02 19:43:04 +00:00
ce829114b1 bugfix
- on synchornisation node is not leaving mesh
2024-01-02 19:41:20 +00:00
05cc287e31 Merge pull request #76 from tim-beatham/74-perform-dad
- Fixing DNS error
2024-01-02 00:16:45 +00:00
cd844ff46e - Fixing DNS error 2024-01-02 00:15:23 +00:00
4b9406a920 Merge pull request #75 from tim-beatham/74-perform-dad
74-perform-dad
2024-01-02 00:14:37 +00:00
d0b1913796 74-perform-dad
- Fixing nil pointer dereference
2024-01-02 00:13:04 +00:00
90cfe820d2 - Fixing errors with stale paths 2024-01-02 00:09:31 +00:00
8a49809855 74-perform-dad
- Adding go.sum to fix errors
2024-01-01 23:59:04 +00:00
dbc18bddc6 74-perform-dad
- Performing DAD to check if IPv6 address present before adding
  outselves to mesh
- Changing name from wgmesh to smegmesh
2024-01-01 23:55:50 +00:00
14f335af74 Merge pull request #73 from tim-beatham/72-pull-rate-in-configuration
72 pull rate in configuration
2023-12-31 14:26:34 +00:00
36e82dba47 72-pull-rate-in-configuration
- Refactored pull rate into the configuration
- code freeze so no more code changes
2023-12-31 14:25:06 +00:00
3cc87bc252 72-pull-rate-in-configuration
- Updated examples
2023-12-31 12:47:45 +00:00
a9ed7c0a20 72-pull-rate-in-configuration
- Removing libp2p reference
2023-12-31 12:47:45 +00:00
fd29af73e3 72-pull-rate-in-configuration
- Added pull rate to configuration (finally) so this can
be modified by an administrator.
2023-12-31 12:47:45 +00:00
9e1058e0f2 72-pull-rate-in-configuration
- Added the pull rate to the configuration file
2023-12-31 12:47:45 +00:00
c29eb197f3 Merge pull request #71 from tim-beatham/66-ipv6-address-not-conforming-to-spec
66 ipv6 address not conforming to spec
2023-12-30 22:26:53 +00:00
1a9d9d61ad 66-ipv6-address-not-conforming-to-spec
- Missing commit
2023-12-30 22:26:08 +00:00
6954608c32 66-ipv6-address-not-confirming-to-spec
- UUID is not random just a name generator needs changing to shortuuid
- When in multiple meshes there is no wait group
2023-12-30 22:24:43 +00:00
2e6aed6f93 main
- Fixing issue with nil pointer de-reference due to bad design of mesh
  manager.
- Going forward all references to GetSelf should be depracated. It
  introduces a race condition when leaving a mesh network
2023-12-30 00:44:57 +00:00
b0893a0b8e Merge pull request #69 from tim-beatham/60-unit-test-crdt-data-store
60-unit-test-crdt-data-store
2023-12-29 22:06:20 +00:00
e7d6055fa3 60-unit-test-crdt-data-store
Provided unit tests for datastore.go
And fixed unit tets failing by different way of providing CA
2023-12-29 22:05:05 +00:00
e0f3f116b9 main
- Stale serverConfig entry causing certificate authorities
to not become authorised
2023-12-29 19:54:08 +00:00
352648b7cb main
- Fixed problem where connection not removed on error
2023-12-29 11:12:40 +00:00
2d5df25b1d main
- If deadline exceeded error remove connection from
connection manager
2023-12-29 01:29:11 +00:00
cabe173831 main
Adding retry parameter
2023-12-29 01:10:26 +00:00
d2c8a52ec6 main
- Adding retry policy for mobility
2023-12-29 00:58:43 +00:00
bf53108384 main
- Bugfix, fix consistent hash problem where
if failure happens then causes panic
2023-12-28 23:24:38 +00:00
77aac5534b main
- Bugfix in client where "-" was attempted to be parsed as a UDP addr
2023-12-28 17:46:04 +00:00
58439fcd56 main
- Bugfix when keepalivewg is not set causes segmentation fault
- give keepalive a default value of 0 if not set
2023-12-28 17:32:54 +00:00
311a15363a Merge pull request #67 from tim-beatham/66-improve-graph-dot-tool
66 improve graph dot tool
2023-12-25 01:26:15 +00:00
255d3c8b39 66-improve-graph-dot-tool
- Showing services a node provides
- Showing all meshes not just one
- Showing the default route
2023-12-25 01:25:20 +00:00
41899c5831 66-improve-graph-dot-tool
Improving the graph dot tool so that it shows all
meshes
2023-12-25 01:10:11 +00:00
fe4ca66ff6 Merge pull request #65 from tim-beatham/64-2p-set-unit-test
64 2p set unit test
2023-12-22 23:58:59 +00:00
0b91ba744a 61-improve-unit-test-coverage
- Provided unit tests for g_map and 2p_map
2023-12-22 23:57:10 +00:00
67483c2a90 64-unit-test-two-phase-set
Provide unit tests for two phase set to make it more
transparent what exactly they are doing.
2023-12-22 23:57:10 +00:00
af26e81bd3 Merge pull request #63 from tim-beatham/61-improve-unit-testing-coverage
61-improve-unit-testing-coverage
2023-12-22 21:52:46 +00:00
186acbe915 Merge pull request #62 from tim-beatham/61-improve-unit-testing-coverage
61-improve-unit-testing-coverage
2023-12-22 21:49:06 +00:00
80 changed files with 3564 additions and 1698 deletions

View File

@ -9,4 +9,4 @@ RUN apt-get update && apt-get install -y \
vim
WORKDIR /wgmesh
RUN go mod tidy
RUN go build -o /usr/local/bin ./...
RUN go build -o /usr/local/bin ./...

View File

@ -3,7 +3,7 @@ package main
import (
"log"
"github.com/tim-beatham/wgmesh/pkg/api"
"github.com/tim-beatham/smegmesh/pkg/api"
)
func main() {

View File

@ -3,7 +3,7 @@ package main
import (
"log"
smegdns "github.com/tim-beatham/wgmesh/pkg/dns"
smegdns "github.com/tim-beatham/smegmesh/pkg/dns"
)
func main() {

View File

@ -6,8 +6,10 @@ import (
"os"
"github.com/akamensky/argparse"
"github.com/tim-beatham/wgmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
graph "github.com/tim-beatham/smegmesh/pkg/dot"
"github.com/tim-beatham/smegmesh/pkg/ipc"
logging "github.com/tim-beatham/smegmesh/pkg/log"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
@ -20,25 +22,22 @@ type CreateMeshParams struct {
AdvertiseDefault bool
}
func createMesh(params *CreateMeshParams) string {
func createMesh(client *ipc.SmegmeshIpc, args *ipc.NewMeshArgs) {
var reply string
newMeshParams := ipc.NewMeshArgs{
WgArgs: params.WgArgs,
}
err := params.Client.Call("IpcHandler.CreateMesh", &newMeshParams, &reply)
err := client.CreateMesh(args, &reply)
if err != nil {
return err.Error()
fmt.Println(err.Error())
return
}
return reply
fmt.Println(reply)
}
func listMeshes(client *ipcRpc.Client) {
func listMeshes(client *ipc.SmegmeshIpc) {
reply := new(ipc.ListMeshReply)
err := client.Call("IpcHandler.ListMeshes", "", &reply)
err := client.ListMeshes(reply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
@ -50,38 +49,22 @@ func listMeshes(client *ipcRpc.Client) {
}
}
type JoinMeshParams struct {
Client *ipcRpc.Client
MeshId string
IpAddress string
Endpoint string
WgArgs ipc.WireGuardArgs
AdvertiseRoutes bool
AdvertiseDefault bool
}
func joinMesh(params *JoinMeshParams) string {
func joinMesh(client *ipc.SmegmeshIpc, args ipc.JoinMeshArgs) {
var reply string
args := ipc.JoinMeshArgs{
MeshId: params.MeshId,
IpAdress: params.IpAddress,
WgArgs: params.WgArgs,
}
err := params.Client.Call("IpcHandler.JoinMesh", &args, &reply)
err := client.JoinMesh(args, &reply)
if err != nil {
return err.Error()
fmt.Println(err.Error())
}
return reply
fmt.Println(reply)
}
func leaveMesh(client *ipcRpc.Client, meshId string) {
func leaveMesh(client *ipc.SmegmeshIpc, meshId string) {
var reply string
err := client.Call("IpcHandler.LeaveMesh", &meshId, &reply)
err := client.LeaveMesh(meshId, &reply)
if err != nil {
fmt.Println(err.Error())
@ -91,10 +74,51 @@ func leaveMesh(client *ipcRpc.Client, meshId string) {
fmt.Println(reply)
}
func getGraph(client *ipcRpc.Client, meshId string) {
func getGraph(client *ipc.SmegmeshIpc) {
listMeshesReply := new(ipc.ListMeshReply)
err := client.ListMeshes(listMeshesReply)
if err != nil {
fmt.Println(err.Error())
return
}
meshes := make(map[string][]ctrlserver.MeshNode)
for _, meshId := range listMeshesReply.Meshes {
var meshReply ipc.GetMeshReply
err := client.GetMesh(meshId, &meshReply)
if err != nil {
fmt.Println(err.Error())
return
}
meshes[meshId] = meshReply.Nodes
}
dotGenerator := graph.NewMeshGraphConverter(meshes)
dot, err := dotGenerator.Generate()
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(dot)
}
func queryMesh(client *ipc.SmegmeshIpc, meshId, query string) {
var reply string
err := client.Call("IpcHandler.GetDOT", &meshId, &reply)
args := ipc.QueryMesh{
MeshId: meshId,
Query: query,
}
err := client.Query(args, &reply)
if err != nil {
fmt.Println(err.Error())
@ -104,24 +128,13 @@ func getGraph(client *ipcRpc.Client, meshId string) {
fmt.Println(reply)
}
func queryMesh(client *ipcRpc.Client, meshId, query string) {
func putDescription(client *ipc.SmegmeshIpc, meshId, description string) {
var reply string
err := client.Call("IpcHandler.Query", &ipc.QueryMesh{MeshId: meshId, Query: query}, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
// putDescription: puts updates the description about the node to the meshes
func putDescription(client *ipcRpc.Client, description string) {
var reply string
err := client.Call("IpcHandler.PutDescription", &description, &reply)
err := client.PutDescription(ipc.PutDescriptionArgs{
MeshId: meshId,
Description: description,
}, &reply)
if err != nil {
fmt.Println(err.Error())
@ -132,10 +145,13 @@ func putDescription(client *ipcRpc.Client, description string) {
}
// putAlias: puts an alias for the node
func putAlias(client *ipcRpc.Client, alias string) {
func putAlias(client *ipc.SmegmeshIpc, meshid, alias string) {
var reply string
err := client.Call("IpcHandler.PutAlias", &alias, &reply)
err := client.PutAlias(ipc.PutAliasArgs{
MeshId: meshid,
Alias: alias,
}, &reply)
if err != nil {
fmt.Println(err.Error())
@ -145,15 +161,14 @@ func putAlias(client *ipcRpc.Client, alias string) {
fmt.Println(reply)
}
func setService(client *ipcRpc.Client, service, value string) {
func setService(client *ipc.SmegmeshIpc, meshId, service, value string) {
var reply string
serviceArgs := &ipc.PutServiceArgs{
err := client.PutService(ipc.PutServiceArgs{
MeshId: meshId,
Service: service,
Value: value,
}
err := client.Call("IpcHandler.PutService", serviceArgs, &reply)
}, &reply)
if err != nil {
fmt.Println(err.Error())
@ -163,10 +178,13 @@ func setService(client *ipcRpc.Client, service, value string) {
fmt.Println(reply)
}
func deleteService(client *ipcRpc.Client, service string) {
func deleteService(client *ipc.SmegmeshIpc, meshId, service string) {
var reply string
err := client.Call("IpcHandler.PutService", &service, &reply)
err := client.DeleteService(ipc.DeleteServiceArgs{
MeshId: meshId,
Service: service,
}, &reply)
if err != nil {
fmt.Println(err.Error())
@ -177,8 +195,8 @@ func deleteService(client *ipcRpc.Client, service string) {
}
func main() {
parser := argparse.NewParser("wg-mesh",
"wg-mesh Manipulate WireGuard mesh networks")
parser := argparse.NewParser("smgctl",
"smegctl Manipulate WireGuard mesh networks")
newMeshCmd := parser.NewCommand("new-mesh", "Create a new mesh")
listMeshCmd := parser.NewCommand("list-meshes", "List meshes the node is connected to")
@ -201,7 +219,6 @@ func main() {
})
var newMeshRole *string = newMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{
Default: "peer",
Help: "Role in the mesh network. A value of peer means that the node is publicly routeable and thus considered" +
" in the gossip protocol. Client means that the node is not publicly routeable and is not a candidate in the gossip" +
" protocol",
@ -234,7 +251,6 @@ func main() {
})
var joinMeshRole *string = joinMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{
Default: "peer",
Help: "Role in the mesh network. A value of peer means that the node is publicly routeable and thus considered" +
" in the gossip protocol. Client means that the node is not publicly routeable and is not a candidate in the gossip" +
" protocol",
@ -258,11 +274,6 @@ func main() {
Help: "Advertise ::/0 into the mesh network",
})
var getGraphMeshId *string = getGraphCmd.String("m", "mesh", &argparse.Options{
Required: true,
Help: "MeshID of the graph to get",
})
var leaveMeshMeshId *string = leaveMeshCmd.String("m", "mesh", &argparse.Options{
Required: true,
Help: "MeshID of the mesh to leave",
@ -282,6 +293,16 @@ func main() {
Help: "Description of the node in the mesh",
})
var descriptionMeshId *string = putDescriptionCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
var aliasMeshId *string = putAliasCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
var alias *string = putAliasCmd.String("a", "alias", &argparse.Options{
Required: true,
Help: "Alias of the node to set can be used in DNS to lookup an IP address",
@ -296,11 +317,21 @@ func main() {
Help: "Value of the service to advertise in the mesh network",
})
var serviceMeshId *string = setServiceCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{
Required: true,
Help: "Key of the service to remove",
})
var deleteServiceMeshid *string = deleteServiceCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
err := parser.Parse(os.Args)
if err != nil {
@ -308,16 +339,13 @@ func main() {
return
}
client, err := ipcRpc.DialHTTP("unix", SockAddr)
client, err := ipc.NewClientIpc()
if err != nil {
fmt.Println(err.Error())
return
panic(err)
}
if newMeshCmd.Happened() {
fmt.Println(createMesh(&CreateMeshParams{
Client: client,
Endpoint: *newMeshEndpoint,
args := &ipc.NewMeshArgs{
WgArgs: ipc.WireGuardArgs{
Endpoint: *newMeshEndpoint,
Role: *newMeshRole,
@ -326,7 +354,9 @@ func main() {
AdvertiseDefaultRoute: *newMeshAdvertiseDefaults,
AdvertiseRoutes: *newMeshAdvertiseRoutes,
},
}))
}
createMesh(client, args)
}
if listMeshCmd.Happened() {
@ -334,11 +364,9 @@ func main() {
}
if joinMeshCmd.Happened() {
fmt.Println(joinMesh(&JoinMeshParams{
Client: client,
args := ipc.JoinMeshArgs{
IpAddress: *joinMeshIpAddress,
MeshId: *joinMeshId,
Endpoint: *joinMeshEndpoint,
WgArgs: ipc.WireGuardArgs{
Endpoint: *joinMeshEndpoint,
Role: *joinMeshRole,
@ -347,11 +375,12 @@ func main() {
AdvertiseDefaultRoute: *joinMeshAdvertiseDefaults,
AdvertiseRoutes: *joinMeshAdvertiseRoutes,
},
}))
}
joinMesh(client, args)
}
if getGraphCmd.Happened() {
getGraph(client, *getGraphMeshId)
getGraph(client)
}
if leaveMeshCmd.Happened() {
@ -363,18 +392,18 @@ func main() {
}
if putDescriptionCmd.Happened() {
putDescription(client, *description)
putDescription(client, *descriptionMeshId, *description)
}
if putAliasCmd.Happened() {
putAlias(client, *alias)
putAlias(client, *aliasMeshId, *alias)
}
if setServiceCmd.Happened() {
setService(client, *serviceKey, *serviceValue)
setService(client, *serviceMeshId, *serviceKey, *serviceValue)
}
if deleteServiceCmd.Happened() {
deleteService(client, *deleteServiceKey)
deleteService(client, *deleteServiceMeshid, *deleteServiceKey)
}
}

View File

@ -6,29 +6,30 @@ import (
"os"
"os/signal"
"github.com/tim-beatham/wgmesh/pkg/conf"
ctrlserver "github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/robin"
"github.com/tim-beatham/wgmesh/pkg/sync"
timer "github.com/tim-beatham/wgmesh/pkg/timers"
"github.com/tim-beatham/smegmesh/pkg/conf"
ctrlserver "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/smegmesh/pkg/ipc"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/robin"
"github.com/tim-beatham/smegmesh/pkg/sync"
"golang.zx2c4.com/wireguard/wgctrl"
)
func main() {
if len(os.Args) != 2 {
logging.Log.WriteErrorf("Did not provide configuration")
return
}
conf, err := conf.ParseDaemonConfiguration(os.Args[1])
configuration, err := conf.ParseDaemonConfiguration(os.Args[1])
if err != nil {
logging.Log.WriteErrorf("Could not parse configuration: %s", err.Error())
return
}
logging.SetLogger(logging.NewLogrusLogger(configuration.LogLevel))
client, err := wgctrl.New()
if err != nil {
@ -36,7 +37,7 @@ func main() {
return
}
if conf.Profile {
if configuration.Profile {
go func() {
http.ListenAndServe("localhost:6060", nil)
}()
@ -45,25 +46,21 @@ func main() {
var robinRpc robin.WgRpc
var robinIpc robin.IpcHandler
var syncProvider sync.SyncServiceImpl
var syncRequester sync.SyncRequester
var syncer sync.Syncer
ctrlServerParams := ctrlserver.NewCtrlServerParams{
Conf: conf,
Conf: configuration,
CtrlProvider: &robinRpc,
SyncProvider: &syncProvider,
Client: client,
OnDelete: func(mp mesh.MeshProvider) {
syncer.SyncMeshes()
},
}
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
syncProvider.Server = ctrlServer
syncRequester = sync.NewSyncRequester(ctrlServer)
syncer = sync.NewSyncer(ctrlServer.MeshManager, conf, syncRequester)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, syncer)
keepAlive := timer.NewTimestampScheduler(ctrlServer)
if err != nil {
panic(err)
}
syncProvider.MeshManager = ctrlServer.MeshManager
robinIpcParams := robin.RobinIpcParams{
CtrlServer: ctrlServer,
@ -77,16 +74,11 @@ func main() {
return
}
logging.Log.WriteInfof("Running IPC Handler")
logging.Log.WriteInfof("running ipc handler")
go ipc.RunIpcHandler(&robinIpc)
go syncScheduler.Run()
go keepAlive.Run()
closeResources := func() {
logging.Log.WriteInfof("Closing resources")
syncScheduler.Stop()
keepAlive.Stop()
logging.Log.WriteInfof("closing resources")
ctrlServer.Close()
client.Close()
}

View File

@ -10,5 +10,5 @@ syncRate: 1
interClusterChance: 0.15
branchRate: 3
infectionCount: 3
keepAliveTime: 10
heartBeatTime: 10
pruneTime: 20

View File

@ -1,95 +0,0 @@
version: '3'
networks:
net-1:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.0.0/17
net-2:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.155.0/17
services:
wg-1:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-2:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-3:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-4:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
sysctls:
- net.ipv6.conf.all.forwarding=1
networks:
- net-1
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-5:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-6:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-7:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"

View File

@ -1,14 +0,0 @@
certificatePath: "/wgmesh/cert/cert.pem"
privateKeyPath: "/wgmesh/cert/priv.pem"
caCertificatePath: "/wgmesh/cert/cacert.pem"
skipCertVerification: true
timeout: 5
gRPCPort: "21906"
advertiseRoutes: true
clusterSize: 32
syncRate: 1
interClusterChance: 0.15
branchRate: 3
infectionCount: 3
keepAliveTime: 10
pruneTime: 20

View File

@ -1,14 +1,9 @@
version: '3'
networks:
net-1:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.0.0/17
services:
wg-1:
image: wg-mesh-base:latest
image: localhost/smegmesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
@ -17,9 +12,11 @@ services:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
command: "smegd /shared/configuration.yaml"
sysctls:
- net.ipv6.conf.all.forwarding=1
wg-2:
image: wg-mesh-base:latest
image: localhost/smegmesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
@ -28,9 +25,11 @@ services:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
command: "smegd /shared/configuration.yaml"
sysctls:
- net.ipv6.conf.all.forwarding=1
wg-3:
image: wg-mesh-base:latest
image: localhost/smegmesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
@ -39,4 +38,6 @@ services:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
command: "smegd /shared/configuration.yaml"
sysctls:
- net.ipv6.conf.all.forwarding=1

View File

@ -1,14 +1,34 @@
certificatePath: "/wgmesh/cert/cert.pem"
privateKeyPath: "/wgmesh/cert/priv.pem"
caCertificatePath: "/wgmesh/cert/cacert.pem"
# Paths to the certificates modify
# if not running from Smegmesh
certificatePath: "./cert/cert.pem"
privateKeyPath: "./cert/priv.pem"
caCertificatePath: "./cert/cacert.pem"
skipCertVerification: true
# timeout is the configured grpc timeout
timeout: 5
gRPCPort: "21906"
advertiseRoutes: true
clusterSize: 32
syncRate: 1
interClusterChance: 0.15
branchRate: 3
# gRPC port to run the solution
gRPCPort: 4000
# whether or not to run go profiler
profile: false
# stubWg: whether to install WireGuard configurations
# if true just tests the control plane
stubWg: false
heartbeatInterval: 60
branch: 3
pullInterval: 20
infectionCount: 3
keepAliveTime: 10
pruneTime: 20
interClusterChance: 0.15
syncInterval: 2
clusterSize: 64
logLevel: "info"
baseConfiguration:
# ipDiscovery: specifies how to find your IP address
ipDiscovery: "outgoing"
# alternative to ipDiscovery specify an actual endpoint yourself with publicEndpoint: "xxxx"
# role is the role that you are playing (peer | client)
# peers can only bootstrap meshes
role: "peer"
# advertise meshes to other meshes
advertiseRoute: true
# advertise default routes
advertiseDefaults: true

4
go.mod
View File

@ -1,4 +1,4 @@
module github.com/tim-beatham/wgmesh
module github.com/tim-beatham/smegmesh
go 1.21.3
@ -11,11 +11,11 @@ require (
github.com/google/uuid v1.3.0
github.com/jmespath/go-jmespath v0.4.0
github.com/jsimonetti/rtnetlink v1.3.5
github.com/lithammer/shortuuid v3.0.0+incompatible
github.com/miekg/dns v1.1.57
github.com/sirupsen/logrus v1.9.3
golang.org/x/sys v0.14.0
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
gonum.org/v1/gonum v0.14.0
google.golang.org/grpc v1.58.1
google.golang.org/protobuf v1.31.0
gopkg.in/yaml.v3 v3.0.1

6
go.sum
View File

@ -27,8 +27,6 @@ github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/o
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE=
github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
@ -57,6 +55,8 @@ github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZX
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/lithammer/shortuuid v3.0.0+incompatible h1:NcD0xWW/MZYXEHa6ITy6kaXN5nwm/V115vj2YXfhS0w=
github.com/lithammer/shortuuid v3.0.0+incompatible/go.mod h1:FR74pbAuElzOUuenUHTK2Tciko1/vKuIKS9dSkDrA4w=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
@ -123,8 +123,6 @@ golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 h1:EY138uSo1JYlDq+
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0=
gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 h1:bVf09lpb+OJbByTj913DRJioFFAjf/ZGxEz7MajTp2U=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98/go.mod h1:TUfxEVdsvPg18p6AslUXFoLdpED4oBnGwyqk3dV1XzM=
google.golang.org/grpc v1.58.1 h1:OL+Vz23DTtrrldqHK49FUOPHyY75rvFqJfXC84NYW58=

View File

@ -4,28 +4,14 @@ import (
"fmt"
"net/http"
ipcRpc "net/rpc"
"github.com/gin-gonic/gin"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/what8words"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/smegmesh/pkg/ipc"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/what8words"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
type ApiServer interface {
GetMeshes(c *gin.Context)
Run(addr string) error
}
type SmegServer struct {
router *gin.Engine
client *ipcRpc.Client
words *what8words.What8Words
}
// routesToApiRoute: convert the returned type to a JSON object
func (s *SmegServer) routeToApiRoute(meshNode ctrlserver.MeshNode) []Route {
routes := make([]Route, len(meshNode.Routes))
@ -44,6 +30,7 @@ func (s *SmegServer) routeToApiRoute(meshNode ctrlserver.MeshNode) []Route {
return routes
}
// meshNodeToAPImeshNode: convert daemon node to a JSON node
func (s *SmegServer) meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode {
if meshNode.Routes == nil {
meshNode.Routes = make([]ctrlserver.MeshRoute, 0)
@ -74,6 +61,7 @@ func (s *SmegServer) meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNo
}
}
// meshToAPIMesh: Convert daemon mesh network to a JSON mesh network
func (s *SmegServer) meshToAPIMesh(meshId string, nodes []ctrlserver.MeshNode) SmegMesh {
var smegMesh SmegMesh
smegMesh.MeshId = meshId
@ -86,6 +74,25 @@ func (s *SmegServer) meshToAPIMesh(meshId string, nodes []ctrlserver.MeshNode) S
return smegMesh
}
// putAlias: place an alias in the mesh
func (s *SmegServer) putAlias(meshId, alias string) error {
var reply string
return s.client.PutAlias(ipc.PutAliasArgs{
Alias: alias,
MeshId: meshId,
}, &reply)
}
func (s *SmegServer) putDescription(meshId, description string) error {
var reply string
return s.client.PutDescription(ipc.PutDescriptionArgs{
Description: description,
MeshId: meshId,
}, &reply)
}
// CreateMesh: creates a mesh network
func (s *SmegServer) CreateMesh(c *gin.Context) {
var createMesh CreateMeshRequest
@ -98,15 +105,21 @@ func (s *SmegServer) CreateMesh(c *gin.Context) {
return
}
fmt.Printf("%+v\n", createMesh)
ipcRequest := ipc.NewMeshArgs{
WgArgs: ipc.WireGuardArgs{
WgPort: createMesh.WgPort,
WgPort: createMesh.WgPort,
Role: createMesh.Role,
Endpoint: createMesh.PublicEndpoint,
AdvertiseRoutes: createMesh.AdvertiseRoutes,
AdvertiseDefaultRoute: createMesh.AdvertiseDefaults,
},
}
var reply string
err := s.client.Call("IpcHandler.CreateMesh", &ipcRequest, &reply)
err := s.client.CreateMesh(&ipcRequest, &reply)
if err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
@ -115,6 +128,14 @@ func (s *SmegServer) CreateMesh(c *gin.Context) {
return
}
if createMesh.Alias != "" {
s.putAlias(reply, createMesh.Alias)
}
if createMesh.Description != "" {
s.putDescription(reply, createMesh.Description)
}
c.JSON(http.StatusOK, &gin.H{
"meshid": reply,
})
@ -132,16 +153,20 @@ func (s *SmegServer) JoinMesh(c *gin.Context) {
}
ipcRequest := ipc.JoinMeshArgs{
MeshId: joinMesh.MeshId,
IpAdress: joinMesh.Bootstrap,
MeshId: joinMesh.MeshId,
IpAddress: joinMesh.Bootstrap,
WgArgs: ipc.WireGuardArgs{
WgPort: joinMesh.WgPort,
WgPort: joinMesh.WgPort,
Endpoint: joinMesh.PublicEndpoint,
Role: joinMesh.Role,
AdvertiseRoutes: joinMesh.AdvertiseRoutes,
AdvertiseDefaultRoute: joinMesh.AdvertiseDefaults,
},
}
var reply string
err := s.client.Call("IpcHandler.JoinMesh", &ipcRequest, &reply)
err := s.client.JoinMesh(ipcRequest, &reply)
if err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
@ -150,6 +175,14 @@ func (s *SmegServer) JoinMesh(c *gin.Context) {
return
}
if joinMesh.Alias != "" {
s.putAlias(reply, joinMesh.Alias)
}
if joinMesh.Description != "" {
s.putDescription(reply, joinMesh.Description)
}
c.JSON(http.StatusOK, &gin.H{
"status": "success",
})
@ -164,7 +197,7 @@ func (s *SmegServer) GetMesh(c *gin.Context) {
getMeshReply := new(ipc.GetMeshReply)
err := s.client.Call("IpcHandler.GetMesh", &meshid, &getMeshReply)
err := s.client.GetMesh(meshid, getMeshReply)
if err != nil {
c.JSON(http.StatusNotFound,
@ -179,10 +212,12 @@ func (s *SmegServer) GetMesh(c *gin.Context) {
c.JSON(http.StatusOK, mesh)
}
// GetMeshes: return all the mesh networks that the
// user is a part of
func (s *SmegServer) GetMeshes(c *gin.Context) {
listMeshesReply := new(ipc.ListMeshReply)
err := s.client.Call("IpcHandler.ListMeshes", "", &listMeshesReply)
err := s.client.ListMeshes(listMeshesReply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
@ -195,7 +230,7 @@ func (s *SmegServer) GetMeshes(c *gin.Context) {
for _, mesh := range listMeshesReply.Meshes {
getMeshReply := new(ipc.GetMeshReply)
err := s.client.Call("IpcHandler.GetMesh", &mesh, &getMeshReply)
err := s.client.GetMesh(mesh, getMeshReply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
@ -209,13 +244,16 @@ func (s *SmegServer) GetMeshes(c *gin.Context) {
c.JSON(http.StatusOK, meshes)
}
// Run: run the API server
func (s *SmegServer) Run(addr string) error {
logging.Log.WriteInfof("Running API server")
return s.router.Run(addr)
}
// NewSmegServer: creates an instance of a new API server
// returns an error if something went wrong
func NewSmegServer(conf ApiServerConf) (ApiServer, error) {
client, err := ipcRpc.DialHTTP("unix", SockAddr)
client, err := ipc.NewClientIpc()
if err != nil {
return nil, err
@ -239,9 +277,19 @@ func NewSmegServer(conf ApiServerConf) (ApiServer, error) {
words: words,
}
router.GET("/meshes", smegServer.GetMeshes)
router.GET("/mesh/:meshid", smegServer.GetMesh)
router.POST("/mesh/create", smegServer.CreateMesh)
router.POST("/mesh/join", smegServer.JoinMesh)
v1 := router.Group("/api/v1")
{
meshes := v1.Group("/meshes")
{
meshes.GET("/", smegServer.GetMeshes)
}
mesh := v1.Group("/mesh")
{
mesh.GET("/:meshid", smegServer.GetMesh)
mesh.POST("/create", smegServer.CreateMesh)
mesh.POST("/join", smegServer.JoinMesh)
}
}
return smegServer, nil
}

View File

@ -1,47 +1,129 @@
package api
import "time"
import (
"time"
"github.com/gin-gonic/gin"
"github.com/tim-beatham/smegmesh/pkg/ipc"
"github.com/tim-beatham/smegmesh/pkg/what8words"
)
// Route is an advertised route in the data store
type Route struct {
Prefix string `json:"prefix"`
Path []string `json:"path"`
// Prefix is the advertised route prefix
Prefix string `json:"prefix"`
// Path is the hops the destination
Path []string `json:"path"`
}
// SmegStats is the WireGuard stats that the underlying host
// has sent to the peer
type SmegStats struct {
TotalTransmit int64 `json:"totalTransmit"`
TotalReceived int64 `json:"totalReceived"`
// TotalTransmit number of bytes sent to the peer
TotalTransmit int64 `json:"totalTransmit"`
// TotalReceived number of bytes received from the peer
TotalReceived int64 `json:"totalReceived"`
// KeepAliveInterval WireGuard keepalive interval that is sent to the host
KeepAliveInterval time.Duration `json:"keepaliveInterval"`
AllowedIps []string `json:"allowedIps"`
// AllowsIps is the allowed path to the destination
AllowedIps []string `json:"allowedIps"`
}
// SmegNode is a node in the mesh network
type SmegNode struct {
Alias string `json:"alias"`
WgHost string `json:"wgHost"`
WgEndpoint string `json:"wgEndpoint"`
Endpoint string `json:"endpoint"`
Timestamp int `json:"timestamp"`
Description string `json:"description"`
PublicKey string `json:"publicKey"`
Routes []Route `json:"routes"`
Services map[string]string `json:"services"`
Stats SmegStats `json:"stats"`
// Alias is the human readable name that the node is assocaited with
Alias string `json:"alias"`
// WgHost is the WireGuard IP address of the node. This is an IPv6
// address
WgHost string `json:"wgHost"`
// WgEndpoint is the physical endpoint of the host that packets
// are forwarded to
WgEndpoint string `json:"wgEndpoint"`
// Endpoint is the control plane endpoint of the host which
// grpc connections are to be sent along
Endpoint string `json:"endpoint"`
// Timestamp is the last time the signified it was alive.
// if the node is the leader this is evert heartBeatInterval
// otherwise this is the time the node joined the network
Timestamp int `json:"timestamp"`
// Description is the human readable description of the node
Description string `json:"description"`
// PublicKey is the WireGuard public key of the node
PublicKey string `json:"publicKey"`
// Routes is the routes that the node is advertising
Routes []Route `json:"routes"`
// Services is information about services that the node offers
Services map[string]string `json:"services"`
// Stats is the WireGuard stats of the node (if any)
Stats SmegStats `json:"stats"`
}
// SmegMesh encapsulates a single mesh in the API
type SmegMesh struct {
MeshId string `json:"meshid"`
Nodes map[string]SmegNode `json:"nodes"`
// MeshId is the mesh id of the network
MeshId string `json:"meshid"`
// Nodes is the nodes in the network keyed by their public
// key
Nodes map[string]SmegNode `json:"nodes"`
}
// CreateMeshRequest encapsulates a request to create a mesh network
type CreateMeshRequest struct {
// WgPort is the WireGuard to create the mesh in
WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"`
// Role is the role to take on in the mesh
Role string `json:"role" binding:"required,eq=client|eq=peer"`
// AdvertiseRoutes: advertise thi mesh to other meshes
AdvertiseRoutes bool `json:"advertiseRoutes"`
// AdvertiseDefaults: advertise an exit point
AdvertiseDefaults bool `json:"advertiseDefaults"`
// Alias: alias of the node in the mesh
Alias string `json:"alias"`
// Description: description of the node in the mesh
Description string `json:"description"`
// PublicEndpoint: an alternative public endpoint to advertise
PublicEndpoint string `json:"publicEndpoint"`
}
// JoinMeshRequests encapsulates a request to create a mesh network
type JoinMeshRequest struct {
WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"`
// WgPort is the WireGuard port to run the service on
WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"`
// Bootstrap is a bootstrap node to use to join the network
Bootstrap string `json:"bootstrap" binding:"required"`
MeshId string `json:"meshid" binding:"required"`
// MeshId is the ID of the mesh to join
MeshId string `json:"meshid" binding:"required"`
// Role is the role to take on in the mesh
Role string `json:"role" binding:"required,eq=client|eq=peer"`
// AdvertiseRoutes: advertise thi mesh to other meshes
AdvertiseRoutes bool `json:"advertiseRoutes"`
// AdvertiseDefaults: advertise an exit point
AdvertiseDefaults bool `json:"advertiseDefaults"`
// Alias: alias of the node in the mesh
Alias string `json:"alias"`
// Description: description of the node in the mesh
Description string `json:"description"`
// PublicEndpoint: an alternative public endpoint to advertise
PublicEndpoint string `json:"publicEndpoint"`
}
// ApiServerConf configuration to instantiate the API server
type ApiServerConf struct {
// WordsFile to use to map IP to words
WordsFile string
}
// SmegSever is the GIN api server that runs the service
type SmegServer struct {
// gin router to use
router *gin.Engine
// client to invoke operations
client *ipc.SmegmeshIpc
// what8words to use to convert IP to an alias
words *what8words.What8Words
}
// ApiSever absrtacts the API server
type ApiServer interface {
Run(addr string) error
}

View File

@ -1,3 +1,5 @@
// automerge: package is depracated and unused. Please refer to crdt
// for crdt operations in the mesh
package automerge
import (
@ -9,26 +11,36 @@ import (
"time"
"github.com/automerge/automerge-go"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// CrdtMeshManager manages nodes in the crdt mesh
// CrdtMeshManager manage the CRDT datastore
type CrdtMeshManager struct {
MeshId string
IfName string
Client *wgctrl.Client
doc *automerge.Doc
LastHash automerge.ChangeHash
conf *conf.WgConfiguration
cache *MeshCrdt
// MeshID of the mesh the datastore represents
MeshId string
// IfName: corresponding ifName
IfName string
// Client: corresponding wireguard control client
Client *wgctrl.Client
// doc: autommerge document
doc *automerge.Doc
// LastHash: last hash that the changes were made to
LastHash automerge.ChangeHash
// conf: WireGuard configuration
conf *conf.WgConfiguration
// cache: stored cache of the list automerge document
// so that the store does not have to be repopulated each time
cache *MeshCrdt
// lastCachehash: hash of when the document was last changed
lastCacheHash automerge.ChangeHash
}
// AddNode as a node to the datastore
func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
crdt, ok := node.(*MeshNodeCrdt)
@ -47,6 +59,7 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
}
}
// isPeer: returns true if the given node has type peer
func (c *CrdtMeshManager) isPeer(nodeId string) bool {
node, err := c.doc.Path("nodes").Map().Get(nodeId)
@ -64,7 +77,8 @@ func (c *CrdtMeshManager) isPeer(nodeId string) bool {
}
// isAlive: checks that the node's configuration has been updated
// since the rquired keep alive time
// since the rquired keep alive time. Depracated no longer works
// due to changes in approach
func (c *CrdtMeshManager) isAlive(nodeId string) bool {
node, err := c.doc.Path("nodes").Map().Get(nodeId)
@ -78,10 +92,11 @@ func (c *CrdtMeshManager) isAlive(nodeId string) bool {
return false
}
return true
// return (time.Now().Unix() - keepAliveTime) < int64(c.conf.DeadTime)
return true
}
// GetPeers: get all the peers in the mesh
func (c *CrdtMeshManager) GetPeers() []string {
keys, _ := c.doc.Path("nodes").Map().Keys()
@ -92,7 +107,7 @@ func (c *CrdtMeshManager) GetPeers() []string {
return keys
}
// GetMesh(): Converts the document into a struct
// GetMesh: Converts the document into a mesh network
func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
changes, err := c.doc.Changes(c.lastCacheHash)
@ -114,7 +129,7 @@ func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
return c.cache, nil
}
// GetMeshId returns the meshid of the mesh
// GetMeshId: returns the meshid of the mesh
func (c *CrdtMeshManager) GetMeshId() string {
return c.MeshId
}
@ -135,6 +150,8 @@ func (c *CrdtMeshManager) Load(bytes []byte) error {
return nil
}
// NewCrdtNodeManagerParams: params to instantiate a new automerge
// datastore
type NewCrdtNodeMangerParams struct {
MeshId string
DevName string
@ -143,7 +160,7 @@ type NewCrdtNodeMangerParams struct {
Client *wgctrl.Client
}
// NewCrdtNodeManager: Create a new crdt node manager
// NewCrdtNodeManager: Create a new automerge crdt data store
func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, error) {
var manager CrdtMeshManager
manager.MeshId = params.MeshId
@ -155,12 +172,13 @@ func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, erro
return &manager, nil
}
// NodeExists: returns true if the node exists. Returns false
// NodeExists: returns true if the node exists other returns false
func (m *CrdtMeshManager) NodeExists(key string) bool {
node, err := m.doc.Path("nodes").Map().Get(key)
return node.Kind() == automerge.KindMap && err == nil
}
// GetNode: gets a node from the mesh network.
func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
node, err := m.doc.Path("nodes").Map().Get(endpoint)
@ -181,10 +199,12 @@ func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
return meshNode, nil
}
// Length: returns the number of nodes in the store
func (m *CrdtMeshManager) Length() int {
return m.doc.Path("nodes").Map().Len()
}
// GetDevice: get the underlying WireGuard device
func (m *CrdtMeshManager) GetDevice() (*wgtypes.Device, error) {
dev, err := m.Client.Device(m.IfName)
@ -195,7 +215,7 @@ func (m *CrdtMeshManager) GetDevice() (*wgtypes.Device, error) {
return dev, nil
}
// HasChanges returns true if we have changes since the last time we synced
// HasChanges: returns true if there are changes since last time synchronised
func (m *CrdtMeshManager) HasChanges() bool {
changes, err := m.doc.Changes(m.LastHash)
@ -209,6 +229,7 @@ func (m *CrdtMeshManager) HasChanges() bool {
return len(changes) > 0
}
// SaveChanges: save changes to the datastore
func (m *CrdtMeshManager) SaveChanges() {
hashes := m.doc.Heads()
hash := hashes[len(hashes)-1]
@ -217,6 +238,7 @@ func (m *CrdtMeshManager) SaveChanges() {
m.LastHash = hash
}
// UpdateTimeStamp: updates the timestamp of the document
func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -237,6 +259,7 @@ func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
return err
}
// SetDescription: set the description of the given node
func (m *CrdtMeshManager) SetDescription(nodeId string, description string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -257,6 +280,7 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro
return err
}
// SetAlias: set the alias of the given node
func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -277,6 +301,7 @@ func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
return err
}
// AddService: add a service to the given node
func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -298,6 +323,7 @@ func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
return err
}
// RemoveService: remove a service from a node
func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -378,6 +404,7 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
return nil
}
// getRoutes: get the routes that the given node is directly advertising
func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -404,6 +431,8 @@ func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) {
return lib.MapValues(routes), err
}
// GetRoutes: get all the routes that the node can see. The routes that the node
// can say may not be direct but cann also be indirect
func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) {
node, err := m.GetNode(targetNode)
@ -447,12 +476,13 @@ func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, e
return routes, nil
}
// RemoveNode: removes a node from the datastore
func (m *CrdtMeshManager) RemoveNode(nodeId string) error {
err := m.doc.Path("nodes").Map().Delete(nodeId)
return err
}
// DeleteRoutes deletes the specified routes
// RemoveRoutes: withdraw all the routes the nodeID is advertising
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -486,30 +516,37 @@ func (m *CrdtMeshManager) GetConfiguration() *conf.WgConfiguration {
func (m *CrdtMeshManager) Mark(nodeId string) {
}
// GetSyncer: get the bi-directionally syncer to synchronise the document
func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
return NewAutomergeSync(m)
}
// Prune: prune all dead nodes
func (m *CrdtMeshManager) Prune() error {
return nil
}
// Compare: compare two mesh node for equality
func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int {
return strings.Compare(m1.PublicKey, m2.PublicKey)
}
// GetHostEndpoint: get the ctrl endpoint of the host
func (m *MeshNodeCrdt) GetHostEndpoint() string {
return m.HostEndpoint
}
// GetPublicKey: get the public key of the node
func (m *MeshNodeCrdt) GetPublicKey() (wgtypes.Key, error) {
return wgtypes.ParseKey(m.PublicKey)
}
// GetWgEndpoint: get the outer WireGuard endpoint
func (m *MeshNodeCrdt) GetWgEndpoint() string {
return m.WgEndpoint
}
// GetWgHost: get the WireGuard IP address of the host
func (m *MeshNodeCrdt) GetWgHost() *net.IPNet {
_, ipnet, err := net.ParseCIDR(m.WgHost)
@ -520,10 +557,12 @@ func (m *MeshNodeCrdt) GetWgHost() *net.IPNet {
return ipnet
}
// GetTimeStamp: get timestamp if when the node was last updated
func (m *MeshNodeCrdt) GetTimeStamp() int64 {
return m.Timestamp
}
// GetRoutes: get all the routes advertised by the node
func (m *MeshNodeCrdt) GetRoutes() []mesh.Route {
return lib.Map(lib.MapValues(m.Routes), func(r Route) mesh.Route {
return &Route{
@ -533,10 +572,12 @@ func (m *MeshNodeCrdt) GetRoutes() []mesh.Route {
})
}
// GetDescription: get the description of the node
func (m *MeshNodeCrdt) GetDescription() string {
return m.Description
}
// GetIdentifier: get the iderntifier section of the ipv6 address
func (m *MeshNodeCrdt) GetIdentifier() string {
ipv6 := m.WgHost[:len(m.WgHost)-4]
@ -545,10 +586,12 @@ func (m *MeshNodeCrdt) GetIdentifier() string {
return strings.Join(constituents, ":")
}
// GetAlias: get the alias of the node
func (m *MeshNodeCrdt) GetAlias() string {
return m.Alias
}
// GetServices: get all the services the node is advertising
func (m *MeshNodeCrdt) GetServices() map[string]string {
services := make(map[string]string)
@ -565,6 +608,7 @@ func (n *MeshNodeCrdt) GetType() conf.NodeType {
return conf.NodeType(n.Type)
}
// GetNodes: get all the nodes in the network
func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
nodes := make(map[string]mesh.MeshNode)
@ -586,15 +630,18 @@ func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
return nodes
}
// GetDestination: get destination of the route
func (r *Route) GetDestination() *net.IPNet {
_, ipnet, _ := net.ParseCIDR(r.Destination)
return ipnet
}
// GetHopCount: get the number of hops to the destination
func (r *Route) GetHopCount() int {
return len(r.Path)
}
// GetPath: get the total path which includes the number of hops
func (r *Route) GetPath() []string {
return r.Path
}

View File

@ -1,15 +1,24 @@
// automerge: automerge is a CRDT library. Defines a CRDT
// datastore and methods to resolve conflicts
package automerge
import (
"github.com/automerge/automerge-go"
logging "github.com/tim-beatham/wgmesh/pkg/log"
logging "github.com/tim-beatham/smegmesh/pkg/log"
)
// AutomergeSync: defines a synchroniser to bi-directionally synchronise the
// two states
type AutomergeSync struct {
state *automerge.SyncState
// state: the automerge sync state to use
state *automerge.SyncState
// manager: the corresponding data store that we are merging
manager *CrdtMeshManager
}
// GenerateMessage: geenrate a new automerge message to synchronise
// returns a byte of the message and a boolean of whether or not there
// are more messages in the sequence
func (a *AutomergeSync) GenerateMessage() ([]byte, bool) {
msg, valid := a.state.GenerateMessage()
@ -20,6 +29,8 @@ func (a *AutomergeSync) GenerateMessage() ([]byte, bool) {
return msg.Bytes(), true
}
// RecvMessage: receive an automerge message to merge in the datastore
// returns an error if unsuccessful
func (a *AutomergeSync) RecvMessage(msg []byte) error {
_, err := a.state.ReceiveMessage(msg)
@ -30,11 +41,13 @@ func (a *AutomergeSync) RecvMessage(msg []byte) error {
return nil
}
// Complete: complete the synchronisation process
func (a *AutomergeSync) Complete() {
logging.Log.WriteInfof("Sync Completed")
logging.Log.WriteInfof("sync completed")
a.manager.SaveChanges()
}
// NewAutomergeSync: instantiates a new automerge syncer
func NewAutomergeSync(manager *CrdtMeshManager) *AutomergeSync {
return &AutomergeSync{
state: automerge.NewSyncState(manager.doc),

View File

@ -6,9 +6,9 @@ import (
"testing"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@ -83,7 +83,6 @@ func TestAddNodeAddRoute(t *testing.T) {
testParams.manager.AddNode(testNode)
testParams.manager.AddRoutes(pubKey.String(), &mesh.RouteStub{
Destination: destination,
HopCount: 0,
Path: make([]string, 0),
})
updatedNode, err := testParams.manager.GetNode(pubKey.String())
@ -297,7 +296,6 @@ func TestAddRoutesNodeDoesNotExist(t *testing.T) {
err := testParams.manager.AddRoutes("AAAAA", &mesh.RouteStub{
Destination: destination,
HopCount: 0,
Path: make([]string, 0),
})

View File

@ -3,13 +3,16 @@ package automerge
import (
"fmt"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/mesh"
)
// CrdtProviderFactory: abstracts the instantiation of an automerge
// datastore
type CrdtProviderFactory struct{}
// CreateMesh: create a new mesh datastore
func (f *CrdtProviderFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) {
return NewCrdtNodeManager(&NewCrdtNodeMangerParams{
MeshId: params.MeshId,
@ -19,11 +22,12 @@ func (f *CrdtProviderFactory) CreateMesh(params *mesh.MeshProviderFactoryParams)
})
}
// MeshNodeFactory: abstracts the instnatiation of a node
type MeshNodeFactory struct {
Config conf.DaemonConfiguration
}
// Build builds the mesh node that represents the host machine to add
// Build: builds the mesh node that represents the host machine to add
// to the mesh
func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode {
hostName := f.getAddress(params)
@ -48,7 +52,7 @@ func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNod
}
}
// getAddress returns the routable address of the machine.
// getAddress: returns the routable address of the machine.
func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string {
var hostName string = ""
@ -59,7 +63,7 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string
} else {
ipFunc := lib.GetPublicIP
if *params.MeshConfig.IPDiscovery == conf.DNS_IP_DISCOVERY {
if *params.MeshConfig.IPDiscovery == conf.OUTGOING_IP_DISCOVERY {
ipFunc = lib.GetOutboundIP
}

View File

@ -6,10 +6,12 @@ import (
"strings"
)
// CmdRunner: run cmd commands when instantiating a network
type CmdRunner interface {
RunCommands(commands ...string) error
}
// UnixCmdRunner: Run UNIX commands
type UnixCmdRunner struct{}
// RunCommand: runs the unix command. It splits the command into fields
@ -20,6 +22,7 @@ func RunCommand(cmd string) error {
return c.Run()
}
// RunCommands: run a series of commands
func (l *UnixCmdRunner) RunCommands(commands ...string) error {
for _, cmd := range commands {
err := RunCommand(cmd)

View File

@ -8,14 +8,7 @@ import (
"gopkg.in/yaml.v3"
)
type WgMeshConfigurationError struct {
msg string
}
func (m *WgMeshConfigurationError) Error() string {
return m.msg
}
// NodeType types of the node either peer or client
type NodeType string
const (
@ -23,11 +16,23 @@ const (
CLIENT_ROLE NodeType = "client"
)
// IPDiscovery: what IPDiscovery service to use
type IPDiscovery string
const (
// Public IP use an IP service to discover your IP
PUBLIC_IP_DISCOVERY IPDiscovery = "public"
DNS_IP_DISCOVERY IPDiscovery = "dns"
// Outgonig: Use your labelled packet IP
OUTGOING_IP_DISCOVERY IPDiscovery = "outgoing"
)
// Loglevel: what log level to use either error info or warning
type LogLevel string
const (
ERROR LogLevel = "error"
WARNING LogLevel = "warning"
INFO LogLevel = "info"
)
// WgConfiguration contains per-mesh WireGuard configuration. Contains poitner types only so we can
@ -35,7 +40,7 @@ const (
type WgConfiguration struct {
// IPDIscovery: how to discover your IP if not specified. Use your outgoing IP or use a public
// service for IPDiscoverability
IPDiscovery *IPDiscovery `yaml:"ipDiscovery" validate:"required,eq=public|eq=dns"`
IPDiscovery *IPDiscovery `yaml:"ipDiscovery" validate:"required,eq=public|eq=outgoing"`
// AdvertiseRoutes: specifies whether the node can act as a router routing packets between meshes
AdvertiseRoutes *bool `yaml:"advertiseRoute" validate:"required"`
// AdvertiseDefaultRoute: specifies whether or not this route should advertise a default route
@ -47,7 +52,6 @@ type WgConfiguration struct {
// If the user is globaly accessible they specify themselves as a client.
Role *NodeType `yaml:"role" validate:"required,eq=client|eq=peer"`
// KeepAliveWg configures the implementation so that we send keep alive packets to peers.
// KeepAlive can only be set if role is type client
KeepAliveWg *int `yaml:"keepAliveWg" validate:"omitempty,gte=0"`
// PreUp are WireGuard commands to run before adding the WG interface
PreUp []string `yaml:"preUp"`
@ -77,22 +81,26 @@ type DaemonConfiguration struct {
Profile bool `yaml:"profile"`
// StubWg whether or not to stub the WireGuard types
StubWg bool `yaml:"stubWg"`
// SyncRate specifies how long the minimum time should be between synchronisation
SyncRate int `yaml:"syncRate" validate:"required,gte=1"`
// KeepAliveTime: number of seconds before the leader of the mesh sends an update to
// SyncInterval specifies how long the minimum time should be between synchronisation
SyncInterval int `yaml:"syncInterval" validate:"required,gte=1"`
// PullInterval specifies the interval between checking for configuration changes
PullInterval int `yaml:"pullInterval" validate:"gte=0"`
// Heartbeat: number of seconds before the leader of the mesh sends an update to
// send to every member in the mesh
KeepAliveTime int `yaml:"keepAliveTime" validate:"required,gte=1"`
Heartbeat int `yaml:"heartbeatInterval" validate:"required,gte=1"`
// ClusterSize specifies how many neighbours you should synchronise with per round
ClusterSize int `yaml:"clusterSize" validate:"gte=1"`
// InterClusterChance specifies the probabilityof inter-cluster communication in a sync round
InterClusterChance float64 `yaml:"interClusterChance" validate:"gt=0"`
// BranchRate specifies the number of nodes to synchronise with when a node has
// Branch specifies the number of nodes to synchronise with when a node has
// new changes to send to the mesh
BranchRate int `yaml:"branchRate" validate:"required,gte=1"`
Branch int `yaml:"branch" validate:"required,gte=1"`
// InfectionCount: number of time to sync before an update can no longer be 'caught'
InfectionCount int `yaml:"infectionCount" validate:"required,gte=1"`
// BaseConfiguration base WireGuard configuration to use, this is used when none is provided
BaseConfiguration WgConfiguration `yaml:"baseConfiguration" validate:"required"`
// LogLevel specifies the log level to output, defaults is warning
LogLevel LogLevel `yaml:"logLevel" validate:"eq=info|eq=warning|eq=error"`
}
// ValdiateMeshConfiguration: validates the mesh configuration
@ -120,32 +128,21 @@ func ValidateMeshConfiguration(conf *WgConfiguration) error {
}
// ValidateDaemonConfiguration: validates the dameon configuration that is used.
func ValidateDaemonConfiguration(c *DaemonConfiguration) error {
func ValidateDaemonConfiguration(conf *DaemonConfiguration) error {
if conf.BaseConfiguration.KeepAliveWg == nil {
var keepAlive int = 0
conf.BaseConfiguration.KeepAliveWg = &keepAlive
}
if conf.LogLevel == "" {
conf.LogLevel = WARNING
}
validate := validator.New(validator.WithRequiredStructEnabled())
err := validate.Struct(c)
err := validate.Struct(conf)
return err
}
// ParseMeshConfiguration: parses the mesh network configuration. Parses parameters such as
// keepalive time, role and so forth.
func ParseMeshConfiguration(filePath string) (*WgConfiguration, error) {
var conf WgConfiguration
yamlBytes, err := os.ReadFile(filePath)
if err != nil {
return nil, err
}
err = yaml.Unmarshal(yamlBytes, &conf)
if err != nil {
return nil, err
}
return &conf, ValidateMeshConfiguration(&conf)
}
// ParseDaemonConfiguration parses the mesh configuration and validates the configuration
func ParseDaemonConfiguration(filePath string) (*DaemonConfiguration, error) {
var conf DaemonConfiguration

View File

@ -21,11 +21,12 @@ func getExampleConfiguration() *DaemonConfiguration {
Timeout: 5,
Profile: false,
StubWg: false,
SyncRate: 2,
KeepAliveTime: 2,
SyncInterval: 2,
Heartbeat: 2,
ClusterSize: 64,
InterClusterChance: 0.15,
BranchRate: 3,
Branch: 3,
PullInterval: 0,
InfectionCount: 2,
BaseConfiguration: WgConfiguration{
IPDiscovery: &discovery,
@ -153,7 +154,7 @@ func TestRoleTypeNotSpecified(t *testing.T) {
func TestBranchRateZero(t *testing.T) {
conf := getExampleConfiguration()
conf.BranchRate = 0
conf.Branch = 0
err := ValidateDaemonConfiguration(conf)
@ -162,9 +163,9 @@ func TestBranchRateZero(t *testing.T) {
}
}
func TestSyncRateZero(t *testing.T) {
func TestsyncTimeZero(t *testing.T) {
conf := getExampleConfiguration()
conf.SyncRate = 0
conf.SyncInterval = 0
err := ValidateDaemonConfiguration(conf)
@ -175,7 +176,7 @@ func TestSyncRateZero(t *testing.T) {
func TestKeepAliveTimeZero(t *testing.T) {
conf := getExampleConfiguration()
conf.KeepAliveTime = 0
conf.Heartbeat = 0
err := ValidateDaemonConfiguration(conf)
if err == nil {
@ -215,6 +216,17 @@ func TestInfectionCountOne(t *testing.T) {
}
}
func TestPullTimeNegative(t *testing.T) {
conf := getExampleConfiguration()
conf.PullInterval = -1
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestValidConfiguration(t *testing.T) {
conf := getExampleConfiguration()
err := ValidateDaemonConfiguration(conf)

View File

@ -7,25 +7,30 @@ import (
"slices"
)
// ConnCluster splits nodes into clusters where nodes in a cluster communicate
// ConnCluster: splits nodes into clusters where nodes in a cluster communicate
// frequently and nodes outside of a cluster communicate infrequently
type ConnCluster interface {
// Getneighbours: get neighbours of the cluster the node is in
GetNeighbours(global []string, selfId string) []string
// GetInterCluster: get the cluster to communicate with
GetInterCluster(global []string, selfId string) string
}
// ConnnClusterImpl: implementation of the connection cluster
type ConnClusterImpl struct {
clusterSize int
}
// perform binary search to attain a size of a group
func binarySearch(global []string, selfId string, groupSize int) (int, int) {
slices.Sort(global)
lower := 0
higher := len(global) - 1
mid := (lower + higher) / 2
for (higher+1)-lower > groupSize {
mid := (lower + higher) / 2
if global[mid] < selfId {
lower = mid + 1
} else if global[mid] > selfId {
@ -33,14 +38,12 @@ func binarySearch(global []string, selfId string, groupSize int) (int, int) {
} else {
break
}
mid = (lower + higher) / 2
}
return lower, int(math.Min(float64(lower+groupSize), float64(len(global))))
}
// GetNeighbours return the neighbours 'nearest' to you. In this implementation the
// GetNeighbours: return the neighbours 'nearest' to you. In this implementation the
// neighbours aren't actually the ones nearest to you but just the ones nearest
// to you alphabetically. Perform binary search to get the total group
func (i *ConnClusterImpl) GetNeighbours(global []string, selfId string) []string {
@ -51,7 +54,7 @@ func (i *ConnClusterImpl) GetNeighbours(global []string, selfId string) []string
return global[lower:higher]
}
// GetInterCluster get nodes not in your cluster. Every round there is a given chance
// GetInterCluster: get nodes not in your cluster. Every round there is a given chance
// you will communicate with a random node that is not in your cluster.
func (i *ConnClusterImpl) GetInterCluster(global []string, selfId string) string {
// Doesn't matter if not in it. Get index of where the node 'should' be
@ -66,6 +69,7 @@ func (i *ConnClusterImpl) GetInterCluster(global []string, selfId string) string
return global[neighbourIndex]
}
// NewConnCluster: instantiate a new connection cluster of a given group size.
func NewConnCluster(clusterSize int) (ConnCluster, error) {
log2Cluster := math.Log2(float64(clusterSize))

View File

@ -6,7 +6,7 @@ import (
"crypto/tls"
"errors"
logging "github.com/tim-beatham/wgmesh/pkg/log"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
@ -18,6 +18,7 @@ type PeerConnection interface {
GetClient() (*grpc.ClientConn, error)
}
// PeerConenctionFactory: create a new connection to a peer
type PeerConnectionFactory = func(clientConfig *tls.Config, server string) (PeerConnection, error)
// WgCtrlConnection implements PeerConnection.

View File

@ -7,7 +7,7 @@ import (
"os"
"sync"
logging "github.com/tim-beatham/wgmesh/pkg/log"
logging "github.com/tim-beatham/smegmesh/pkg/log"
)
// ConnectionManager defines an interface for maintaining peer connections
@ -19,9 +19,11 @@ type ConnectionManager interface {
// If the endpoint does not exist then add the connection. Returns an error
// if something went wrong
GetConnection(endPoint string) (PeerConnection, error)
// HasConnections returns true if a client has already registered at the givne
// HasConnections returns true if a peer has already registered at the given
// endpoint or false otherwise.
HasConnection(endPoint string) bool
// Removes a connection if it exists
RemoveConnection(endPoint string) error
// Goes through all the connections and closes eachone
Close() error
}
@ -32,7 +34,6 @@ type ConnectionManagerImpl struct {
// clientConnections maps an endpoint to a connection
conLoc sync.RWMutex
clientConnections map[string]PeerConnection
serverConfig *tls.Config
clientConfig *tls.Config
connFactory PeerConnectionFactory
}
@ -61,37 +62,25 @@ func NewConnectionManager(params *NewConnectionManagerParams) (ConnectionManager
return nil, err
}
serverAuth := tls.RequireAndVerifyClientCert
if params.SkipCertVerification {
serverAuth = tls.RequireAnyClientCert
}
certPool := x509.NewCertPool()
if !params.SkipCertVerification {
if params.CaCert == "" {
return nil, errors.New("CA Cert is not specified")
}
caCert, err := os.ReadFile(params.CaCert)
if err != nil {
return nil, err
}
certPool.AppendCertsFromPEM(caCert)
if params.CaCert == "" {
return nil, errors.New("CA Cert is not specified")
}
serverConfig := &tls.Config{
ClientAuth: serverAuth,
Certificates: []tls.Certificate{cert},
caCert, err := os.ReadFile(params.CaCert)
if err != nil {
return nil, err
}
if ok := certPool.AppendCertsFromPEM(caCert); !ok {
return nil, errors.New("could not parse PEM")
}
clientConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: params.SkipCertVerification,
Certificates: []tls.Certificate{cert},
RootCAs: certPool,
}
@ -99,7 +88,6 @@ func NewConnectionManager(params *NewConnectionManagerParams) (ConnectionManager
connMgr := ConnectionManagerImpl{
sync.RWMutex{},
connections,
serverConfig,
clientConfig,
params.ConnFactory,
}
@ -150,6 +138,15 @@ func (m *ConnectionManagerImpl) HasConnection(endPoint string) bool {
return exists
}
// RemoveConnection removes the given connection if it exists
func (m *ConnectionManagerImpl) RemoveConnection(endPoint string) error {
m.conLoc.Lock()
err := m.clientConnections[endPoint].Close()
delete(m.clientConnections, endPoint)
m.conLoc.Unlock()
return err
}
func (m *ConnectionManagerImpl) Close() error {
for _, conn := range m.clientConnections {
if err := conn.Close(); err != nil {

View File

@ -53,13 +53,13 @@ func TestNewConnectionManagerCACertDoesNotExistAndVerify(t *testing.T) {
func TestNewConnectionManagerCACertDoesNotExistAndNotVerify(t *testing.T) {
params := getConnectionManagerParams()
params.CaCert = ""
params.CaCert = "./cert/sdjsdjsdjk.pem"
params.SkipCertVerification = true
_, err := NewConnectionManager(params)
if err != nil {
t.Fatal(`an error should not be thrown`)
if err == nil {
t.Fatalf(`an error should be thrown`)
}
}

View File

@ -2,22 +2,23 @@ package conn
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"os"
"github.com/tim-beatham/wgmesh/pkg/conf"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/rpc"
"github.com/tim-beatham/smegmesh/pkg/conf"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/rpc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
// ConnectionServer manages gRPC server peer connections
type ConnectionServer struct {
// tlsConfiguration of the server
serverConfig *tls.Config
// server an instance of the grpc server
server *grpc.Server // the authentication service to authenticate nodes
server *grpc.Server
// the ctrl service to manage node
ctrlProvider rpc.MeshCtrlServerServer
// the sync service to synchronise nodes
@ -48,9 +49,26 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
serverAuth = tls.RequireAnyClientCert
}
certPool := x509.NewCertPool()
if params.Conf.CaCertificatePath == "" {
return nil, errors.New("CA Cert is not specified")
}
caCert, err := os.ReadFile(params.Conf.CaCertificatePath)
if err != nil {
return nil, err
}
if ok := certPool.AppendCertsFromPEM(caCert); !ok {
return nil, errors.New("could not parse PEM")
}
serverConfig := &tls.Config{
ClientAuth: serverAuth,
Certificates: []tls.Certificate{cert},
ClientCAs: certPool,
}
server := grpc.NewServer(
@ -61,7 +79,6 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
syncProvider := params.SyncProvider
connServer := ConnectionServer{
serverConfig: serverConfig,
server: server,
ctrlProvider: ctrlProvider,
syncProvider: syncProvider,
@ -74,7 +91,6 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
// Listen for incoming requests. Returns an error if something went wrong.
func (s *ConnectionServer) Listen() error {
rpc.RegisterMeshCtrlServerServer(s.server, s.ctrlProvider)
rpc.RegisterSyncServiceServer(s.server, s.syncProvider)
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", s.Conf.GrpcPort))

View File

@ -16,6 +16,11 @@ func (s *ConnectionManagerStub) AddConnection(endPoint string) (PeerConnection,
return mock, nil
}
func (s *ConnectionManagerStub) RemoveConnection(endPoint string) error {
delete(s.Endpoints, endPoint)
return nil
}
func (s *ConnectionManagerStub) GetConnection(endPoint string) (PeerConnection, error) {
endpoint, ok := s.Endpoints[endPoint]

View File

@ -9,17 +9,20 @@ import (
"strings"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// Route: represents a route within the data store
type Route struct {
// Destination the route is advertising
Destination string
Path []string
// Path to the destination
Path []string
}
// GetDestination implements mesh.Route.
@ -158,8 +161,8 @@ type TwoPhaseStoreMeshManager struct {
IfName string
Client *wgctrl.Client
LastClock uint64
conf *conf.WgConfiguration
daemonConf *conf.DaemonConfiguration
Conf *conf.WgConfiguration
DaemonConf *conf.DaemonConfiguration
store *TwoPhaseMap[string, MeshNode]
}
@ -204,7 +207,6 @@ func (m *TwoPhaseStoreMeshManager) Save() []byte {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(*snapshot)
if err != nil {
@ -249,7 +251,8 @@ func (m *TwoPhaseStoreMeshManager) SaveChanges() {
m.LastClock = clockValue
}
// UpdateTimeStamp: update the timestamp of the given node
// UpdateTimeStamp: update the timestamp of the given node, causes a configuration refresh if the node
// is the leader causing all nodes to update their vector clocks
func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
@ -265,7 +268,7 @@ func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error {
peerToUpdate := peers[0]
if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.daemonConf.KeepAliveTime) {
if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.DaemonConf.Heartbeat) {
m.store.Mark(peerToUpdate)
if len(peers) < 2 {
@ -313,6 +316,8 @@ func (m *TwoPhaseStoreMeshManager) AddRoutes(nodeId string, routes ...mesh.Route
}
}
// Only add nodes on changes. Otherwise the node will advertise new
// information whenever they get new routes
if changes {
m.store.Put(nodeId, node)
}
@ -320,7 +325,7 @@ func (m *TwoPhaseStoreMeshManager) AddRoutes(nodeId string, routes ...mesh.Route
return nil
}
// DeleteRoutes: deletes the routes from the node
// RemoveRoute: deletes the routes from the given node
func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
@ -336,6 +341,7 @@ func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Ro
for _, route := range routes {
changes = true
logging.Log.WriteInfof("deleting: %s", route.GetDestination().String())
delete(node.Routes, route.GetDestination().String())
}
@ -346,12 +352,12 @@ func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Ro
return nil
}
// GetSyncer: returns the automerge syncer for sync
// GetSyncer: returns the bi-directionally synchroniser to merge documents
func (m *TwoPhaseStoreMeshManager) GetSyncer() mesh.MeshSyncer {
return NewTwoPhaseSyncer(m)
}
// GetNode get a particular not within the mesh
// GetNode: get a particular not within the mesh network
func (m *TwoPhaseStoreMeshManager) GetNode(nodeId string) (mesh.MeshNode, error) {
if !m.store.Contains(nodeId) {
return nil, fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
@ -379,7 +385,7 @@ func (m *TwoPhaseStoreMeshManager) SetDescription(nodeId string, description str
return nil
}
// SetAlias: set the alias of the nodeId
// SetAlias: set the alias of the given node
func (m *TwoPhaseStoreMeshManager) SetAlias(nodeId string, alias string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
@ -392,7 +398,7 @@ func (m *TwoPhaseStoreMeshManager) SetAlias(nodeId string, alias string) error {
return nil
}
// AddService: adds the service to the given node
// AddService: adds a service to the given node
func (m *TwoPhaseStoreMeshManager) AddService(nodeId string, key string, value string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
@ -404,19 +410,25 @@ func (m *TwoPhaseStoreMeshManager) AddService(nodeId string, key string, value s
return nil
}
// RemoveService: removes the service form the node. throws an error if the service does not exist
// RemoveService: removes the service form a node, throws an error if the service does not exist
func (m *TwoPhaseStoreMeshManager) RemoveService(nodeId string, key string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
node := m.store.Get(nodeId)
if _, ok := node.Services[key]; !ok {
return fmt.Errorf("datastore: node does not contain service %s", key)
}
delete(node.Services, key)
m.store.Put(nodeId, node)
return nil
}
// Prune: prunes all nodes that have not updated their timestamp in
// Prune: prunes all nodes that have not updated their vector clock in a given amount
// of time
func (m *TwoPhaseStoreMeshManager) Prune() error {
m.store.Prune()
return nil
@ -445,6 +457,7 @@ func (m *TwoPhaseStoreMeshManager) GetPeers() []string {
})
}
// getRoutes: get all routes the target node is advertising
func (m *TwoPhaseStoreMeshManager) getRoutes(targetNode string) (map[string]Route, error) {
if !m.store.Contains(targetNode) {
return nil, fmt.Errorf("getRoute: cannot get route %s does not exist", targetNode)
@ -454,7 +467,8 @@ func (m *TwoPhaseStoreMeshManager) getRoutes(targetNode string) (map[string]Rout
return node.Routes, nil
}
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen
// GetRoutes: Get all unique routes the target node is advertising.
// on conflicts the route with the least hop count is chosen
func (m *TwoPhaseStoreMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) {
node, err := m.GetNode(targetNode)
@ -498,7 +512,7 @@ func (m *TwoPhaseStoreMeshManager) GetRoutes(targetNode string) (map[string]mesh
return routes, nil
}
// RemoveNode(): remove the node from the mesh
// RemoveNode: remove the node from the mesh
func (m *TwoPhaseStoreMeshManager) RemoveNode(nodeId string) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
@ -508,7 +522,8 @@ func (m *TwoPhaseStoreMeshManager) RemoveNode(nodeId string) error {
return nil
}
// GetConfiguration implements mesh.MeshProvider.
// GetConfiguration gets the WireGuard configuration to use for this
// network
func (m *TwoPhaseStoreMeshManager) GetConfiguration() *conf.WgConfiguration {
return m.conf
return m.Conf
}

440
pkg/crdt/datastore_test.go Normal file
View File

@ -0,0 +1,440 @@
package crdt
import (
"net"
"slices"
"testing"
"time"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type TestParams struct {
manager mesh.MeshProvider
publicKey *wgtypes.Key
}
func setUpTests() *TestParams {
advertiseRoutes := false
advertiseDefaultRoute := false
role := conf.PEER_ROLE
discovery := conf.OUTGOING_IP_DISCOVERY
factory := &TwoPhaseMapFactory{
Config: &conf.DaemonConfiguration{
CertificatePath: "/somecertificatepath",
PrivateKeyPath: "/someprivatekeypath",
CaCertificatePath: "/somecacertificatepath",
SkipCertVerification: true,
GrpcPort: 0,
Timeout: 20,
Profile: false,
SyncInterval: 2,
Heartbeat: 10,
ClusterSize: 32,
InterClusterChance: 0.15,
Branch: 3,
InfectionCount: 3,
BaseConfiguration: conf.WgConfiguration{
IPDiscovery: &discovery,
AdvertiseRoutes: &advertiseRoutes,
AdvertiseDefaultRoute: &advertiseDefaultRoute,
Role: &role,
},
},
}
key, _ := wgtypes.GeneratePrivateKey()
mesh, _ := factory.CreateMesh(&mesh.MeshProviderFactoryParams{
DevName: "bob",
MeshId: "meshid123",
Client: nil,
Conf: &factory.Config.BaseConfiguration,
DaemonConf: factory.Config,
NodeID: "bob",
})
publicKey := key.PublicKey()
return &TestParams{
manager: mesh,
publicKey: &publicKey,
}
}
func getOurNode(testParams *TestParams) *MeshNode {
return &MeshNode{
HostEndpoint: "public-endpoint:8080",
WgEndpoint: "public-endpoint:21906",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d218/128",
PublicKey: testParams.publicKey.String(),
Timestamp: time.Now().Unix(),
Description: "A node that we are adding",
Type: "peer",
}
}
func getRandomNode() *MeshNode {
key, _ := wgtypes.GeneratePrivateKey()
publicKey := key.PublicKey()
return &MeshNode{
HostEndpoint: "public-endpoint:8081",
WgEndpoint: "public-endpoint:21907",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d234/128",
PublicKey: publicKey.String(),
Timestamp: time.Now().Unix(),
Description: "A node that we are adding",
Type: "peer",
}
}
func TestAddNodeAddsTheNodesToTheStore(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
if !testParams.manager.NodeExists(testParams.publicKey.String()) {
t.Fatalf(`node %s should have been added to the mesh network`, testParams.publicKey.String())
}
}
func TestAddNodeNodeAlreadyExistsReplacesTheNode(t *testing.T) {
TestAddNodeAddsTheNodesToTheStore(t)
TestAddNodeAddsTheNodesToTheStore(t)
}
func TestSaveThenLoad(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
bytes := testParams.manager.Save()
if err := testParams.manager.Load(bytes); err != nil {
t.Fatalf(`error caused by loading datastore: %s`, err.Error())
}
}
func TestHasChangesReturnsTrueWhenThereAreChangesInTheMesh(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
if !testParams.manager.HasChanges() {
t.Fatalf(`mesh has change but HasChanges returned false`)
}
testParams.manager.SetDescription(testParams.publicKey.String(), "Bob marley")
if !testParams.manager.HasChanges() {
t.Fatalf(`mesh has change but HasChanges returned false`)
}
testParams.manager.SaveChanges()
}
func TestHasChangesWhenThereAreNoChangesInTheMeshReturnsFalse(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.SaveChanges()
if testParams.manager.HasChanges() {
t.Fatalf(`mesh has no changes but HasChanges was true`)
}
testParams.manager.SetDescription(testParams.publicKey.String(), "Bob marley")
testParams.manager.SaveChanges()
if testParams.manager.HasChanges() {
t.Fatalf(`mesh has no changes but HasChanges was true`)
}
}
func TestUpdateTimeStampUpdatesTheTimeStampOfTheGivenNodeIfItIsTheLeader(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
before, _ := testParams.manager.GetNode(testParams.publicKey.String())
time.Sleep(1 * time.Second)
testParams.manager.UpdateTimeStamp(testParams.publicKey.String())
after, _ := testParams.manager.GetNode(testParams.publicKey.String())
if before.GetTimeStamp() >= after.GetTimeStamp() {
t.Fatalf(`before should not be after after`)
}
}
func TestUpdateTimeStampUpdatesTheTimeStampOfTheGivenNodeIfItIsNotLeader(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
newNode := getRandomNode()
newNode.PublicKey = "aaaaaaaaaa"
testParams.manager.AddNode(newNode)
before, _ := testParams.manager.GetNode(testParams.publicKey.String())
time.Sleep(1 * time.Second)
after, _ := testParams.manager.GetNode(testParams.publicKey.String())
if before.GetTimeStamp() != after.GetTimeStamp() {
t.Fatalf(`before and after should be the same`)
}
}
func TestAddRoutesAddsARouteToTheGivenMesh(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
_, destination, _ := net.ParseCIDR("0353:1da7:7f33:acc0:7a3f:6e55:912b:bc1f/64")
testParams.manager.AddRoutes(testParams.publicKey.String(), &mesh.RouteStub{
Destination: destination,
Path: make([]string, 0),
})
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
containsDestination := lib.Contains(node.GetRoutes(), func(r mesh.Route) bool {
return r.GetDestination().Contains(destination.IP)
})
if !containsDestination {
t.Fatalf(`route has not been added to the node`)
}
}
func TestRemoveRoutesWithdrawsRoutesFromTheMesh(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
_, destination, _ := net.ParseCIDR("0353:1da7:7f33:acc0:7a3f:6e55:912b:bc1f/64")
route := &mesh.RouteStub{
Destination: destination,
Path: make([]string, 0),
}
testParams.manager.AddRoutes(testParams.publicKey.String(), route)
testParams.manager.RemoveRoutes(testParams.publicKey.String(), route)
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
containsDestination := lib.Contains(node.GetRoutes(), func(r mesh.Route) bool {
return r.GetDestination().Contains(destination.IP)
})
if containsDestination {
t.Fatalf(`route has not been removed from the node`)
}
}
func TestGetNodeGetsTheNodeWhenItExists(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
if node == nil {
t.Fatalf(`node not found returned nil`)
}
}
func TestGetNodeReturnsNilWhenItDoesNotExist(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.RemoveNode(testParams.publicKey.String())
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
if node != nil {
t.Fatalf(`node found but should be nil`)
}
}
func TestNodeExistsReturnsFalseWhenNotExists(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.RemoveNode(testParams.publicKey.String())
if testParams.manager.NodeExists(testParams.publicKey.String()) {
t.Fatalf(`nodeexists should be false`)
}
}
func TestSetDescriptionReturnsErrorWhenNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.SetDescription("djdjdj", "djdsjkd")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestSetDescriptionSetsTheDescription(t *testing.T) {
testParams := setUpTests()
descriptionToSet := "djdsjkd"
testParams.manager.AddNode(getOurNode(testParams))
err := testParams.manager.SetDescription(testParams.publicKey.String(), descriptionToSet)
if err != nil {
t.Fatalf(`error %s thrown`, err.Error())
}
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
description := node.GetDescription()
if description != descriptionToSet {
t.Fatalf(`description was %s should be %s`, description, descriptionToSet)
}
}
func TestAliasNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.SetAlias("djdjdj", "djdsjkd")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestSetAliasSetsAlias(t *testing.T) {
testParams := setUpTests()
aliasToSet := "djdsjkd"
testParams.manager.AddNode(getOurNode(testParams))
err := testParams.manager.SetAlias(testParams.publicKey.String(), aliasToSet)
if err != nil {
t.Fatalf(`error %s thrown`, err.Error())
}
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
alias := node.GetAlias()
if alias != aliasToSet {
t.Fatalf(`description was %s should be %s`, alias, aliasToSet)
}
}
func TestAddServiceNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.AddService("djdjdj", "djdsjkd", "sddsds")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestAddServiceNodeExists(t *testing.T) {
testParams := setUpTests()
service := "djdsjkd"
serviceValue := "dsdsds"
testParams.manager.AddNode(getOurNode(testParams))
err := testParams.manager.AddService(testParams.publicKey.String(), service, serviceValue)
if err != nil {
t.Fatalf(`error %s thrown`, err.Error())
}
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
services := node.GetServices()
if value, ok := services[service]; !ok || value != serviceValue {
t.Fatalf(`service not added to the data store`)
}
}
func TestRemoveServiceDoesNotExists(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.RemoveService("djdjdj", "dsdssd")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestRemoveServiceServiceDoesNotExist(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
if err := testParams.manager.RemoveService(testParams.publicKey.String(), "dhsdh"); err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestGetPeersReturnsAllPeersInTheMesh(t *testing.T) {
testParams := setUpTests()
peer1 := getRandomNode()
peer2 := getRandomNode()
client := getRandomNode()
client.Type = "client"
testParams.manager.AddNode(peer1)
testParams.manager.AddNode(peer2)
testParams.manager.AddNode(client)
peers := testParams.manager.GetPeers()
slices.Sort(peers)
if len(peers) != 2 {
t.Fatalf(`there should be two peers in the mesh`)
}
peer1Pub, _ := peer1.GetPublicKey()
if !slices.Contains(peers, peer1Pub.String()) {
t.Fatalf(`peer1 not in the list`)
}
peer2Pub, _ := peer2.GetPublicKey()
if !slices.Contains(peers, peer2Pub.String()) {
t.Fatalf(`peer2 not in the list`)
}
}
func TestRemoveNodeReturnsErrorIfNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.RemoveNode("dsjdssjk")
if err == nil {
t.Fatalf(`error should have returned`)
}
}

View File

@ -4,34 +4,39 @@ import (
"fmt"
"hash/fnv"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/mesh"
)
// TwoPhaseMapFactory: instantiate a new twophasemap
// datastore
type TwoPhaseMapFactory struct {
Config *conf.DaemonConfiguration
}
// CreateMesh: create a new mesh network
func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) {
return &TwoPhaseStoreMeshManager{
MeshId: params.MeshId,
IfName: params.DevName,
Client: params.Client,
conf: params.Conf,
daemonConf: params.DaemonConf,
Conf: params.Conf,
DaemonConf: params.DaemonConf,
store: NewTwoPhaseMap[string, MeshNode](params.NodeID, func(s string) uint64 {
h := fnv.New64a()
h.Write([]byte(s))
return h.Sum64()
}, uint64(3*f.Config.KeepAliveTime)),
}, uint64(3*f.Config.Heartbeat)),
}, nil
}
// MeshNodeFactory: create a new node in the mesh network
type MeshNodeFactory struct {
Config conf.DaemonConfiguration
}
// Build: build a new mesh network
func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode {
hostName := f.getAddress(params)
@ -66,7 +71,7 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string
} else {
ipFunc := lib.GetPublicIP
if *params.MeshConfig.IPDiscovery == conf.DNS_IP_DISCOVERY {
if *params.MeshConfig.IPDiscovery == conf.OUTGOING_IP_DISCOVERY {
ipFunc = lib.GetOutboundIP
}

View File

@ -1,4 +1,4 @@
// crdt is a golang implementation of a crdt
// crdt provides go implementations for crdts
package crdt
import (
@ -6,6 +6,7 @@ import (
"sync"
)
// Bucket: bucket represents a value in the grow only map
type Bucket[D any] struct {
Vector uint64
Contents D
@ -19,6 +20,7 @@ type GMap[K cmp.Ordered, D any] struct {
clock *VectorClock[K]
}
// Put: put a new entry in the grow-only-map
func (g *GMap[K, D]) Put(key K, value D) {
g.lock.Lock()
@ -32,6 +34,8 @@ func (g *GMap[K, D]) Put(key K, value D) {
g.lock.Unlock()
}
// Contains: returns whether or not the key is contained
// in the g-map
func (g *GMap[K, D]) Contains(key K) bool {
return g.contains(g.clock.hashFunc(key))
}
@ -64,11 +68,23 @@ func (g *GMap[K, D]) get(key uint64) Bucket[D] {
return bucket
}
// Get: get the value associated with the given key
func (g *GMap[K, D]) Get(key K) D {
if !g.Contains(key) {
var def D
return def
}
return g.get(g.clock.hashFunc(key)).Contents
}
// Mark: marks the node, this means the status of the node
// is an undefined state
func (g *GMap[K, D]) Mark(key K) {
if !g.Contains(key) {
return
}
g.lock.Lock()
bucket := g.contents[g.clock.hashFunc(key)]
bucket.Gravestone = true
@ -76,7 +92,7 @@ func (g *GMap[K, D]) Mark(key K) {
g.lock.Unlock()
}
// IsMarked: returns true if the node is marked
// IsMarked: returns true if the node is marked (in an undefined state)
func (g *GMap[K, D]) IsMarked(key K) bool {
marked := false
@ -89,10 +105,10 @@ func (g *GMap[K, D]) IsMarked(key K) bool {
}
g.lock.RUnlock()
return marked
}
// Keys: return all the keys in the grow-only map
func (g *GMap[K, D]) Keys() []uint64 {
g.lock.RLock()
@ -108,6 +124,7 @@ func (g *GMap[K, D]) Keys() []uint64 {
return contents
}
// Save: saves the grow only map
func (g *GMap[K, D]) Save() map[uint64]Bucket[D] {
buckets := make(map[uint64]Bucket[D])
g.lock.RLock()
@ -120,6 +137,7 @@ func (g *GMap[K, D]) Save() map[uint64]Bucket[D] {
return buckets
}
// SaveWithKeys: get all the values corresponding with the provided keys
func (g *GMap[K, D]) SaveWithKeys(keys []uint64) map[uint64]Bucket[D] {
buckets := make(map[uint64]Bucket[D])
g.lock.RLock()
@ -132,6 +150,7 @@ func (g *GMap[K, D]) SaveWithKeys(keys []uint64) map[uint64]Bucket[D] {
return buckets
}
// GetClock: get all the vector clocks in the g_map
func (g *GMap[K, D]) GetClock() map[uint64]uint64 {
clock := make(map[uint64]uint64)
g.lock.RLock()
@ -144,6 +163,7 @@ func (g *GMap[K, D]) GetClock() map[uint64]uint64 {
return clock
}
// GetHash: get the hash of the g_map representing its state
func (g *GMap[K, D]) GetHash() uint64 {
hash := uint64(0)
@ -157,6 +177,7 @@ func (g *GMap[K, D]) GetHash() uint64 {
return hash
}
// Prune: prune all stale entries
func (g *GMap[K, D]) Prune() {
stale := g.clock.getStale()
g.lock.Lock()

224
pkg/crdt/g_map_test.go Normal file
View File

@ -0,0 +1,224 @@
// crdt_test unit tests the crdt implementations
package crdt
import (
"hash/fnv"
"slices"
"testing"
"time"
"github.com/tim-beatham/smegmesh/pkg/lib"
)
func NewGmap() *GMap[string, bool] {
vectorClock := NewVectorClock("a", func(key string) uint64 {
hash := fnv.New64a()
hash.Write([]byte(key))
return hash.Sum64()
}, 1) // 1 second stale time
gMap := NewGMap[string, bool](vectorClock)
return gMap
}
func TestGMapPutInsertsItems(t *testing.T) {
gMap := NewGmap()
gMap.Put("bruh1234", true)
if !gMap.Contains("bruh1234") {
t.Fatalf(`value not added to map`)
}
}
func TestGMapPutReplacesItems(t *testing.T) {
gMap := NewGmap()
gMap.Put("bruh1234", true)
gMap.Put("bruh1234", false)
value := gMap.Get("bruh1234")
if value {
t.Fatalf(`value should ahve been replaced to false`)
}
}
func TestContainsValueNotPresent(t *testing.T) {
gMap := NewGmap()
if gMap.Contains("sdhjsdhsdj") {
t.Fatalf(`value should not be present in the map`)
}
}
func TestContainsValuePresent(t *testing.T) {
gMap := NewGmap()
key := "hehehehe"
gMap.Put(key, false)
if !gMap.Contains(key) {
t.Fatalf(`%s should not be present in the map`, key)
}
}
func TestGMapGetNotPresentReturnsError(t *testing.T) {
gMap := NewGmap()
value := gMap.Get("bruh123")
if value != false {
t.Fatalf(`value should be default type false`)
}
}
func TestGMapGetReturnsValue(t *testing.T) {
gMap := NewGmap()
gMap.Put("bobdylan", true)
value := gMap.Get("bobdylan")
if !value {
t.Fatalf("value should be true but was false")
}
}
func TestMarkMarksTheValue(t *testing.T) {
gMap := NewGmap()
gMap.Put("hello123", true)
gMap.Mark("hello123")
if !gMap.IsMarked("hello123") {
t.Fatal(`hello123 should be marked`)
}
}
func TestMarkValueNotPresent(t *testing.T) {
gMap := NewGmap()
gMap.Mark("ok123456")
}
func TestKeysMapEmpty(t *testing.T) {
gMap := NewGmap()
keys := gMap.Keys()
if len(keys) != 0 {
t.Fatal(`list of keys was not empty but should be empty`)
}
}
func TestKeysMapReturnsKeysInMap(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
keys := gMap.Keys()
if len(keys) != 3 {
t.Fatal(`key length should be 3`)
}
}
func TestSaveMapEmptyReturnsEmptyMap(t *testing.T) {
gMap := NewGmap()
saveMap := gMap.Save()
if len(saveMap) != 0 {
t.Fatal(`saves should be empty`)
}
}
func TestSaveMapReturnsMapOfBuckets(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
saveMap := gMap.Save()
if len(saveMap) != 3 {
t.Fatalf(`save length should be 3`)
}
}
func TestSaveWithKeysNoKeysReturnsEmptyBucket(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
saveMap := gMap.SaveWithKeys([]uint64{})
if len(saveMap) != 0 {
t.Fatalf(`save map should be empty`)
}
}
func TestSaveWithKeysReturnsIntersection(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
clock := lib.MapKeys(gMap.GetClock())
clock = clock[:len(clock)-1]
values := gMap.SaveWithKeys(clock)
if len(values) != len(clock) {
t.Fatalf(`intersection not returned`)
}
}
func TestGetClockMapEmptyReturnsEmptyClock(t *testing.T) {
gMap := NewGmap()
clocks := gMap.GetClock()
if len(clocks) != 0 {
t.Fatalf(`vector clock is not empty`)
}
}
func TestGetClockReturnsAllCLocks(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
clocks := lib.MapValues(gMap.GetClock())
slices.Sort(clocks)
if !slices.Equal([]uint64{0, 1, 2}, clocks) {
t.Fatalf(`clocks are invalid`)
}
}
func TestGetHashChangesHashOnValueAdded(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
prevHash := gMap.GetHash()
gMap.Put("b", true)
if prevHash == gMap.GetHash() {
t.Fatalf(`hash should be different`)
}
}
func TestPruneGarbageCollectsValuesThatHaveNotBeenUpdated(t *testing.T) {
gMap := NewGmap()
gMap.clock.Put("c", 12)
gMap.Put("c", false)
gMap.Put("a", false)
time.Sleep(4 * time.Second)
gMap.Put("a", true)
gMap.Prune()
if gMap.Contains("c") {
t.Fatalf(`a should have been pruned`)
}
}

View File

@ -3,9 +3,10 @@ package crdt
import (
"cmp"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/lib"
)
// TwoPhaseMap: comprises of two grow-only maps
type TwoPhaseMap[K cmp.Ordered, D any] struct {
addMap *GMap[K, D]
removeMap *GMap[K, bool]
@ -23,7 +24,7 @@ func (m *TwoPhaseMap[K, D]) Contains(key K) bool {
return m.contains(m.Clock.hashFunc(key))
}
// Contains checks whether the value exists in the map
// contains: checks whether the key exists in the map
func (m *TwoPhaseMap[K, D]) contains(key uint64) bool {
if !m.addMap.contains(key) {
return false
@ -40,6 +41,7 @@ func (m *TwoPhaseMap[K, D]) contains(key uint64) bool {
return addValue.Vector >= removeValue.Vector
}
// Get: get the value corresponding with the given key
func (m *TwoPhaseMap[K, D]) Get(key K) D {
var result D
@ -60,18 +62,19 @@ func (m *TwoPhaseMap[K, D]) get(key uint64) D {
return m.addMap.get(key).Contents
}
// Put places the key K in the map
// Put: places the key K in the map with the associated data D
func (m *TwoPhaseMap[K, D]) Put(key K, data D) {
msgSequence := m.Clock.IncrementClock()
m.Clock.Put(key, msgSequence)
m.addMap.Put(key, data)
}
// Mark: marks the status of the node as undetermiend
func (m *TwoPhaseMap[K, D]) Mark(key K) {
m.addMap.Mark(key)
}
// Remove removes the value from the map
// Remove: removes the value from the map
func (m *TwoPhaseMap[K, D]) Remove(key K) {
m.removeMap.Put(key, true)
}
@ -92,6 +95,7 @@ func (m *TwoPhaseMap[K, D]) keys() []uint64 {
return keys
}
// AsList: convert the map to a list
func (m *TwoPhaseMap[K, D]) AsList() []D {
theList := make([]D, 0)
@ -104,6 +108,8 @@ func (m *TwoPhaseMap[K, D]) AsList() []D {
return theList
}
// Snapshot: convert the map into an immutable snapshot.
// contains the contents of the add and remove map
func (m *TwoPhaseMap[K, D]) Snapshot() *TwoPhaseMapSnapshot[K, D] {
return &TwoPhaseMapSnapshot[K, D]{
Add: m.addMap.Save(),
@ -111,6 +117,8 @@ func (m *TwoPhaseMap[K, D]) Snapshot() *TwoPhaseMapSnapshot[K, D] {
}
}
// SnapshotFromState: create a snapshot of the intersection of values provided
// in the given state
func (m *TwoPhaseMap[K, D]) SnapShotFromState(state *TwoPhaseMapState[K]) *TwoPhaseMapSnapshot[K, D] {
addKeys := lib.MapKeys(state.AddContents)
removeKeys := lib.MapKeys(state.RemoveContents)
@ -121,12 +129,18 @@ func (m *TwoPhaseMap[K, D]) SnapShotFromState(state *TwoPhaseMapState[K]) *TwoPh
}
}
// TwoPhaseMapState: encapsulates the state of the map
// without specifying the data that is stored
type TwoPhaseMapState[K cmp.Ordered] struct {
// Vectors: the vector ID of each process
Vectors map[uint64]uint64
// AddContents: the contents of the add map
AddContents map[uint64]uint64
// RemoveContents: the contents of the remove map
RemoveContents map[uint64]uint64
}
// IsMarked: returns true if the given value is marked in an undetermined state
func (m *TwoPhaseMap[K, D]) IsMarked(key K) bool {
return m.addMap.IsMarked(key)
}
@ -151,6 +165,8 @@ func (m *TwoPhaseMap[K, D]) GenerateMessage() *TwoPhaseMapState[K] {
}
}
// Difference: compute the set difference between the two states.
// highestStale represents the highest vector clock that has been marked as stale
func (m *TwoPhaseMapState[K]) Difference(highestStale uint64, state *TwoPhaseMapState[K]) *TwoPhaseMapState[K] {
mapState := &TwoPhaseMapState[K]{
AddContents: make(map[uint64]uint64),
@ -176,6 +192,7 @@ func (m *TwoPhaseMapState[K]) Difference(highestStale uint64, state *TwoPhaseMap
return mapState
}
// Merge: merge a snapshot into the map
func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) {
for key, value := range snapshot.Add {
// Gravestone is local only to that node.
@ -190,6 +207,7 @@ func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) {
}
}
// Prune: garbage collect all stale entries in the map
func (m *TwoPhaseMap[K, D]) Prune() {
m.addMap.Prune()
m.removeMap.Prune()

View File

@ -4,7 +4,7 @@ import (
"bytes"
"encoding/gob"
logging "github.com/tim-beatham/wgmesh/pkg/log"
logging "github.com/tim-beatham/smegmesh/pkg/log"
)
type SyncState int

View File

@ -0,0 +1,214 @@
package crdt
import (
"hash/fnv"
"slices"
"testing"
)
func NewMap(processId string) *TwoPhaseMap[string, string] {
theMap := NewTwoPhaseMap[string, string](processId, func(key string) uint64 {
hash := fnv.New64a()
hash.Write([]byte(key))
return hash.Sum64()
}, 1)
return theMap
}
func TestTwoPhaseMapEmpty(t *testing.T) {
theMap := NewMap("a")
if theMap.Contains("a") {
t.Fatalf(`a should not be present in the map`)
}
}
func TestTwoPhaseMapValuePresent(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "")
if !theMap.Contains("a") {
t.Fatalf(`should be present within the map`)
}
}
func TestTwoPhaseMapValueNotPresent(t *testing.T) {
theMap := NewMap("a")
theMap.Put("b", "")
if theMap.Contains("a") {
t.Fatalf(`a should not be present in the map`)
}
}
func TestTwoPhaseMapPutThenRemove(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "")
theMap.Remove("a")
if theMap.Contains("a") {
t.Fatalf(`a should not be present within the map`)
}
}
func TestTwoPhaseMapPutThenRemoveThenPut(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "")
theMap.Remove("a")
theMap.Put("a", "")
if !theMap.Contains("a") {
t.Fatalf(`a should be present within the map`)
}
}
func TestMarkMarksTheValueIn2PMap(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "")
theMap.Mark("a")
if !theMap.IsMarked("a") {
t.Fatalf(`a should be marked`)
}
}
func TestAsListReturnsItemsInList(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "bob")
theMap.Put("b", "dylan")
keys := theMap.AsList()
slices.Sort(keys)
if !slices.Equal([]string{"bob", "dylan"}, keys) {
t.Fatalf(`values should be bob, dylan`)
}
}
func TestSnapShotRemoveMapEmpty(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "bob")
theMap.Put("b", "dylan")
snapshot := theMap.Snapshot()
if len(snapshot.Add) != 2 {
t.Fatalf(`add values length should be 2`)
}
if len(snapshot.Remove) != 0 {
t.Fatalf(`remove map length should be 0`)
}
}
func TestSnapshotMapEmpty(t *testing.T) {
theMap := NewMap("a")
snapshot := theMap.Snapshot()
if len(snapshot.Add) != 0 || len(snapshot.Remove) != 0 {
t.Fatalf(`snapshot length should be 0`)
}
}
func TestSnapShotFromStateReturnsIntersection(t *testing.T) {
map1 := NewMap("a")
map1.Put("a", "heyy")
map2 := NewMap("b")
map2.Put("b", "hmmm")
message := map2.GenerateMessage()
snapShot := map1.SnapShotFromState(message)
if len(snapShot.Add) != 1 {
t.Fatalf(`add length should be 1`)
}
if len(snapShot.Remove) != 0 {
t.Fatalf(`remove length should be 0`)
}
}
func TestGetHashDifferentOnChange(t *testing.T) {
theMap := NewMap("a")
prevHash := theMap.GetHash()
theMap.Put("b", "hmmhmhmh")
if prevHash == theMap.GetHash() {
t.Fatalf(`hashes should not be the same`)
}
}
func TestGenerateMessageReturnsClocks(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "hmm")
theMap.Put("b", "hmm")
theMap.Remove("a")
message := theMap.GenerateMessage()
if len(message.AddContents) != 2 {
t.Fatalf(`two items added add should be 2`)
}
if len(message.RemoveContents) != 1 {
t.Fatalf(`a was removed remove map should be length 1`)
}
}
func TestDifferenceReturnsDifferenceOfMaps(t *testing.T) {
map1 := NewMap("a")
map1.Put("a", "ssms")
map1.Put("b", "sdmdsmd")
map2 := NewMap("b")
map2.Put("d", "eek")
map2.Put("c", "meh")
message1 := map1.GenerateMessage()
message2 := map2.GenerateMessage()
difference := message1.Difference(0, message2)
if len(difference.AddContents) != 2 {
t.Fatalf(`d and c are not in map1 they should be in add contents`)
}
if len(difference.RemoveContents) != 0 {
t.Fatalf(`remove should be empty`)
}
}
func TestMergeMergesValuesThatAreGreaterThanCurrentClock(t *testing.T) {
map1 := NewMap("a")
map1.Put("a", "ssms")
map1.Put("b", "sdmdsmd")
map2 := NewMap("b")
map2.Put("d", "eek")
map2.Put("c", "meh")
message1 := map1.GenerateMessage()
message2 := map2.GenerateMessage()
difference := message1.Difference(0, message2)
state := map2.SnapShotFromState(difference)
map1.Merge(*state)
if !map1.Contains("d") {
t.Fatalf(`d should be in the map`)
}
if !map2.Contains("c") {
t.Fatalf(`c should be in the map`)
}
}

View File

@ -5,9 +5,12 @@ import (
"sync"
"time"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/lib"
)
// VectorBucket: represents a vector clock in the bucket
// recording both the time changes were last seen
// and when the lastUpdate epoch was recorded
type VectorBucket struct {
// clock current value of the node's clock
clock uint64
@ -15,8 +18,9 @@ type VectorBucket struct {
lastUpdate uint64
}
// Vector clock defines an abstract data type
// for a vector clock implementation
// VectorClock: defines an abstract data type
// for a vector clock implementation. Including a mechanism to
// garbage collect stale entries
type VectorClock[K cmp.Ordered] struct {
vectors map[uint64]*VectorBucket
lock sync.RWMutex
@ -62,6 +66,7 @@ func (m *VectorClock[K]) GetHash() uint64 {
return hash
}
// Merge: merge two clocks together
func (m *VectorClock[K]) Merge(vectors map[uint64]uint64) {
for key, value := range vectors {
m.put(key, value)
@ -97,6 +102,7 @@ func (m *VectorClock[K]) GetStaleCount() uint64 {
return staleCount
}
// Prune: prunes all stale entries in the vector clock
func (m *VectorClock[K]) Prune() {
stale := m.getStale()
@ -109,6 +115,8 @@ func (m *VectorClock[K]) Prune() {
m.lock.Unlock()
}
// GetTimeStamp: get the last time the node was updated in UNIX
// epoch time
func (m *VectorClock[K]) GetTimestamp(processId K) uint64 {
m.lock.RLock()
@ -118,6 +126,8 @@ func (m *VectorClock[K]) GetTimestamp(processId K) uint64 {
return lastUpdate
}
// Put: places the key with vector clock in the clock of the given
// process
func (m *VectorClock[K]) Put(key K, value uint64) {
m.put(m.hashFunc(key), value)
}
@ -133,7 +143,8 @@ func (m *VectorClock[K]) put(key uint64, value uint64) {
}
// Make sure that entries that were garbage collected don't get
// addded back
// highestStale represents the highest vector clock that has been
// invalidated
if value > clockValue && value > m.highestStale {
newBucket := VectorBucket{
clock: value,
@ -145,6 +156,7 @@ func (m *VectorClock[K]) put(key uint64, value uint64) {
m.lock.Unlock()
}
// GetClock: serialize the vector clock into an immutable map
func (m *VectorClock[K]) GetClock() map[uint64]uint64 {
clock := make(map[uint64]uint64)

View File

@ -1,16 +1,17 @@
package ctrlserver
import (
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/crdt"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query"
"github.com/tim-beatham/wgmesh/pkg/rpc"
"github.com/tim-beatham/wgmesh/pkg/wg"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/smegmesh/pkg/crdt"
"github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/query"
"github.com/tim-beatham/smegmesh/pkg/rpc"
"github.com/tim-beatham/smegmesh/pkg/sync"
"github.com/tim-beatham/smegmesh/pkg/wg"
"golang.zx2c4.com/wireguard/wgctrl"
)
@ -21,7 +22,6 @@ type NewCtrlServerParams struct {
CtrlProvider rpc.MeshCtrlServerServer
SyncProvider rpc.SyncServiceServer
Querier query.Querier
OnDelete func(mesh.MeshProvider)
}
// Create a new instance of the MeshCtrlServer or error if the
@ -34,12 +34,16 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
nodeFactory := &crdt.MeshNodeFactory{
Config: *params.Conf,
}
idGenerator := &lib.IDNameGenerator{}
idGenerator := &lib.ShortIDGenerator{}
ipAllocator := &ip.ULABuilder{}
interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client)
ctrlServer.timers = make([]*lib.Timer, 0)
configApplyer := mesh.NewWgMeshConfigApplyer()
var syncer sync.Syncer
meshManagerParams := &mesh.NewMeshManagerParams{
Conf: *params.Conf,
Client: params.Client,
@ -49,7 +53,13 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
IPAllocator: ipAllocator,
InterfaceManipulator: interfaceManipulator,
ConfigApplyer: configApplyer,
OnDelete: params.OnDelete,
OnDelete: func(mesh mesh.MeshProvider) {
_, err := syncer.Sync(mesh)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
},
}
ctrlServer.MeshManager = mesh.NewMeshManager(meshManagerParams)
@ -83,13 +93,40 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
return nil, err
}
syncer = sync.NewSyncer(&sync.NewSyncerParams{
MeshManager: ctrlServer.MeshManager,
ConnectionManager: ctrlServer.ConnectionManager,
Configuration: params.Conf,
})
// Check any syncs every 1 second
syncTimer := lib.NewTimer(func() error {
err = syncer.SyncMeshes()
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
return nil
}, 1)
heartbeatTimer := lib.NewTimer(func() error {
logging.Log.WriteInfof("checking heartbeat")
return ctrlServer.MeshManager.UpdateTimeStamp()
}, params.Conf.Heartbeat)
ctrlServer.timers = append(ctrlServer.timers, syncTimer, heartbeatTimer)
ctrlServer.Querier = query.NewJmesQuerier(ctrlServer.MeshManager)
ctrlServer.ConnectionServer = connServer
for _, timer := range ctrlServer.timers {
go timer.Run()
}
return ctrlServer, nil
}
func (s *MeshCtrlServer) GetConfiguration() *conf.DaemonConfiguration {
return s.Conf
}
@ -124,5 +161,13 @@ func (s *MeshCtrlServer) Close() error {
logging.Log.WriteErrorf(err.Error())
}
for _, timer := range s.timers {
err := timer.Stop()
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
}
return nil
}

View File

@ -4,21 +4,23 @@ import (
"net"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/query"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// MeshRoute: represents a route in the mesh that is
// available to client applications
type MeshRoute struct {
Destination string
Path []string
}
// Represents the WireGuard configuration attached to the node
// WireGuardStats: Represents the WireGuard configuration attached to the node
type WireGuardStats struct {
AllowedIPs []string
TransmitBytes int64
@ -26,7 +28,8 @@ type WireGuardStats struct {
PersistentKeepAliveInterval time.Duration
}
// Represents a WireGuard MeshNode
// MeshNode: represents a node in the WireGuard mesh that can be
// sent to ip chandlers
type MeshNode struct {
HostEndpoint string
WgEndpoint string
@ -40,12 +43,13 @@ type MeshNode struct {
Stats WireGuardStats
}
// Represents a WireGuard Mesh
// Mesh: Represents a WireGuard Mesh network that can be sent
// along ipc to client frameworks
type Mesh struct {
SharedKey *wgtypes.Key
Nodes map[string]MeshNode
}
// CtrlServer: Encapsulates th ctrlserver
type CtrlServer interface {
GetConfiguration() *conf.DaemonConfiguration
GetClient() *wgctrl.Client
@ -55,7 +59,7 @@ type CtrlServer interface {
GetConnectionManager() conn.ConnectionManager
}
// Represents a ctrlserver to be used in WireGuard
// MeshCtrlServer: Represents a ctrlserver to be used in WireGuard
type MeshCtrlServer struct {
Client *wgctrl.Client
MeshManager mesh.MeshManager
@ -63,6 +67,7 @@ type MeshCtrlServer struct {
ConnectionServer *conn.ConnectionServer
Conf *conf.DaemonConfiguration
Querier query.Querier
timers []*lib.Timer
}
// NewCtrlNode create an instance of a ctrl node to send over an

View File

@ -1,10 +1,10 @@
package ctrlserver
import (
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/query"
"golang.zx2c4.com/wireguard/wgctrl"
)

View File

@ -1,24 +1,22 @@
// smegdns: example of how to implement dns in the mesh
package smegdns
import (
"encoding/json"
"fmt"
"net"
"net/rpc"
"github.com/miekg/dns"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/query"
"github.com/tim-beatham/smegmesh/pkg/ipc"
"github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/query"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
const MeshRegularExpression = `(?P<meshId>.+)\.(?P<alias>.+)\.smeg\.`
type DNSHandler struct {
client *rpc.Client
client *ipc.SmegmeshIpc
server *dns.Server
}
@ -27,7 +25,7 @@ type DNSHandler struct {
func (d *DNSHandler) queryMesh(meshId, alias string) net.IP {
var reply string
err := d.client.Call("IpcHandler.Query", &ipc.QueryMesh{
err := d.client.Query(ipc.QueryMesh{
MeshId: meshId,
Query: fmt.Sprintf("[?alias == '%s'] | [0]", alias),
}, &reply)
@ -48,6 +46,7 @@ func (d *DNSHandler) queryMesh(meshId, alias string) net.IP {
return ip
}
// handleQuery: handles a DNS query
func (d *DNSHandler) handleQuery(m *dns.Msg) {
for _, q := range m.Question {
switch q.Qtype {
@ -75,6 +74,7 @@ func (d *DNSHandler) handleQuery(m *dns.Msg) {
}
}
// handleDNS query: handle a DNS request
func (h *DNSHandler) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
msg := new(dns.Msg)
msg.SetReply(r)
@ -97,7 +97,7 @@ func (h *DNSHandler) Close() error {
}
func NewDns(udpPort int) (*DNSHandler, error) {
client, err := rpc.DialHTTP("unix", SockAddr)
client, err := ipc.NewClientIpc()
if err != nil {
return nil, err

249
pkg/dot/dot.go Normal file
View File

@ -0,0 +1,249 @@
// Graph allows the definition of a DOT graph in golang
package graph
import (
"fmt"
"hash/fnv"
"strings"
"github.com/tim-beatham/smegmesh/pkg/lib"
)
type GraphType string
type Shape string
const (
GRAPH GraphType = "graph"
DIGRAPH GraphType = "digraph"
)
const (
CIRCLE Shape = "circle"
STAR Shape = "star"
HEXAGON Shape = "hexagon"
PARALLELOGRAM Shape = "parallelogram"
)
type Graph interface {
Dottable
GetType() GraphType
}
// Cluster: represents a subgraph in the graphs
type Cluster struct {
Type GraphType
Name string
Label string
nodes map[string]*Node
edges map[string]Edge
}
// RootGraph: Represents the top level graph
type RootGraph struct {
Type GraphType
Label string
nodes map[string]*Node
clusters map[string]*Cluster
edges map[string]Edge
}
// Node: represents a graphviz not
type Node struct {
Name string
Label string
Shape Shape
Size int
}
// Edge: represents an edge between adjacent nodes
type Edge interface {
Dottable
}
// DirectEdge: contains a directed edge between any two nodes
type DirectedEdge struct {
Name string
Label string
From string
To string
}
// UndirectedEdge: contains an undirected edge between any two
// nodes
type UndirectedEdge struct {
Name string
Label string
From string
To string
}
// Dottable means an implementer can convert the struct to DOT representation
type Dottable interface {
GetDOT() (string, error)
}
// PutNode: puts a node in the root graph
func (g *RootGraph) PutNode(name, label string, size int, shape Shape) error {
_, exists := g.nodes[name]
if exists {
// If exists no need to add the ndoe
return nil
}
g.nodes[name] = &Node{Name: name, Label: label, Size: size, Shape: shape}
return nil
}
// PutCluster: puts a cluster in the root graph
func (g *RootGraph) PutCluster(graph *Cluster) {
g.clusters[graph.Label] = graph
}
func writeContituents[D Dottable](result *strings.Builder, elements ...D) error {
for _, node := range elements {
dot, err := node.GetDOT()
if err != nil {
return err
}
_, err = result.WriteString(dot)
if err != nil {
return err
}
}
return nil
}
// GetDOT: convert the root graph into dot format
func (g *RootGraph) GetDOT() (string, error) {
var result strings.Builder
result.WriteString(fmt.Sprintf("%s {\n", g.Type))
result.WriteString("node [colorscheme=set312];\n")
result.WriteString("layout = fdp;\n")
nodes := lib.MapValues(g.nodes)
edges := lib.MapValues(g.edges)
writeContituents(&result, nodes...)
writeContituents(&result, edges...)
for _, cluster := range g.clusters {
clusterDOT, err := cluster.GetDOT()
if err != nil {
return "", err
}
result.WriteString(clusterDOT)
}
result.WriteString("}")
return result.String(), nil
}
// GetType: get the graph type. DIRECTED|UNDIRECTED
func (r *RootGraph) GetType() GraphType {
return r.Type
}
func constructEdge(graph Graph, name, label, from, to string) Edge {
switch graph.GetType() {
case DIGRAPH:
return &DirectedEdge{Name: name, Label: label, From: from, To: to}
default:
return &UndirectedEdge{Name: name, Label: label, From: from, To: to}
}
}
// AddEdge: adds an edge between two nodes in the root graph
func (g *RootGraph) AddEdge(name string, label string, from string, to string) error {
g.edges[name] = constructEdge(g, name, label, from, to)
return nil
}
const numColours = 12
func (n *Node) hash() int {
h := fnv.New32a()
h.Write([]byte(n.Name))
return (int(h.Sum32()) % numColours) + 1
}
// GetDOT: convert the node into DOT format
func (n *Node) GetDOT() (string, error) {
return fmt.Sprintf("node[label=\"%s\",shape=%s, style=\"filled\", fillcolor=%d, width=%d, height=%d, fixedsize=true] \"%s\";\n",
n.Label, n.Shape, n.hash(), n.Size, n.Size, n.Name), nil
}
// GetDOT: Convert a directed edge into dot format
func (e *DirectedEdge) GetDOT() (string, error) {
return fmt.Sprintf("\"%s\" -> \"%s\" [label=\"%s\"];\n", e.From, e.To, e.Label), nil
}
// GetDOT: convert an undirected edge into dot format
func (e *UndirectedEdge) GetDOT() (string, error) {
return fmt.Sprintf("\"%s\" -- \"%s\" [label=\"%s\"];\n", e.From, e.To, e.Label), nil
}
// AddEdge: adds an edge between two nodes in the graph
func (g *Cluster) AddEdge(name string, label string, from string, to string) error {
g.edges[name] = constructEdge(g, name, label, from, to)
return nil
}
// PutNode: puts a node in the graph
func (g *Cluster) PutNode(name, label string, size int, shape Shape) error {
_, exists := g.nodes[name]
if exists {
// If exists no need to add the ndoe
return nil
}
g.nodes[name] = &Node{Name: name, Label: label, Shape: shape, Size: size}
return nil
}
// GetDOT: convert the cluster into dot format
func (g *Cluster) GetDOT() (string, error) {
var builder strings.Builder
builder.WriteString(fmt.Sprintf("subgraph \"cluster%s\" {\n", g.Label))
builder.WriteString(fmt.Sprintf("label = \"%s\"\n", g.Label))
nodes := lib.MapValues(g.nodes)
edges := lib.MapValues(g.edges)
writeContituents(&builder, nodes...)
writeContituents(&builder, edges...)
builder.WriteString("}\n")
return builder.String(), nil
}
// GetType: get the type of the subgraph (directed|undirected)
func (g *Cluster) GetType() GraphType {
return g.Type
}
// NewSubGraph: instantiate a new subgraph
func NewSubGraph(name string, label string, graphType GraphType) *Cluster {
return &Cluster{
Label: name,
Type: graphType,
Name: name,
nodes: make(map[string]*Node),
edges: make(map[string]Edge),
}
}
// NewGraph: create a new root graph
func NewGraph(label string, graphType GraphType) *RootGraph {
return &RootGraph{
Type: graphType,
Label: label,
clusters: map[string]*Cluster{},
nodes: make(map[string]*Node),
edges: make(map[string]Edge),
}
}

116
pkg/dot/wg.go Normal file
View File

@ -0,0 +1,116 @@
package graph
import (
"fmt"
"slices"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
)
// MeshGraphConverter converts a mesh to a graph
type MeshGraphConverter interface {
// convert the mesh to textual form
Generate() (string, error)
}
type MeshDOTConverter struct {
meshes map[string][]ctrlserver.MeshNode
destinations map[string]interface{}
}
func (c *MeshDOTConverter) Generate() (string, error) {
g := NewGraph("Smegmesh", GRAPH)
for meshId := range c.meshes {
err := c.generateMesh(g, meshId)
if err != nil {
return "", err
}
}
for mesh := range c.meshes {
g.PutNode(mesh, mesh, 1, CIRCLE)
}
for destination := range c.destinations {
g.PutNode(destination, destination, 1, HEXAGON)
}
return g.GetDOT()
}
func (c *MeshDOTConverter) generateMesh(g *RootGraph, meshId string) error {
nodes := c.meshes[meshId]
g.PutNode(meshId, meshId, 1, CIRCLE)
for _, node := range nodes {
c.graphNode(g, node, meshId)
}
for _, node := range nodes {
g.AddEdge(fmt.Sprintf("%s to %s", node.PublicKey, meshId), "", node.PublicKey, meshId)
}
return nil
}
// graphNode: graphs a node within the mesh
func (c *MeshDOTConverter) graphNode(g *RootGraph, node ctrlserver.MeshNode, meshId string) {
alias := node.Alias
if alias == "" {
alias = node.WgHost[1:len(node.WgHost)-20] + "\\n" + node.WgHost[len(node.WgHost)-20:len(node.WgHost)]
}
g.PutNode(node.PublicKey, alias, 2, CIRCLE)
for _, route := range node.Routes {
if len(route.Path) == 0 {
g.AddEdge(route.Destination, "", node.PublicKey, route.Destination)
continue
}
reversedPath := slices.Clone(route.Path)
slices.Reverse(reversedPath)
g.AddEdge(fmt.Sprintf("%s to %s", node.PublicKey, reversedPath[0]), "", node.PublicKey, reversedPath[0])
for _, mesh := range route.Path {
if _, ok := c.meshes[mesh]; !ok {
c.destinations[mesh] = struct{}{}
}
}
for index := range reversedPath[0 : len(reversedPath)-1] {
routeID := fmt.Sprintf("%s to %s", reversedPath[index], reversedPath[index+1])
g.AddEdge(routeID, "", reversedPath[index], reversedPath[index+1])
}
if route.Destination == "::/0" {
c.destinations[route.Destination] = struct{}{}
lastMesh := reversedPath[len(reversedPath)-1]
routeID := fmt.Sprintf("%s to %s", lastMesh, route.Destination)
g.AddEdge(routeID, "", lastMesh, route.Destination)
}
}
for service := range node.Services {
c.putService(g, service, meshId, node)
}
}
// putService: construct a service node and a link between the nodes
func (c *MeshDOTConverter) putService(g *RootGraph, key, meshId string, node ctrlserver.MeshNode) {
serviceID := fmt.Sprintf("%s%s%s", key, node.PublicKey, meshId)
g.PutNode(serviceID, key, 1, PARALLELOGRAM)
g.AddEdge(fmt.Sprintf("%s to %s", node.PublicKey, serviceID), "", node.PublicKey, serviceID)
}
func NewMeshGraphConverter(meshes map[string][]ctrlserver.MeshNode) MeshGraphConverter {
return &MeshDOTConverter{
meshes: meshes,
destinations: make(map[string]interface{}),
}
}

View File

@ -1,178 +0,0 @@
// Graph allows the definition of a DOT graph in golang
package graph
import (
"errors"
"fmt"
"hash/fnv"
"strings"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
type GraphType string
type Shape string
const (
GRAPH GraphType = "graph"
DIGRAPH = "digraph"
)
const (
CIRCLE Shape = "circle"
STAR Shape = "star"
HEXAGON Shape = "hexagon"
)
type Graph struct {
Type GraphType
Label string
nodes map[string]*Node
edges []Edge
}
type Node struct {
Name string
Shape Shape
}
type Edge interface {
Dottable
}
type DirectedEdge struct {
Label string
From *Node
To *Node
}
type UndirectedEdge struct {
Label string
From *Node
To *Node
}
// Dottable means an implementer can convert the struct to DOT representation
type Dottable interface {
GetDOT() (string, error)
}
func NewGraph(label string, graphType GraphType) *Graph {
return &Graph{Type: graphType, Label: label, nodes: make(map[string]*Node), edges: make([]Edge, 0)}
}
// PutNode: puts a node in the graph
func (g *Graph) PutNode(label string, shape Shape) error {
_, exists := g.nodes[label]
if exists {
// If exists no need to add the ndoe
return nil
}
g.nodes[label] = &Node{Name: label, Shape: shape}
return nil
}
func writeContituents[D Dottable](result *strings.Builder, elements ...D) error {
for _, node := range elements {
dot, err := node.GetDOT()
if err != nil {
return err
}
_, err = result.WriteString(dot)
if err != nil {
return err
}
}
return nil
}
func (g *Graph) GetDOT() (string, error) {
var result strings.Builder
_, err := result.WriteString(fmt.Sprintf("%s {\n", g.Type))
if err != nil {
return "", err
}
_, err = result.WriteString("node [colorscheme=set312];\n")
if err != nil {
return "", err
}
nodes := lib.MapValues(g.nodes)
err = writeContituents(&result, nodes...)
if err != nil {
return "", err
}
err = writeContituents(&result, g.edges...)
if err != nil {
return "", err
}
_, err = result.WriteString("}")
if err != nil {
return "", err
}
return result.String(), nil
}
func (g *Graph) constructEdge(label string, from *Node, to *Node) Edge {
switch g.Type {
case DIGRAPH:
return &DirectedEdge{Label: label, From: from, To: to}
default:
return &UndirectedEdge{Label: label, From: from, To: to}
}
}
// AddEdge: adds an edge between two nodes in the graph
func (g *Graph) AddEdge(label string, from string, to string) error {
fromNode, exists := g.nodes[from]
if !exists {
return errors.New(fmt.Sprintf("Node %s does not exist", from))
}
toNode, exists := g.nodes[to]
if !exists {
return errors.New(fmt.Sprintf("Node %s does not exist", to))
}
g.edges = append(g.edges, g.constructEdge(label, fromNode, toNode))
return nil
}
const numColours = 12
func (n *Node) hash() int {
h := fnv.New32a()
h.Write([]byte(n.Name))
return (int(h.Sum32()) % numColours) + 1
}
func (n *Node) GetDOT() (string, error) {
return fmt.Sprintf("node[shape=%s, style=\"filled\", fillcolor=%d] %s;\n",
n.Shape, n.hash(), n.Name), nil
}
func (e *DirectedEdge) GetDOT() (string, error) {
return fmt.Sprintf("%s -> %s;\n", e.From.Name, e.To.Name), nil
}
func (e *UndirectedEdge) GetDOT() (string, error) {
return fmt.Sprintf("%s -- %s;\n", e.From.Name, e.To.Name), nil
}

212
pkg/grpc/ctrlserver.pb.go Normal file
View File

@ -0,0 +1,212 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.1
// protoc v3.21.12
// source: pkg/grpc/ctrlserver.proto
package rpc
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type GetMeshRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
MeshId string `protobuf:"bytes,1,opt,name=meshId,proto3" json:"meshId,omitempty"`
}
func (x *GetMeshRequest) Reset() {
*x = GetMeshRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *GetMeshRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GetMeshRequest) ProtoMessage() {}
func (x *GetMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GetMeshRequest.ProtoReflect.Descriptor instead.
func (*GetMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_proto_rawDescGZIP(), []int{0}
}
func (x *GetMeshRequest) GetMeshId() string {
if x != nil {
return x.MeshId
}
return ""
}
type GetMeshReply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Mesh []byte `protobuf:"bytes,1,opt,name=mesh,proto3" json:"mesh,omitempty"`
}
func (x *GetMeshReply) Reset() {
*x = GetMeshReply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *GetMeshReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GetMeshReply) ProtoMessage() {}
func (x *GetMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GetMeshReply.ProtoReflect.Descriptor instead.
func (*GetMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_proto_rawDescGZIP(), []int{1}
}
func (x *GetMeshReply) GetMesh() []byte {
if x != nil {
return x.Mesh
}
return nil
}
var File_pkg_grpc_ctrlserver_proto protoreflect.FileDescriptor
var file_pkg_grpc_ctrlserver_proto_rawDesc = []byte{
0x0a, 0x19, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73,
0x65, 0x72, 0x76, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x22, 0x28, 0x0a, 0x0e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68,
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49,
0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x22,
0x22, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12,
0x12, 0x0a, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x6d,
0x65, 0x73, 0x68, 0x32, 0x4f, 0x0a, 0x0e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x74, 0x72, 0x6c, 0x53,
0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x3d, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68,
0x12, 0x18, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d,
0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70,
0x6c, 0x79, 0x22, 0x00, 0x42, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62,
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_pkg_grpc_ctrlserver_proto_rawDescOnce sync.Once
file_pkg_grpc_ctrlserver_proto_rawDescData = file_pkg_grpc_ctrlserver_proto_rawDesc
)
func file_pkg_grpc_ctrlserver_proto_rawDescGZIP() []byte {
file_pkg_grpc_ctrlserver_proto_rawDescOnce.Do(func() {
file_pkg_grpc_ctrlserver_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_grpc_ctrlserver_proto_rawDescData)
})
return file_pkg_grpc_ctrlserver_proto_rawDescData
}
var file_pkg_grpc_ctrlserver_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_pkg_grpc_ctrlserver_proto_goTypes = []interface{}{
(*GetMeshRequest)(nil), // 0: rpctypes.GetMeshRequest
(*GetMeshReply)(nil), // 1: rpctypes.GetMeshReply
}
var file_pkg_grpc_ctrlserver_proto_depIdxs = []int32{
0, // 0: rpctypes.MeshCtrlServer.GetMesh:input_type -> rpctypes.GetMeshRequest
1, // 1: rpctypes.MeshCtrlServer.GetMesh:output_type -> rpctypes.GetMeshReply
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_pkg_grpc_ctrlserver_proto_init() }
func file_pkg_grpc_ctrlserver_proto_init() {
if File_pkg_grpc_ctrlserver_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_pkg_grpc_ctrlserver_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetMeshRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_ctrlserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetMeshReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pkg_grpc_ctrlserver_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_pkg_grpc_ctrlserver_proto_goTypes,
DependencyIndexes: file_pkg_grpc_ctrlserver_proto_depIdxs,
MessageInfos: file_pkg_grpc_ctrlserver_proto_msgTypes,
}.Build()
File_pkg_grpc_ctrlserver_proto = out.File
file_pkg_grpc_ctrlserver_proto_rawDesc = nil
file_pkg_grpc_ctrlserver_proto_goTypes = nil
file_pkg_grpc_ctrlserver_proto_depIdxs = nil
}

View File

@ -1,18 +0,0 @@
syntax = "proto3";
package rpctypes;
option go_package = "pkg/rpc";
service Authentication {
rpc JoinMesh(JoinAuthMeshRequest) returns (JoinAuthMeshReply) {}
}
message JoinAuthMeshRequest {
string meshId = 1;
string alias = 2;
}
message JoinAuthMeshReply {
bool success = 1;
optional string token = 2;
}

View File

@ -0,0 +1,105 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.2.0
// - protoc v3.21.12
// source: pkg/grpc/ctrlserver.proto
package rpc
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// MeshCtrlServerClient is the client API for MeshCtrlServer service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type MeshCtrlServerClient interface {
GetMesh(ctx context.Context, in *GetMeshRequest, opts ...grpc.CallOption) (*GetMeshReply, error)
}
type meshCtrlServerClient struct {
cc grpc.ClientConnInterface
}
func NewMeshCtrlServerClient(cc grpc.ClientConnInterface) MeshCtrlServerClient {
return &meshCtrlServerClient{cc}
}
func (c *meshCtrlServerClient) GetMesh(ctx context.Context, in *GetMeshRequest, opts ...grpc.CallOption) (*GetMeshReply, error) {
out := new(GetMeshReply)
err := c.cc.Invoke(ctx, "/rpctypes.MeshCtrlServer/GetMesh", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// MeshCtrlServerServer is the server API for MeshCtrlServer service.
// All implementations must embed UnimplementedMeshCtrlServerServer
// for forward compatibility
type MeshCtrlServerServer interface {
GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error)
mustEmbedUnimplementedMeshCtrlServerServer()
}
// UnimplementedMeshCtrlServerServer must be embedded to have forward compatible implementations.
type UnimplementedMeshCtrlServerServer struct {
}
func (UnimplementedMeshCtrlServerServer) GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetMesh not implemented")
}
func (UnimplementedMeshCtrlServerServer) mustEmbedUnimplementedMeshCtrlServerServer() {}
// UnsafeMeshCtrlServerServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to MeshCtrlServerServer will
// result in compilation errors.
type UnsafeMeshCtrlServerServer interface {
mustEmbedUnimplementedMeshCtrlServerServer()
}
func RegisterMeshCtrlServerServer(s grpc.ServiceRegistrar, srv MeshCtrlServerServer) {
s.RegisterService(&MeshCtrlServer_ServiceDesc, srv)
}
func _MeshCtrlServer_GetMesh_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(GetMeshRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(MeshCtrlServerServer).GetMesh(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/rpctypes.MeshCtrlServer/GetMesh",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(MeshCtrlServerServer).GetMesh(ctx, req.(*GetMeshRequest))
}
return interceptor(ctx, in, info, handler)
}
// MeshCtrlServer_ServiceDesc is the grpc.ServiceDesc for MeshCtrlServer service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var MeshCtrlServer_ServiceDesc = grpc.ServiceDesc{
ServiceName: "rpctypes.MeshCtrlServer",
HandlerType: (*MeshCtrlServerServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "GetMesh",
Handler: _MeshCtrlServer_GetMesh_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "pkg/grpc/ctrlserver.proto",
}

233
pkg/grpc/syncservice.pb.go Normal file
View File

@ -0,0 +1,233 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.1
// protoc v3.21.12
// source: pkg/grpc/syncservice.proto
package rpc
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type SyncMeshRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
MeshId string `protobuf:"bytes,1,opt,name=meshId,proto3" json:"meshId,omitempty"`
Changes []byte `protobuf:"bytes,2,opt,name=changes,proto3" json:"changes,omitempty"`
}
func (x *SyncMeshRequest) Reset() {
*x = SyncMeshRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_syncservice_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SyncMeshRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SyncMeshRequest) ProtoMessage() {}
func (x *SyncMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_syncservice_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SyncMeshRequest.ProtoReflect.Descriptor instead.
func (*SyncMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_syncservice_proto_rawDescGZIP(), []int{0}
}
func (x *SyncMeshRequest) GetMeshId() string {
if x != nil {
return x.MeshId
}
return ""
}
func (x *SyncMeshRequest) GetChanges() []byte {
if x != nil {
return x.Changes
}
return nil
}
type SyncMeshReply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"`
Changes []byte `protobuf:"bytes,2,opt,name=changes,proto3" json:"changes,omitempty"`
}
func (x *SyncMeshReply) Reset() {
*x = SyncMeshReply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_syncservice_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SyncMeshReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SyncMeshReply) ProtoMessage() {}
func (x *SyncMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_syncservice_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SyncMeshReply.ProtoReflect.Descriptor instead.
func (*SyncMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_syncservice_proto_rawDescGZIP(), []int{1}
}
func (x *SyncMeshReply) GetSuccess() bool {
if x != nil {
return x.Success
}
return false
}
func (x *SyncMeshReply) GetChanges() []byte {
if x != nil {
return x.Changes
}
return nil
}
var File_pkg_grpc_syncservice_proto protoreflect.FileDescriptor
var file_pkg_grpc_syncservice_proto_rawDesc = []byte{
0x0a, 0x1a, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x73, 0x79, 0x6e, 0x63, 0x73,
0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0b, 0x73, 0x79,
0x6e, 0x63, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x22, 0x43, 0x0a, 0x0f, 0x53, 0x79, 0x6e,
0x63, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06,
0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65,
0x73, 0x68, 0x49, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18,
0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x22, 0x43,
0x0a, 0x0d, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12,
0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08,
0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x68, 0x61,
0x6e, 0x67, 0x65, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x63, 0x68, 0x61, 0x6e,
0x67, 0x65, 0x73, 0x32, 0x59, 0x0a, 0x0b, 0x53, 0x79, 0x6e, 0x63, 0x53, 0x65, 0x72, 0x76, 0x69,
0x63, 0x65, 0x12, 0x4a, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x1c,
0x2e, 0x73, 0x79, 0x6e, 0x63, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x53, 0x79, 0x6e,
0x63, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x73,
0x79, 0x6e, 0x63, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x4d,
0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x09,
0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x33,
}
var (
file_pkg_grpc_syncservice_proto_rawDescOnce sync.Once
file_pkg_grpc_syncservice_proto_rawDescData = file_pkg_grpc_syncservice_proto_rawDesc
)
func file_pkg_grpc_syncservice_proto_rawDescGZIP() []byte {
file_pkg_grpc_syncservice_proto_rawDescOnce.Do(func() {
file_pkg_grpc_syncservice_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_grpc_syncservice_proto_rawDescData)
})
return file_pkg_grpc_syncservice_proto_rawDescData
}
var file_pkg_grpc_syncservice_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_pkg_grpc_syncservice_proto_goTypes = []interface{}{
(*SyncMeshRequest)(nil), // 0: syncservice.SyncMeshRequest
(*SyncMeshReply)(nil), // 1: syncservice.SyncMeshReply
}
var file_pkg_grpc_syncservice_proto_depIdxs = []int32{
0, // 0: syncservice.SyncService.SyncMesh:input_type -> syncservice.SyncMeshRequest
1, // 1: syncservice.SyncService.SyncMesh:output_type -> syncservice.SyncMeshReply
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_pkg_grpc_syncservice_proto_init() }
func file_pkg_grpc_syncservice_proto_init() {
if File_pkg_grpc_syncservice_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_pkg_grpc_syncservice_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SyncMeshRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_syncservice_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SyncMeshReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pkg_grpc_syncservice_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_pkg_grpc_syncservice_proto_goTypes,
DependencyIndexes: file_pkg_grpc_syncservice_proto_depIdxs,
MessageInfos: file_pkg_grpc_syncservice_proto_msgTypes,
}.Build()
File_pkg_grpc_syncservice_proto = out.File
file_pkg_grpc_syncservice_proto_rawDesc = nil
file_pkg_grpc_syncservice_proto_goTypes = nil
file_pkg_grpc_syncservice_proto_depIdxs = nil
}

View File

@ -4,18 +4,9 @@ package syncservice;
option go_package = "pkg/rpc";
service SyncService {
rpc GetConf(GetConfRequest) returns (GetConfReply) {}
rpc SyncMesh(stream SyncMeshRequest) returns (stream SyncMeshReply) {}
}
message GetConfRequest {
string meshId = 1;
}
message GetConfReply {
bytes mesh = 1;
}
message SyncMeshRequest {
string meshId = 1;
bytes changes = 2;

View File

@ -0,0 +1,137 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.2.0
// - protoc v3.21.12
// source: pkg/grpc/syncservice.proto
package rpc
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// SyncServiceClient is the client API for SyncService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type SyncServiceClient interface {
SyncMesh(ctx context.Context, opts ...grpc.CallOption) (SyncService_SyncMeshClient, error)
}
type syncServiceClient struct {
cc grpc.ClientConnInterface
}
func NewSyncServiceClient(cc grpc.ClientConnInterface) SyncServiceClient {
return &syncServiceClient{cc}
}
func (c *syncServiceClient) SyncMesh(ctx context.Context, opts ...grpc.CallOption) (SyncService_SyncMeshClient, error) {
stream, err := c.cc.NewStream(ctx, &SyncService_ServiceDesc.Streams[0], "/syncservice.SyncService/SyncMesh", opts...)
if err != nil {
return nil, err
}
x := &syncServiceSyncMeshClient{stream}
return x, nil
}
type SyncService_SyncMeshClient interface {
Send(*SyncMeshRequest) error
Recv() (*SyncMeshReply, error)
grpc.ClientStream
}
type syncServiceSyncMeshClient struct {
grpc.ClientStream
}
func (x *syncServiceSyncMeshClient) Send(m *SyncMeshRequest) error {
return x.ClientStream.SendMsg(m)
}
func (x *syncServiceSyncMeshClient) Recv() (*SyncMeshReply, error) {
m := new(SyncMeshReply)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// SyncServiceServer is the server API for SyncService service.
// All implementations must embed UnimplementedSyncServiceServer
// for forward compatibility
type SyncServiceServer interface {
SyncMesh(SyncService_SyncMeshServer) error
mustEmbedUnimplementedSyncServiceServer()
}
// UnimplementedSyncServiceServer must be embedded to have forward compatible implementations.
type UnimplementedSyncServiceServer struct {
}
func (UnimplementedSyncServiceServer) SyncMesh(SyncService_SyncMeshServer) error {
return status.Errorf(codes.Unimplemented, "method SyncMesh not implemented")
}
func (UnimplementedSyncServiceServer) mustEmbedUnimplementedSyncServiceServer() {}
// UnsafeSyncServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to SyncServiceServer will
// result in compilation errors.
type UnsafeSyncServiceServer interface {
mustEmbedUnimplementedSyncServiceServer()
}
func RegisterSyncServiceServer(s grpc.ServiceRegistrar, srv SyncServiceServer) {
s.RegisterService(&SyncService_ServiceDesc, srv)
}
func _SyncService_SyncMesh_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(SyncServiceServer).SyncMesh(&syncServiceSyncMeshServer{stream})
}
type SyncService_SyncMeshServer interface {
Send(*SyncMeshReply) error
Recv() (*SyncMeshRequest, error)
grpc.ServerStream
}
type syncServiceSyncMeshServer struct {
grpc.ServerStream
}
func (x *syncServiceSyncMeshServer) Send(m *SyncMeshReply) error {
return x.ServerStream.SendMsg(m)
}
func (x *syncServiceSyncMeshServer) Recv() (*SyncMeshRequest, error) {
m := new(SyncMeshRequest)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// SyncService_ServiceDesc is the grpc.ServiceDesc for SyncService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var SyncService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "syncservice.SyncService",
HandlerType: (*SyncServiceServer)(nil),
Methods: []grpc.MethodDesc{},
Streams: []grpc.StreamDesc{
{
StreamName: "SyncMesh",
Handler: _SyncService_SyncMesh_Handler,
ServerStreams: true,
ClientStreams: true,
},
},
Metadata: "pkg/grpc/syncservice.proto",
}

View File

@ -1,8 +1,7 @@
package ip
/*
* Use a WireGuard public key to generate a unique interface ID
*/
// Generates a CGA see RFC 3972
// https://datatracker.ietf.org/doc/html/rfc3972
import (
"crypto/rand"
@ -22,19 +21,23 @@ const (
InterfaceIdLen = 8
)
/*
* Cga parameters used to generate an IPV6 interface ID
*/
// CGAParameters: parameters used to create a new cryotpgraphically generated
// address
type CgaParameters struct {
Modifier [ModifierLength]byte
// SubnetPrefix: prefix of the subnetwork
SubnetPrefix [2 * InterfaceIdLen]byte
// CollisionCount: total number of times we have atempted to generate a porefix
CollisionCount uint8
// PublicKey: WireGuard public key of our interface
PublicKey wgtypes.Key
// interfaceId: the generated interfaceId
interfaceId [2 * InterfaceIdLen]byte
// flag: represents whether or not an IP address has been generated
flag byte
}
func NewCga(key wgtypes.Key, subnetPrefix [2 * InterfaceIdLen]byte) (*CgaParameters, error) {
func NewCga(key wgtypes.Key, collisionCount uint8, subnetPrefix [2 * InterfaceIdLen]byte) (*CgaParameters, error) {
var params CgaParameters
_, err := rand.Read(params.Modifier[:])
@ -45,25 +48,10 @@ func NewCga(key wgtypes.Key, subnetPrefix [2 * InterfaceIdLen]byte) (*CgaParamet
params.PublicKey = key
params.SubnetPrefix = subnetPrefix
params.CollisionCount = collisionCount
return &params, nil
}
func (c *CgaParameters) generateHash2() []byte {
var byteVal [hash2Length]byte
for i := 0; i < ModifierLength; i++ {
byteVal[i] = c.Modifier[i]
}
for i := 0; i < wgtypes.KeyLen; i++ {
byteVal[ModifierLength+ZeroLength+i] = c.PublicKey[i]
}
hash := sha1.Sum(byteVal[:])
return hash[:Hash2Prefix]
}
func (c *CgaParameters) generateHash1() []byte {
var byteVal [hash1Length]byte
@ -78,7 +66,6 @@ func (c *CgaParameters) generateHash1() []byte {
byteVal[hash1Length-1] = c.CollisionCount
hash := sha1.Sum(byteVal[:])
return hash[:Hash1Prefix]
}
@ -90,9 +77,6 @@ func clearBit(num, pos int) byte {
}
func (c *CgaParameters) generateInterface() []byte {
// TODO: On duplicate address detection increment collision.
// Also incorporate SEC
hash1 := c.generateHash1()
var interfaceId []byte = make([]byte, InterfaceIdLen)
@ -101,7 +85,6 @@ func (c *CgaParameters) generateInterface() []byte {
interfaceId[0] = clearBit(int(interfaceId[0]), 6)
interfaceId[0] = clearBit(int(interfaceId[1]), 7)
return interfaceId
}

View File

@ -6,6 +6,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// IPAllocator: abstracts the process of creating an IP address
type IPAllocator interface {
GetIP(key wgtypes.Key, meshId string) (net.IP, error)
GetIP(key wgtypes.Key, meshId string, collisionCount uint8) (net.IP, error)
}

View File

@ -8,6 +8,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// ULABuilder: Create a new ULA in WireGuard
type ULABuilder struct{}
func getMeshPrefix(meshId string) [16]byte {
@ -39,10 +40,10 @@ func (u *ULABuilder) GetIPNet(meshId string) (*net.IPNet, error) {
return net, nil
}
func (u *ULABuilder) GetIP(key wgtypes.Key, meshId string) (net.IP, error) {
func (u *ULABuilder) GetIP(key wgtypes.Key, meshId string, collisionCount uint8) (net.IP, error) {
ulaPrefix := getMeshPrefix(meshId)
c, err := NewCga(key, ulaPrefix)
c, err := NewCga(key, collisionCount, ulaPrefix)
if err != nil {
return nil, err

View File

@ -5,11 +5,27 @@ import (
"net"
"net/http"
"net/rpc"
ipcRPC "net/rpc"
"os"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
)
const SockAddr = "/tmp/smeg.sock"
type MeshIpc interface {
CreateMesh(args *NewMeshArgs, reply *string) error
ListMeshes(name string, reply *ListMeshReply) error
JoinMesh(args *JoinMeshArgs, reply *string) error
LeaveMesh(meshId string, reply *string) error
GetMesh(meshId string, reply *GetMeshReply) error
Query(query QueryMesh, reply *string) error
PutDescription(args PutDescriptionArgs, reply *string) error
PutAlias(args PutAliasArgs, reply *string) error
PutService(args PutServiceArgs, reply *string) error
DeleteService(args DeleteServiceArgs, reply *string) error
}
// WireGuardArgs are provided args specific to WireGuard
type WireGuardArgs struct {
// WgPort is the WireGuard port to expose
@ -39,44 +55,141 @@ type JoinMeshArgs struct {
// MeshId is the ID of the mesh to join
MeshId string
// IpAddress is a routable IP in another mesh
IpAdress string
IpAddress string
// WgArgs is the WireGuard parameters to use.
WgArgs WireGuardArgs
}
// PutServiceArgs: args to place a service into the data store
type PutServiceArgs struct {
Service string
Value string
MeshId string
}
// DeleteServiceArgs: args to remove a service from the data store
type DeleteServiceArgs struct {
Service string
MeshId string
}
// PutAliasArgs: args to assign an alias to a node
type PutAliasArgs struct {
// Alias: represents the alias of the node
Alias string
// MeshId: represents the meshID of the node
MeshId string
}
// PutDescriptionArgs: args to assign a description to a node
type PutDescriptionArgs struct {
// Description: descriptio to add to the network
Description string
// MeshID to add to the mesh network
MeshId string
}
// GetMeshReply: ipc reply to get the mesh network
type GetMeshReply struct {
Nodes []ctrlserver.MeshNode
}
// ListMeshReply: ipc reply of the networks the node is part of
type ListMeshReply struct {
Meshes []string
}
// Querymesh: ipc args to query a mesh network
type QueryMesh struct {
// MeshId: id of the mesh to query
MeshId string
Query string
// JMESPath: query string to query
Query string
}
type MeshIpc interface {
// ClientIpc: Framework to invoke ipc calls to the daemon
type ClientIpc interface {
// CreateMesh: create a mesh network, return an error if the operation failed
CreateMesh(args *NewMeshArgs, reply *string) error
ListMeshes(name string, reply *ListMeshReply) error
// ListMesh: list mesh network the node is a part of, return an error if the operation failed
ListMeshes(args *ListMeshReply, reply *string) error
// JoinMesh: join a mesh network return an error if the operation failed
JoinMesh(args JoinMeshArgs, reply *string) error
// LeaveMesh: leave a mesh network, return an error if the operation failed
LeaveMesh(meshId string, reply *string) error
// GetMesh: get the given mesh network, return an error if the operation failed
GetMesh(meshId string, reply *GetMeshReply) error
GetDOT(meshId string, reply *string) error
// Query: query the given mesh network
Query(query QueryMesh, reply *string) error
PutDescription(description string, reply *string) error
PutAlias(alias string, reply *string) error
// PutDescription: assign a description to yourself
PutDescription(args PutDescriptionArgs, reply *string) error
// PutAlias: assign an alias to yourself
PutAlias(args PutAliasArgs, reply *string) error
// PutService: assign a service to yourself
PutService(args PutServiceArgs, reply *string) error
DeleteService(service string, reply *string) error
// DeleteService: retract a service
DeleteService(args DeleteServiceArgs, reply *string) error
}
const SockAddr = "/tmp/wgmesh_ipc.sock"
type SmegmeshIpc struct {
client *ipcRPC.Client
}
func NewClientIpc() (*SmegmeshIpc, error) {
client, err := ipcRPC.DialHTTP("unix", SockAddr)
if err != nil {
return nil, err
}
return &SmegmeshIpc{
client: client,
}, nil
}
func (c *SmegmeshIpc) CreateMesh(args *NewMeshArgs, reply *string) error {
return c.client.Call("IpcHandler.CreateMesh", args, reply)
}
func (c *SmegmeshIpc) ListMeshes(reply *ListMeshReply) error {
return c.client.Call("IpcHandler.ListMeshes", "", reply)
}
func (c *SmegmeshIpc) JoinMesh(args JoinMeshArgs, reply *string) error {
return c.client.Call("IpcHandler.JoinMesh", &args, reply)
}
func (c *SmegmeshIpc) LeaveMesh(meshId string, reply *string) error {
return c.client.Call("IpcHandler.LeaveMesh", &meshId, reply)
}
func (c *SmegmeshIpc) GetMesh(meshId string, reply *GetMeshReply) error {
return c.client.Call("IpcHandler.GetMesh", &meshId, reply)
}
func (c *SmegmeshIpc) Query(query QueryMesh, reply *string) error {
return c.client.Call("IpcHandler.Query", &query, reply)
}
func (c *SmegmeshIpc) PutDescription(args PutDescriptionArgs, reply *string) error {
return c.client.Call("IpcHandler.PutDescription", &args, reply)
}
func (c *SmegmeshIpc) PutAlias(args PutAliasArgs, reply *string) error {
return c.client.Call("IpcHandler.PutAlias", &args, reply)
}
func (c *SmegmeshIpc) PutService(args PutServiceArgs, reply *string) error {
return c.client.Call("IpcHandler.PutService", &args, reply)
}
func (c *SmegmeshIpc) DeleteService(args DeleteServiceArgs, reply *string) error {
return c.client.Call("IpcHandler.DeleteService", &args, reply)
}
func (c *SmegmeshIpc) Close() error {
return c.client.Close()
}
func RunIpcHandler(server MeshIpc) error {
if err := os.RemoveAll(SockAddr); err != nil {

View File

@ -36,11 +36,13 @@ func ConsistentHash[V any, K any](values []V, client K, bucketFunc func(V) int,
ourKey := keyFunc(client)
for _, record := range vs {
if ourKey < record.value {
return record.record
}
idx := sort.Search(len(vs), func(i int) bool {
return vs[i].value >= ourKey
})
if idx == len(vs) {
return vs[0].record
}
return vs[0].record
return vs[idx].record
}

View File

@ -3,6 +3,7 @@ package lib
import (
"github.com/anandvarma/namegen"
"github.com/google/uuid"
"github.com/lithammer/shortuuid"
)
// IdGenerator generates unique ids
@ -19,6 +20,14 @@ func (g *UUIDGenerator) GetId() (string, error) {
return id.String(), nil
}
type ShortIDGenerator struct {
}
func (g *ShortIDGenerator) GetId() (string, error) {
id := shortuuid.New()
return id, nil
}
type IDNameGenerator struct {
}

View File

@ -6,27 +6,21 @@ import (
"net"
"github.com/jsimonetti/rtnetlink"
logging "github.com/tim-beatham/wgmesh/pkg/log"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"golang.org/x/sys/unix"
)
// Maximum MTU to assin to WireGuard
// This isn't configurable
const WIREGUARD_MTU = 1420
// RtNetlinkConfig: represents an rtnetlkink configuration instance
type RtNetlinkConfig struct {
// conn: connection to the rtnetlink API
conn *rtnetlink.Conn
}
func NewRtNetlinkConfig() (*RtNetlinkConfig, error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return nil, err
}
return &RtNetlinkConfig{conn: conn}, nil
}
const WIREGUARD_MTU = 1420
// Create a netlink interface if it does not exist. ifName is the name of the netlink interface
// CreateLink: Create a netlink interface if it does not exist. ifName is the name of the netlink interface
func (c *RtNetlinkConfig) CreateLink(ifName string) error {
_, err := net.InterfaceByName(ifName)
@ -51,7 +45,7 @@ func (c *RtNetlinkConfig) CreateLink(ifName string) error {
return nil
}
// Delete link delete the specified interface
// DeleteLink: delete the specified interface
func (c *RtNetlinkConfig) DeleteLink(ifName string) error {
iface, err := net.InterfaceByName(ifName)
@ -68,7 +62,7 @@ func (c *RtNetlinkConfig) DeleteLink(ifName string) error {
return nil
}
// AddAddress adds an address to the given interface.
// AddAddress: adds an address to the given interface.
func (c *RtNetlinkConfig) AddAddress(ifName string, address string) error {
iface, err := net.InterfaceByName(ifName)
@ -177,7 +171,7 @@ func (c *RtNetlinkConfig) AddRoute(ifName string, route Route) error {
return nil
}
// DeleteRoute deletes routes with the gateway and destination
// DeleteRoute: deletes routes with the gateway and destination
func (c *RtNetlinkConfig) DeleteRoute(ifName string, route Route) error {
iface, err := net.InterfaceByName(ifName)
@ -219,6 +213,7 @@ func (c *RtNetlinkConfig) DeleteRoute(ifName string, route Route) error {
return nil
}
// route: represents a rout to add to the RIB
type Route struct {
Gateway net.IP
Destination net.IPNet
@ -232,7 +227,7 @@ func (r1 Route) equal(r2 Route) bool {
(mask1Ones == 0 && mask2Ones == 0 || r1.Destination.IP.Equal(r2.Destination.IP))
}
// DeleteRoutes deletes all routes not in exclude
// DeleteRoutes: deletes all routes not in exclude on the given interface
func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...Route) error {
routes, err := c.listRoutes(ifName, family)
@ -282,7 +277,7 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
return nil
}
// listRoutes lists all routes on the interface
// listRoutes: lists all routes on the interface
func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8) ([]rtnetlink.RouteMessage, error) {
iface, err := net.InterfaceByName(ifName)
@ -304,6 +299,18 @@ func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8) ([]rtnetlink.R
return routes, nil
}
// Close: close the Rtnetlink API
func (c *RtNetlinkConfig) Close() error {
return c.conn.Close()
}
// newRtNetlinkConfig: connect to the RtnetlinkAPI
func NewRtNetlinkConfig() (*RtNetlinkConfig, error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return nil, err
}
return &RtNetlinkConfig{conn: conn}, nil
}

View File

@ -6,6 +6,7 @@ import (
"os"
"github.com/sirupsen/logrus"
"github.com/tim-beatham/smegmesh/pkg/conf"
)
var (
@ -39,17 +40,29 @@ func (l *LogrusLogger) Writer() io.Writer {
return l.logger.Writer()
}
func NewLogrusLogger() *LogrusLogger {
func NewLogrusLogger(confLevel conf.LogLevel) *LogrusLogger {
var level logrus.Level
switch confLevel {
case conf.ERROR:
level = logrus.ErrorLevel
case conf.WARNING:
level = logrus.WarnLevel
case conf.INFO:
level = logrus.InfoLevel
}
logger := logrus.New()
logger.SetFormatter(&logrus.TextFormatter{FullTimestamp: true})
logger.SetOutput(os.Stdout)
logger.SetLevel(logrus.InfoLevel)
logger.SetLevel(level)
return &LogrusLogger{logger: logger}
}
func init() {
SetLogger(NewLogrusLogger())
SetLogger(NewLogrusLogger(conf.INFO))
}
func SetLogger(l Logger) {

View File

@ -7,21 +7,22 @@ import (
"strings"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/route"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/route"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// MeshConfigApplyer abstracts applying the mesh configuration
type MeshConfigApplyer interface {
// ApplyConfig: apply the configurtation
ApplyConfig() error
RemovePeers(meshId string) error
// SetMeshManager: sets the associated manager
SetMeshManager(manager MeshManager)
}
// WgMeshConfigApplyer applies WireGuard configuration
// WgMeshConfigApplyer: applies WireGuard configuration
type WgMeshConfigApplyer struct {
meshManager MeshManager
routeInstaller route.RouteInstaller
@ -91,7 +92,11 @@ func (m *WgMeshConfigApplyer) convertMeshNode(params convertMeshNodeParams) (*wg
return p.PublicKey.String() == pubKey.String()
})
endpoint, err := net.ResolveUDPAddr("udp", params.node.GetWgEndpoint())
var endpoint *net.UDPAddr = nil
if params.node.GetType() == conf.PEER_ROLE {
endpoint, err = net.ResolveUDPAddr("udp", params.node.GetWgEndpoint())
}
if err != nil {
return nil, err
@ -115,8 +120,13 @@ func (m *WgMeshConfigApplyer) convertMeshNode(params convertMeshNodeParams) (*wg
// getRoutes: finds the routes with the least hop distance. If more than one route exists
// consistently hash to evenly spread the distribution of traffic
func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]routeNode {
mesh, _ := meshProvider.GetMesh()
func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) (map[string][]routeNode, error) {
mesh, err := meshProvider.GetMesh()
if err != nil {
return nil, err
}
routes := make(map[string][]routeNode)
peers := lib.Filter(lib.MapValues(mesh.GetNodes()), func(p MeshNode) bool {
@ -154,17 +164,18 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
// Client's only acessible by another peer
if node.GetType() == conf.CLIENT_ROLE {
peer := m.getCorrespondingPeer(peers, node)
self, _ := m.meshManager.GetSelf(meshProvider.GetMeshId())
self, err := meshProvider.GetNode(m.meshManager.GetPublicKey().String())
if err != nil {
return nil, err
}
// If the node isn't the self use that peer as the gateway
if !NodeEquals(peer, self) {
peerPub, _ := peer.GetPublicKey()
rn.gateway = peerPub.String()
rn.route = &RouteStub{
Destination: rn.route.GetDestination(),
HopCount: rn.route.GetHopCount() + 1,
// Append the path to this peer
Path: append(rn.route.GetPath(), peer.GetWgHost().IP.String()),
Path: append(rn.route.GetPath(), peer.GetWgHost().IP.String()),
}
}
}
@ -181,7 +192,7 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
}
}
return routes
return routes, nil
}
// getCorrespondignPeer: gets the peer corresponding to the client
@ -190,6 +201,7 @@ func (m *WgMeshConfigApplyer) getCorrespondingPeer(peers []MeshNode, client Mesh
return peer
}
// getPeerCfgsToRemove: remove peer configurations that are no longer in the mesh
func (m *WgMeshConfigApplyer) getPeerCfgsToRemove(dev *wgtypes.Device, newPeers []wgtypes.PeerConfig) []wgtypes.PeerConfig {
peers := dev.Peers
peers = lib.Filter(peers, func(p1 wgtypes.Peer) bool {
@ -214,27 +226,37 @@ type GetConfigParams struct {
routes map[string][]routeNode
}
// getClientConfig: if the node is a client get their configuration
func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes.Config, error) {
self, err := m.meshManager.GetSelf(params.mesh.GetMeshId())
ula := &ip.ULABuilder{}
meshNet, _ := ula.GetIPNet(params.mesh.GetMeshId())
routesForMesh := lib.Map(lib.MapValues(params.routes), func(rns []routeNode) []routeNode {
return lib.Filter(rns, func(rn routeNode) bool {
ip, _, _ := net.ParseCIDR(rn.gateway)
return meshNet.Contains(ip)
node, err := params.mesh.GetNode(rn.gateway)
return node != nil && err == nil
})
})
routesForMesh = lib.Filter(routesForMesh, func(rns []routeNode) bool {
return len(rns) != 0
})
routes := lib.Map(routesForMesh, func(rs []routeNode) net.IPNet {
return *rs[0].route.GetDestination()
})
routes = append(routes, *meshNet)
self, err := params.mesh.GetNode(m.meshManager.GetPublicKey().String())
if err != nil {
return nil, err
}
if len(params.peers) == 0 {
return nil, fmt.Errorf("no peers in the mesh")
}
peer := m.getCorrespondingPeer(params.peers, self)
pubKey, _ := peer.GetPublicKey()
@ -260,10 +282,14 @@ func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes
installedRoutes := make([]lib.Route, 0)
for _, route := range peerCfgs[0].AllowedIPs {
installedRoutes = append(installedRoutes, lib.Route{
Gateway: peer.GetWgHost().IP,
Destination: route,
})
// Don't install routes that we are directly apart
// Dont install default route wgctrl handles this for us
if !meshNet.Contains(route.IP) {
installedRoutes = append(installedRoutes, lib.Route{
Gateway: peer.GetWgHost().IP,
Destination: route,
})
}
}
cfg := wgtypes.Config{
@ -274,6 +300,8 @@ func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes
return &cfg, err
}
// getRoutesToInstall: work out if the given node is advertising routes that should be installed into the
// RIB
func (m *WgMeshConfigApplyer) getRoutesToInstall(wgNode *wgtypes.PeerConfig, mesh MeshProvider, node MeshNode) []lib.Route {
routes := make([]lib.Route, 0)
@ -281,9 +309,8 @@ func (m *WgMeshConfigApplyer) getRoutesToInstall(wgNode *wgtypes.PeerConfig, mes
ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
_, defaultRoute, _ := net.ParseCIDR("::/0")
if !ipNet.Contains(route.IP) && !ipNet.IP.Equal(defaultRoute.IP) {
// Check there is no overlap in network and its not the default route
if !ipNet.Contains(route.IP) {
routes = append(routes, lib.Route{
Gateway: node.GetWgHost().IP,
Destination: route,
@ -294,11 +321,12 @@ func (m *WgMeshConfigApplyer) getRoutesToInstall(wgNode *wgtypes.PeerConfig, mes
return routes
}
// getPeerConfig: creates the WireGuard configuration for a peer
func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.Config, error) {
peerToClients := make(map[string][]net.IPNet)
installedRoutes := make([]lib.Route, 0)
peerConfigs := make([]wgtypes.PeerConfig, 0)
self, err := m.meshManager.GetSelf(params.mesh.GetMeshId())
self, err := params.mesh.GetNode(m.meshManager.GetPublicKey().String())
if err != nil {
return nil, err
@ -367,6 +395,7 @@ func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.C
return &cfg, err
}
// updateWgConf: update the WireGuard configuration
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string][]routeNode) error {
snap, err := mesh.GetMesh()
@ -389,7 +418,7 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string]
return mn.GetType() == conf.CLIENT_ROLE
})
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
self, err := mesh.GetNode(m.meshManager.GetPublicKey().String())
if err != nil {
return err
@ -428,11 +457,17 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string]
return nil
}
func (m *WgMeshConfigApplyer) getAllRoutes() map[string][]routeNode {
// getAllRoutes: works out all the routes to install out of all the routes in the
// set of networks the node is a part of
func (m *WgMeshConfigApplyer) getAllRoutes() (map[string][]routeNode, error) {
allRoutes := make(map[string][]routeNode)
for _, mesh := range m.meshManager.GetMeshes() {
routes := m.getRoutes(mesh)
routes, err := m.getRoutes(mesh)
if err != nil {
return nil, err
}
for destination, route := range routes {
_, ok := allRoutes[destination]
@ -450,11 +485,16 @@ func (m *WgMeshConfigApplyer) getAllRoutes() map[string][]routeNode {
}
}
return allRoutes
return allRoutes, nil
}
// ApplyConfig: apply the WireGuard configuration
func (m *WgMeshConfigApplyer) ApplyConfig() error {
allRoutes := m.getAllRoutes()
allRoutes, err := m.getAllRoutes()
if err != nil {
return err
}
for _, mesh := range m.meshManager.GetMeshes() {
err := m.updateWgConf(mesh, allRoutes)
@ -467,27 +507,6 @@ func (m *WgMeshConfigApplyer) ApplyConfig() error {
return nil
}
func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error {
mesh := m.meshManager.GetMesh(meshId)
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
dev, err := mesh.GetDevice()
if err != nil {
return err
}
m.meshManager.GetClient().ConfigureDevice(dev.Name, wgtypes.Config{
Peers: make([]wgtypes.PeerConfig, 0),
ReplacePeers: true,
})
return nil
}
func (m *WgMeshConfigApplyer) SetMeshManager(manager MeshManager) {
m.meshManager = manager
}

View File

@ -1,77 +0,0 @@
package mesh
import (
"errors"
"fmt"
"github.com/tim-beatham/wgmesh/pkg/graph"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
// MeshGraphConverter converts a mesh to a graph
type MeshGraphConverter interface {
// convert the mesh to textual form
Generate(meshId string) (string, error)
}
type MeshDOTConverter struct {
manager MeshManager
}
func (c *MeshDOTConverter) Generate(meshId string) (string, error) {
mesh := c.manager.GetMesh(meshId)
if mesh == nil {
return "", errors.New("mesh does not exist")
}
g := graph.NewGraph(meshId, graph.GRAPH)
snapshot, err := mesh.GetMesh()
if err != nil {
return "", err
}
for _, node := range snapshot.GetNodes() {
c.graphNode(g, node, meshId)
}
nodes := lib.MapValues(snapshot.GetNodes())
for i, node1 := range nodes[:len(nodes)-1] {
for _, node2 := range nodes[i+1:] {
if node1.GetWgEndpoint() == node2.GetWgEndpoint() {
continue
}
node1Id := fmt.Sprintf("\"%s\"", node1.GetIdentifier())
node2Id := fmt.Sprintf("\"%s\"", node2.GetIdentifier())
g.AddEdge(fmt.Sprintf("%s to %s", node1Id, node2Id), node1Id, node2Id)
}
}
return g.GetDOT()
}
// graphNode: graphs a node within the mesh
func (c *MeshDOTConverter) graphNode(g *graph.Graph, node MeshNode, meshId string) {
nodeId := fmt.Sprintf("\"%s\"", node.GetIdentifier())
g.PutNode(nodeId, graph.CIRCLE)
self, _ := c.manager.GetSelf(meshId)
if NodeEquals(self, node) {
return
}
for _, route := range node.GetRoutes() {
routeId := fmt.Sprintf("\"%s\"", route)
g.PutNode(routeId, graph.HEXAGON)
g.AddEdge(fmt.Sprintf("%s to %s", nodeId, routeId), nodeId, routeId)
}
}
func NewMeshDotConverter(m MeshManager) MeshGraphConverter {
return &MeshDOTConverter{manager: m}
}

View File

@ -3,17 +3,21 @@ package mesh
import (
"errors"
"fmt"
"net"
"sync"
"github.com/tim-beatham/wgmesh/pkg/cmd"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/wg"
"github.com/tim-beatham/smegmesh/pkg/cmd"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/wg"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// MeshManager: abstracts maanging meshes, including installing the WireGuard configuration
// to the device, and adding and removing nodes
type MeshManager interface {
CreateMesh(params *CreateMeshParams) (string, error)
AddMesh(params *AddMeshParams) error
@ -24,10 +28,10 @@ type MeshManager interface {
LeaveMesh(meshId string) error
GetSelf(meshId string) (MeshNode, error)
ApplyConfig() error
SetDescription(description string) error
SetAlias(alias string) error
SetService(service string, value string) error
RemoveService(service string) error
SetDescription(meshId, description string) error
SetAlias(meshId, alias string) error
SetService(meshId, service, value string) error
RemoveService(meshId, service string) error
UpdateTimeStamp() error
GetClient() *wgctrl.Client
GetMeshes() map[string]MeshProvider
@ -37,12 +41,10 @@ type MeshManager interface {
}
type MeshManagerImpl struct {
lock sync.RWMutex
Meshes map[string]MeshProvider
RouteManager RouteManager
Client *wgctrl.Client
// HostParameters contains information that uniquely locates
// the node in the mesh network.
meshLock sync.RWMutex
meshes map[string]MeshProvider
RouteManager RouteManager
Client *wgctrl.Client
HostParameters *HostParameters
conf *conf.DaemonConfiguration
meshProviderFactory MeshProviderFactory
@ -55,39 +57,43 @@ type MeshManagerImpl struct {
OnDelete func(MeshProvider)
}
// GetRouteManager implements MeshManager.
func (m *MeshManagerImpl) GetRouteManager() RouteManager {
return m.RouteManager
}
// RemoveService implements MeshManager.
func (m *MeshManagerImpl) RemoveService(service string) error {
for _, mesh := range m.Meshes {
err := mesh.RemoveService(m.HostParameters.GetPublicKey(), service)
// RemoveService: remove a service from the given mesh.
func (m *MeshManagerImpl) RemoveService(meshId, service string) error {
mesh := m.GetMesh(meshId)
if err != nil {
return err
}
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
return nil
}
// SetService implements MeshManager.
func (m *MeshManagerImpl) SetService(service string, value string) error {
for _, mesh := range m.Meshes {
err := mesh.AddService(m.HostParameters.GetPublicKey(), service, value)
if err != nil {
return err
}
if !mesh.NodeExists(m.HostParameters.GetPublicKey()) {
return fmt.Errorf("node %s does not exist in the mesh", meshId)
}
return nil
return mesh.RemoveService(m.HostParameters.GetPublicKey(), service)
}
// SetService: add a service to the given mesh
func (m *MeshManagerImpl) SetService(meshId, service, value string) error {
mesh := m.GetMesh(meshId)
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
if !mesh.NodeExists(m.HostParameters.GetPublicKey()) {
return fmt.Errorf("node %s does not exist in the mesh", meshId)
}
return mesh.AddService(m.HostParameters.GetPublicKey(), service, value)
}
// GetNode: gets the node with given id in the mesh network
func (m *MeshManagerImpl) GetNode(meshid, nodeId string) MeshNode {
mesh, ok := m.Meshes[meshid]
mesh, ok := m.meshes[meshid]
if !ok {
return nil
@ -134,6 +140,10 @@ func (m *MeshManagerImpl) CreateMesh(args *CreateMeshParams) (string, error) {
return "", err
}
if *meshConfiguration.Role == conf.CLIENT_ROLE {
return "", fmt.Errorf("cannot create mesh as a client")
}
meshId, err := m.idGenerator.GetId()
var ifName string = ""
@ -166,9 +176,9 @@ func (m *MeshManagerImpl) CreateMesh(args *CreateMeshParams) (string, error) {
return "", fmt.Errorf("error creating mesh: %w", err)
}
m.lock.Lock()
m.Meshes[meshId] = nodeManager
m.lock.Unlock()
m.meshLock.Lock()
m.meshes[meshId] = nodeManager
m.meshLock.Unlock()
m.cmdRunner.RunCommands(m.conf.BaseConfiguration.PostUp...)
@ -182,7 +192,7 @@ type AddMeshParams struct {
Conf *conf.WgConfiguration
}
// AddMesh: Add the mesh to the list of meshes
// AddMesh: Add a new mesh network to the list of addresses
func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
var ifName string
var err error
@ -225,20 +235,20 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
return err
}
m.lock.Lock()
m.Meshes[params.MeshId] = meshProvider
m.lock.Unlock()
m.meshLock.Lock()
m.meshes[params.MeshId] = meshProvider
m.meshLock.Unlock()
return nil
}
// HasChanges returns true if the mesh has changes
// HasChanges: returns true if the mesh has changes
func (m *MeshManagerImpl) HasChanges(meshId string) bool {
return m.Meshes[meshId].HasChanges()
return m.meshes[meshId].HasChanges()
}
// GetMesh returns the mesh with the given meshid
// GetMesh: returns the mesh with the given meshid
func (m *MeshManagerImpl) GetMesh(meshId string) MeshProvider {
theMesh := m.Meshes[meshId]
theMesh := m.meshes[meshId]
return theMesh
}
@ -248,6 +258,8 @@ func (s *MeshManagerImpl) GetPublicKey() *wgtypes.Key {
return &key
}
// AddSelfParams: parameters required to add yourself to a mesh
// network
type AddSelfParams struct {
// MeshId is the ID of the mesh to add this instance to
MeshId string
@ -257,7 +269,7 @@ type AddSelfParams struct {
Endpoint string
}
// AddSelf adds this host to the mesh
// AddSelf: adds this host to the mesh
func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
mesh := s.GetMesh(params.MeshId)
@ -277,10 +289,36 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
pubKey := s.HostParameters.PrivateKey.PublicKey()
nodeIP, err := s.ipAllocator.GetIP(pubKey, params.MeshId)
collisionCount := uint8(0)
if err != nil {
return err
var nodeIP net.IP
// Perform Duplicate Address Detection with the nodes
// that are already in the network
for {
generatedIP, err := s.ipAllocator.GetIP(pubKey, params.MeshId, collisionCount)
if err != nil {
return err
}
snapshot, err := mesh.GetMesh()
if err != nil {
return err
}
proposition := func(node MeshNode) bool {
ipNet := node.GetWgHost()
return ipNet.IP.Equal(nodeIP)
}
if lib.Contains(lib.MapValues(snapshot.GetNodes()), proposition) {
collisionCount++
} else {
nodeIP = generatedIP
break
}
}
node := s.nodeFactory.Build(&MeshNodeFactoryParams{
@ -305,11 +343,11 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
}
}
s.Meshes[params.MeshId].AddNode(node)
s.meshes[params.MeshId].AddNode(node)
return nil
}
// LeaveMesh leaves the mesh network
// LeaveMesh: leaves the mesh network and force a synchronsiation
func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
mesh := s.GetMesh(meshId)
@ -320,16 +358,16 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
err := mesh.RemoveNode(s.HostParameters.GetPublicKey())
if err != nil {
return err
logging.Log.WriteErrorf(err.Error())
}
if s.OnDelete != nil {
s.OnDelete(mesh)
}
s.lock.Lock()
delete(s.Meshes, meshId)
s.lock.Unlock()
s.meshLock.Lock()
delete(s.meshes, meshId)
s.meshLock.Unlock()
s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PreDown...)
@ -348,12 +386,11 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
}
s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PostDown...)
return err
}
func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
meshInstance, ok := s.Meshes[meshId]
meshInstance, ok := s.meshes[meshId]
if !ok {
return nil, fmt.Errorf("mesh %s does not exist", meshId)
@ -368,51 +405,46 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
return node, nil
}
// ApplyConfig: applies the WireGuard configuration
// adds routes to the RIB and so forth.
func (s *MeshManagerImpl) ApplyConfig() error {
if s.conf.StubWg {
return nil
}
err := s.configApplyer.ApplyConfig()
if err != nil {
return err
}
return nil
return s.configApplyer.ApplyConfig()
}
func (s *MeshManagerImpl) SetDescription(description string) error {
meshes := s.GetMeshes()
for _, mesh := range meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetDescription(s.HostParameters.GetPublicKey(), description)
func (s *MeshManagerImpl) SetDescription(meshId, description string) error {
mesh := s.GetMesh(meshId)
if err != nil {
return err
}
}
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
return nil
}
// SetAlias implements MeshManager.
func (s *MeshManagerImpl) SetAlias(alias string) error {
meshes := s.GetMeshes()
for _, mesh := range meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetAlias(s.HostParameters.GetPublicKey(), alias)
if err != nil {
return err
}
}
if !mesh.NodeExists(s.HostParameters.GetPublicKey()) {
return fmt.Errorf("node %s does not exist in the mesh", meshId)
}
return nil
return mesh.SetDescription(s.HostParameters.GetPublicKey(), description)
}
// UpdateTimeStamp updates the timestamp of this node in all meshes
// SetAlias sets the alias of the node for the given meshid
func (s *MeshManagerImpl) SetAlias(meshId, alias string) error {
mesh := s.GetMesh(meshId)
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
if !mesh.NodeExists(s.HostParameters.GetPublicKey()) {
return fmt.Errorf("node %s does not exist in the mesh", meshId)
}
return mesh.SetAlias(s.HostParameters.GetPublicKey(), alias)
}
// UpdateTimeStamp: updates the timestamp of this node in all meshes
// essentially performs heartbeat if the node is the leader
func (s *MeshManagerImpl) UpdateTimeStamp() error {
meshes := s.GetMeshes()
for _, mesh := range meshes {
@ -432,26 +464,30 @@ func (s *MeshManagerImpl) GetClient() *wgctrl.Client {
return s.Client
}
// GetMeshes: get all meshes the node is part of
func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider {
meshes := make(map[string]MeshProvider)
s.lock.RLock()
// GetMesh: copies the map of meshes to a new map
// to prevent a whole range of concurrency issues
// due to iteration and modification
s.meshLock.RLock()
for id, mesh := range s.Meshes {
for id, mesh := range s.meshes {
meshes[id] = mesh
}
s.lock.RUnlock()
s.meshLock.RUnlock()
return meshes
}
// Close the mesh manager
// Close: close the mesh manager
func (s *MeshManagerImpl) Close() error {
if s.conf.StubWg {
return nil
}
for _, mesh := range s.Meshes {
for _, mesh := range s.meshes {
dev, err := mesh.GetDevice()
if err != nil {
@ -468,7 +504,7 @@ func (s *MeshManagerImpl) Close() error {
return nil
}
// NewMeshManagerParams params required to create an instance of a mesh manager
// NewMeshManagerParams: params required to create an instance of a mesh manager
type NewMeshManagerParams struct {
Conf conf.DaemonConfiguration
Client *wgctrl.Client
@ -483,7 +519,7 @@ type NewMeshManagerParams struct {
OnDelete func(MeshProvider)
}
// Creates a new instance of a mesh manager with the given parameters
// NewMeshManager: Creates a new instance of a mesh manager with the given parameters
func NewMeshManager(params *NewMeshManagerParams) MeshManager {
privateKey, _ := wgtypes.GeneratePrivateKey()
hostParams := HostParameters{
@ -491,7 +527,7 @@ func NewMeshManager(params *NewMeshManagerParams) MeshManager {
}
m := &MeshManagerImpl{
Meshes: make(map[string]MeshProvider),
meshes: make(map[string]MeshProvider),
HostParameters: &hostParams,
meshProviderFactory: params.MeshProvider,
nodeFactory: params.NodeFactory,

View File

@ -3,10 +3,10 @@ package mesh
import (
"testing"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/wg"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/wg"
)
func getMeshConfiguration() *conf.DaemonConfiguration {
@ -24,11 +24,11 @@ func getMeshConfiguration() *conf.DaemonConfiguration {
Timeout: 5,
Profile: false,
StubWg: true,
SyncRate: 2,
KeepAliveTime: 60,
SyncInterval: 2,
Heartbeat: 60,
ClusterSize: 64,
InterClusterChance: 0.15,
BranchRate: 3,
Branch: 3,
InfectionCount: 3,
BaseConfiguration: conf.WgConfiguration{
IPDiscovery: &ipDiscovery,
@ -213,7 +213,7 @@ func TestLeaveMeshDeletesMesh(t *testing.T) {
}
}
func TestSetAlias(t *testing.T) {
func TestSetAliasUpdatesAliasOfNode(t *testing.T) {
manager := getMeshManager()
alias := "Firpo"
@ -221,14 +221,13 @@ func TestSetAlias(t *testing.T) {
Port: 5000,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
err := manager.SetAlias(alias)
err := manager.SetAlias(meshId, alias)
if err != nil {
t.Fatalf(`failed to set the alias`)
@ -245,7 +244,7 @@ func TestSetAlias(t *testing.T) {
}
}
func TestSetDescription(t *testing.T) {
func TestSetDescriptionSetsTheDescriptionOfTheNode(t *testing.T) {
manager := getMeshManager()
description := "wooooo"
@ -254,23 +253,13 @@ func TestSetDescription(t *testing.T) {
Conf: &conf.WgConfiguration{},
})
meshId2, _ := manager.CreateMesh(&CreateMeshParams{
Port: 5001,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId2,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
err := manager.SetDescription(description)
err := manager.SetDescription(meshId1, description)
if err != nil {
t.Fatalf(`failed to set the descriptions`)
@ -285,18 +274,7 @@ func TestSetDescription(t *testing.T) {
if description != self1.GetDescription() {
t.Fatalf(`description should be %s was %s`, description, self1.GetDescription())
}
self2, err := manager.GetSelf(meshId2)
if err != nil {
t.Fatalf(`failed to set the description`)
}
if description != self2.GetDescription() {
t.Fatalf(`description should be %s was %s`, description, self2.GetDescription())
}
}
func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) {
manager := getMeshManager()
@ -327,3 +305,68 @@ func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) {
t.Fatalf(`failed to update the timestamp`)
}
}
func TestAddServiceAddsServiceToTheMesh(t *testing.T) {
manager := getMeshManager()
meshId1, _ := manager.CreateMesh(&CreateMeshParams{
Port: 5000,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
serviceName := "hello"
manager.SetService(meshId1, serviceName, "dave")
self, err := manager.GetSelf(meshId1)
if err != nil {
t.Fatalf(`error thrown %s:`, err.Error())
}
if _, ok := self.GetServices()[serviceName]; !ok {
t.Fatalf(`service not added`)
}
}
func TestRemoveServiceRemovesTheServiceFromTheMesh(t *testing.T) {
manager := getMeshManager()
meshId1, _ := manager.CreateMesh(&CreateMeshParams{
Port: 5000,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
serviceName := "hello"
manager.SetService(meshId1, serviceName, "dave")
self, err := manager.GetSelf(meshId1)
if err != nil {
t.Fatalf(`error thrown %s:`, err.Error())
}
if _, ok := self.GetServices()[serviceName]; !ok {
t.Fatalf(`service not added`)
}
manager.RemoveService(meshId1, serviceName)
self, err = manager.GetSelf(meshId1)
if err != nil {
t.Fatalf(`error thrown %s:`, err.Error())
}
if _, ok := self.GetServices()[serviceName]; ok {
t.Fatalf(`service still exists`)
}
}

View File

@ -3,11 +3,13 @@ package mesh
import (
"net"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/smegmesh/pkg/lib"
)
// RouteManager: manager that leaks routes between meshes
type RouteManager interface {
// UpdateRoutes: leak all routes in each mesh
UpdateRoutes() error
}
@ -19,12 +21,17 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
meshes := r.meshManager.GetMeshes()
routes := make(map[string][]Route)
for _, mesh := range meshes {
// Make empty routes so that routes are retracted
routes[mesh.GetMeshId()] = make([]Route, 0)
}
for _, mesh1 := range meshes {
if !*mesh1.GetConfiguration().AdvertiseRoutes {
continue
}
self, err := r.meshManager.GetSelf(mesh1.GetMeshId())
self, err := mesh1.GetNode(r.meshManager.GetPublicKey().String())
if err != nil {
return err
@ -39,7 +46,6 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
defaultRoute := &RouteStub{
Destination: ipv6Default,
HopCount: 0,
Path: []string{mesh1.GetMeshId()},
}
@ -68,7 +74,6 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
routeValues = append(routeValues, &RouteStub{
Destination: mesh1IpNet,
HopCount: 0,
Path: []string{mesh1.GetMeshId()},
})
@ -90,15 +95,21 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
// Calculate the set different of each, working out routes to remove and to keep.
for meshId, meshRoutes := range routes {
mesh := r.meshManager.GetMesh(meshId)
self, _ := r.meshManager.GetSelf(meshId)
mesh := meshes[meshId]
self, err := mesh.GetNode(r.meshManager.GetPublicKey().String())
if err != nil {
return err
}
toRemove := make([]Route, 0)
prevRoutes, _ := mesh.GetRoutes(NodeID(self))
prevRoutes := self.GetRoutes()
for _, route := range prevRoutes {
if !lib.Contains(meshRoutes, func(r Route) bool {
return RouteEquals(r, route)
return RouteEqual(r, route)
}) {
toRemove = append(toRemove, route)
}

View File

@ -5,8 +5,8 @@ import (
"net"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/lib"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@ -30,8 +30,8 @@ func (*MeshNodeStub) GetType() conf.NodeType {
}
// GetServices implements MeshNode.
func (*MeshNodeStub) GetServices() map[string]string {
return make(map[string]string)
func (m *MeshNodeStub) GetServices() map[string]string {
return m.services
}
// GetAlias implements MeshNode.
@ -249,6 +249,7 @@ func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode {
routes: make([]Route, 0),
identifier: "abc",
description: "A Mesh Node Stub",
services: make(map[string]string),
}
}
@ -271,32 +272,32 @@ type MeshManagerStub struct {
// GetRouteManager implements MeshManager.
func (*MeshManagerStub) GetRouteManager() RouteManager {
panic("unimplemented")
return nil
}
// GetNode implements MeshManager.
func (*MeshManagerStub) GetNode(string, string) MeshNode {
panic("unimplemented")
func (*MeshManagerStub) GetNode(meshId, nodeId string) MeshNode {
return nil
}
// RemoveService implements MeshManager.
func (*MeshManagerStub) RemoveService(service string) error {
panic("unimplemented")
func (*MeshManagerStub) RemoveService(meshId, service string) error {
return nil
}
// SetService implements MeshManager.
func (*MeshManagerStub) SetService(service string, value string) error {
panic("unimplemented")
func (*MeshManagerStub) SetService(meshId, service, value string) error {
return nil
}
// SetAlias implements MeshManager.
func (*MeshManagerStub) SetAlias(alias string) error {
panic("unimplemented")
func (*MeshManagerStub) SetAlias(meshId, alias string) error {
return nil
}
// Close implements MeshManager.
func (*MeshManagerStub) Close() error {
panic("unimplemented")
return nil
}
// Prune implements MeshManager.
@ -348,7 +349,7 @@ func (m *MeshManagerStub) ApplyConfig() error {
return nil
}
func (m *MeshManagerStub) SetDescription(description string) error {
func (m *MeshManagerStub) SetDescription(meshId, description string) error {
return nil
}

View File

@ -6,7 +6,7 @@ import (
"net"
"slices"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/conf"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@ -21,12 +21,6 @@ type Route interface {
}
func RouteEqual(r1 Route, r2 Route) bool {
return r1.GetDestination().IP.Equal(r2.GetDestination().IP) &&
r1.GetHopCount() == r2.GetHopCount() &&
slices.Equal(r1.GetPath(), r2.GetPath())
}
func RouteEquals(r1, r2 Route) bool {
return r1.GetDestination().String() == r2.GetDestination().String() &&
r1.GetHopCount() == r2.GetHopCount() &&
slices.Equal(r1.GetPath(), r2.GetPath())
@ -34,7 +28,6 @@ func RouteEquals(r1, r2 Route) bool {
type RouteStub struct {
Destination *net.IPNet
HopCount int
Path []string
}
@ -43,7 +36,7 @@ func (r *RouteStub) GetDestination() *net.IPNet {
}
func (r *RouteStub) GetHopCount() int {
return r.HopCount
return len(r.Path)
}
func (r *RouteStub) GetPath() []string {
@ -81,6 +74,10 @@ func NodeEquals(node1, node2 MeshNode) bool {
key1, _ := node1.GetPublicKey()
key2, _ := node2.GetPublicKey()
if node1 == nil || node2 == nil {
return false
}
return key1.String() == key2.String()
}

View File

@ -6,9 +6,9 @@ import (
"strings"
"github.com/jmespath/go-jmespath"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/mesh"
)
// Querier queries a data store for the given data
@ -17,20 +17,24 @@ type Querier interface {
Query(meshId string, queryParams string) ([]byte, error)
}
// JmesQuerier: queries the datstore in JMESPath syntax
type JmesQuerier struct {
manager mesh.MeshManager
}
// QueryError: query error if something went wrong
type QueryError struct {
msg string
}
// QuerRoute: represents a route in the query
type QueryRoute struct {
Destination string `json:"destination"`
HopCount int `json:"hopCount"`
Path string `json:"path"`
}
// QueryNode: represents a single node in the query
type QueryNode struct {
HostEndpoint string `json:"hostEndpoint"`
PublicKey string `json:"publicKey"`
@ -48,7 +52,7 @@ func (m *QueryError) Error() string {
return m.msg
}
// Query: queries the data
// Query: queries the the datastore at the given meshid
func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
mesh, ok := j.manager.GetMeshes()[meshId]
@ -74,6 +78,7 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
return bytes, err
}
// MeshNodeToQuerynode: convert the mesh node into a query abstraction
func MeshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode := new(QueryNode)
queryNode.HostEndpoint = node.GetHostEndpoint()

View File

@ -4,20 +4,22 @@ import (
"context"
"errors"
"fmt"
"strconv"
"slices"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/rpc"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/smegmesh/pkg/ipc"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/rpc"
)
// IpcHandler: represents a handler for ipc calls
type IpcHandler struct {
Server ctrlserver.CtrlServer
}
// getOverrideConfiguration: override any specific WireGuard configuration
func getOverrideConfiguration(args *ipc.WireGuardArgs) conf.WgConfiguration {
overrideConf := conf.WgConfiguration{}
@ -40,20 +42,17 @@ func getOverrideConfiguration(args *ipc.WireGuardArgs) conf.WgConfiguration {
return overrideConf
}
// CreateMesh: create a new mesh network
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
overrideConf := getOverrideConfiguration(&args.WgArgs)
if overrideConf.Role != nil && *overrideConf.Role == conf.CLIENT_ROLE {
return fmt.Errorf("cannot create a mesh with no public endpoint")
}
meshId, err := n.Server.GetMeshManager().CreateMesh(&mesh.CreateMeshParams{
Port: args.WgArgs.WgPort,
Conf: &overrideConf,
})
if err != nil {
return err
return errors.New("could not create mesh")
}
err = n.Server.GetMeshManager().AddSelf(&mesh.AddSelfParams{
@ -63,13 +62,14 @@ func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
})
if err != nil {
return err
return errors.New("could not create mesh")
}
*reply = meshId
return err
}
// ListMeshes: list mesh networks
func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
meshNames := make([]string, len(n.Server.GetMeshManager().GetMeshes()))
@ -79,29 +79,35 @@ func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
i++
}
slices.Sort(meshNames)
*reply = ipc.ListMeshReply{Meshes: meshNames}
return nil
}
func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
// JoinMesh: join a mesh network
func (n *IpcHandler) JoinMesh(args *ipc.JoinMeshArgs, reply *string) error {
overrideConf := getOverrideConfiguration(&args.WgArgs)
peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAdress)
if n.Server.GetMeshManager().GetMesh(args.MeshId) != nil {
return fmt.Errorf("user is already apart of the mesh")
}
peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAddress)
if err != nil {
return err
return fmt.Errorf("could not join mesh %s", args.MeshId)
}
client, err := peerConnection.GetClient()
if err != nil {
return err
return fmt.Errorf("could not join mesh %s", args.MeshId)
}
c := rpc.NewMeshCtrlServerClient(client)
if err != nil {
return err
return fmt.Errorf("could not join mesh %s", args.MeshId)
}
configuration := n.Server.GetConfiguration()
@ -112,7 +118,7 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
meshReply, err := c.GetMesh(ctx, &rpc.GetMeshRequest{MeshId: args.MeshId})
if err != nil {
return err
return fmt.Errorf("could not join mesh %s", args.MeshId)
}
err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
@ -123,7 +129,7 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
})
if err != nil {
return err
return fmt.Errorf("could not join mesh %s", args.MeshId)
}
err = n.Server.GetMeshManager().AddSelf(&mesh.AddSelfParams{
@ -133,24 +139,24 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
})
if err != nil {
return err
return fmt.Errorf("could not join mesh %s", args.MeshId)
}
*reply = strconv.FormatBool(true)
*reply = fmt.Sprintf("Successfully Joined: %s", args.MeshId)
return nil
}
// LeaveMesh leaves a mesh network
// LeaveMesh: leaves a mesh network
func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error {
err := n.Server.GetMeshManager().LeaveMesh(meshId)
if err == nil {
*reply = fmt.Sprintf("Left Mesh %s", meshId)
}
return err
}
// GetMesh: get a mesh network at the given meshid
func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
theMesh := n.Server.GetMeshManager().GetMesh(meshId)
@ -182,19 +188,7 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
return nil
}
func (n *IpcHandler) GetDOT(meshId string, reply *string) error {
g := mesh.NewMeshDotConverter(n.Server.GetMeshManager())
result, err := g.Generate(meshId)
if err != nil {
return err
}
*reply = result
return nil
}
// Query: perform a jmespath query
func (n *IpcHandler) Query(params ipc.QueryMesh, reply *string) error {
queryResponse, err := n.Server.GetQuerier().Query(params.MeshId, params.Query)
@ -206,50 +200,59 @@ func (n *IpcHandler) Query(params ipc.QueryMesh, reply *string) error {
return nil
}
func (n *IpcHandler) PutDescription(description string, reply *string) error {
err := n.Server.GetMeshManager().SetDescription(description)
// PutDescription: change your description in the mesh
func (n *IpcHandler) PutDescription(args ipc.PutDescriptionArgs, reply *string) error {
err := n.Server.GetMeshManager().SetDescription(args.MeshId, args.Description)
if err != nil {
return err
}
*reply = fmt.Sprintf("Set description to %s", description)
*reply = fmt.Sprintf("set description to %s for %s", args.Description, args.MeshId)
return nil
}
func (n *IpcHandler) PutAlias(alias string, reply *string) error {
err := n.Server.GetMeshManager().SetAlias(alias)
if err != nil {
return err
// PutAlias: put your aliasin the mesh
func (n *IpcHandler) PutAlias(args ipc.PutAliasArgs, reply *string) error {
if args.Alias == "" {
return fmt.Errorf("alias not provided")
}
*reply = fmt.Sprintf("Set alias to %s", alias)
err := n.Server.GetMeshManager().SetAlias(args.MeshId, args.Alias)
if err != nil {
return fmt.Errorf("could not set alias: %s", args.Alias)
}
*reply = fmt.Sprintf("Set alias to %s", args.Alias)
return nil
}
// PutService: place a service in the mesh
func (n *IpcHandler) PutService(service ipc.PutServiceArgs, reply *string) error {
err := n.Server.GetMeshManager().SetService(service.Service, service.Value)
err := n.Server.GetMeshManager().SetService(service.MeshId, service.Service, service.Value)
if err != nil {
return err
}
*reply = "success"
*reply = fmt.Sprintf("Set service %s in %s to %s", service.Service, service.MeshId, service.Value)
return nil
}
func (n *IpcHandler) DeleteService(service string, reply *string) error {
err := n.Server.GetMeshManager().RemoveService(service)
// DeleteService: withtract a service in the mesh
func (n *IpcHandler) DeleteService(service ipc.DeleteServiceArgs, reply *string) error {
err := n.Server.GetMeshManager().RemoveService(service.MeshId, service.Service)
if err != nil {
return err
}
*reply = "success"
*reply = fmt.Sprintf("Removed service %s from %s", service.Service, service.MeshId)
return nil
}
// RobinIpcParams: parameters required to construct a new mesh network
type RobinIpcParams struct {
CtrlServer ctrlserver.CtrlServer
}

View File

@ -3,10 +3,10 @@ package robin
import (
"testing"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/smegmesh/pkg/ipc"
"github.com/tim-beatham/smegmesh/pkg/mesh"
)
func getRequester() *IpcHandler {

View File

@ -4,15 +4,17 @@ import (
"context"
"errors"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/rpc"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/smegmesh/pkg/rpc"
)
// WgRpc: represents a WireGuard rpc call
type WgRpc struct {
rpc.UnimplementedMeshCtrlServerServer
Server *ctrlserver.MeshCtrlServer
}
// GetMesh: serialise the mesh network into bytes
func (m *WgRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.GetMeshReply, error) {
mesh := m.Server.MeshManager.GetMesh(request.MeshId)

View File

@ -1 +0,0 @@
package robin

View File

@ -1,10 +1,11 @@
package route
import (
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/lib"
"golang.org/x/sys/unix"
)
// RouteInstaller: install the routes to the given interface
type RouteInstaller interface {
InstallRoutes(devName string, routes ...lib.Route) error
}
@ -19,6 +20,8 @@ func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...lib.Route)
return err
}
defer rtnl.Close()
err = rtnl.DeleteRoutes(devName, unix.AF_INET6, routes...)
if err != nil {

View File

@ -1,235 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.1
// protoc v3.21.12
// source: pkg/grpc/ctrlserver/authentication.proto
package rpc
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type JoinAuthMeshRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
MeshId string `protobuf:"bytes,1,opt,name=meshId,proto3" json:"meshId,omitempty"`
Alias string `protobuf:"bytes,2,opt,name=alias,proto3" json:"alias,omitempty"`
}
func (x *JoinAuthMeshRequest) Reset() {
*x = JoinAuthMeshRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *JoinAuthMeshRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JoinAuthMeshRequest) ProtoMessage() {}
func (x *JoinAuthMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JoinAuthMeshRequest.ProtoReflect.Descriptor instead.
func (*JoinAuthMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_authentication_proto_rawDescGZIP(), []int{0}
}
func (x *JoinAuthMeshRequest) GetMeshId() string {
if x != nil {
return x.MeshId
}
return ""
}
func (x *JoinAuthMeshRequest) GetAlias() string {
if x != nil {
return x.Alias
}
return ""
}
type JoinAuthMeshReply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"`
Token *string `protobuf:"bytes,2,opt,name=token,proto3,oneof" json:"token,omitempty"`
}
func (x *JoinAuthMeshReply) Reset() {
*x = JoinAuthMeshReply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *JoinAuthMeshReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JoinAuthMeshReply) ProtoMessage() {}
func (x *JoinAuthMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JoinAuthMeshReply.ProtoReflect.Descriptor instead.
func (*JoinAuthMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_authentication_proto_rawDescGZIP(), []int{1}
}
func (x *JoinAuthMeshReply) GetSuccess() bool {
if x != nil {
return x.Success
}
return false
}
func (x *JoinAuthMeshReply) GetToken() string {
if x != nil && x.Token != nil {
return *x.Token
}
return ""
}
var File_pkg_grpc_ctrlserver_authentication_proto protoreflect.FileDescriptor
var file_pkg_grpc_ctrlserver_authentication_proto_rawDesc = []byte{
0x0a, 0x28, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73,
0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x61, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61,
0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x72, 0x70, 0x63, 0x74,
0x79, 0x70, 0x65, 0x73, 0x22, 0x43, 0x0a, 0x13, 0x4a, 0x6f, 0x69, 0x6e, 0x41, 0x75, 0x74, 0x68,
0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x6d,
0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73,
0x68, 0x49, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x61, 0x6c, 0x69, 0x61, 0x73, 0x18, 0x02, 0x20, 0x01,
0x28, 0x09, 0x52, 0x05, 0x61, 0x6c, 0x69, 0x61, 0x73, 0x22, 0x52, 0x0a, 0x11, 0x4a, 0x6f, 0x69,
0x6e, 0x41, 0x75, 0x74, 0x68, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x18,
0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52,
0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x19, 0x0a, 0x05, 0x74, 0x6f, 0x6b, 0x65,
0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e,
0x88, 0x01, 0x01, 0x42, 0x08, 0x0a, 0x06, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x32, 0x5a, 0x0a,
0x0e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12,
0x48, 0x0a, 0x08, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x1d, 0x2e, 0x72, 0x70,
0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x4d,
0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x4d, 0x65,
0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x42, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67,
0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_pkg_grpc_ctrlserver_authentication_proto_rawDescOnce sync.Once
file_pkg_grpc_ctrlserver_authentication_proto_rawDescData = file_pkg_grpc_ctrlserver_authentication_proto_rawDesc
)
func file_pkg_grpc_ctrlserver_authentication_proto_rawDescGZIP() []byte {
file_pkg_grpc_ctrlserver_authentication_proto_rawDescOnce.Do(func() {
file_pkg_grpc_ctrlserver_authentication_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_grpc_ctrlserver_authentication_proto_rawDescData)
})
return file_pkg_grpc_ctrlserver_authentication_proto_rawDescData
}
var file_pkg_grpc_ctrlserver_authentication_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_pkg_grpc_ctrlserver_authentication_proto_goTypes = []interface{}{
(*JoinAuthMeshRequest)(nil), // 0: rpctypes.JoinAuthMeshRequest
(*JoinAuthMeshReply)(nil), // 1: rpctypes.JoinAuthMeshReply
}
var file_pkg_grpc_ctrlserver_authentication_proto_depIdxs = []int32{
0, // 0: rpctypes.Authentication.JoinMesh:input_type -> rpctypes.JoinAuthMeshRequest
1, // 1: rpctypes.Authentication.JoinMesh:output_type -> rpctypes.JoinAuthMeshReply
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_pkg_grpc_ctrlserver_authentication_proto_init() }
func file_pkg_grpc_ctrlserver_authentication_proto_init() {
if File_pkg_grpc_ctrlserver_authentication_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*JoinAuthMeshRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*JoinAuthMeshReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[1].OneofWrappers = []interface{}{}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pkg_grpc_ctrlserver_authentication_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_pkg_grpc_ctrlserver_authentication_proto_goTypes,
DependencyIndexes: file_pkg_grpc_ctrlserver_authentication_proto_depIdxs,
MessageInfos: file_pkg_grpc_ctrlserver_authentication_proto_msgTypes,
}.Build()
File_pkg_grpc_ctrlserver_authentication_proto = out.File
file_pkg_grpc_ctrlserver_authentication_proto_rawDesc = nil
file_pkg_grpc_ctrlserver_authentication_proto_goTypes = nil
file_pkg_grpc_ctrlserver_authentication_proto_depIdxs = nil
}

View File

@ -1,105 +0,0 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.2.0
// - protoc v3.21.12
// source: pkg/grpc/ctrlserver/authentication.proto
package rpc
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// AuthenticationClient is the client API for Authentication service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type AuthenticationClient interface {
JoinMesh(ctx context.Context, in *JoinAuthMeshRequest, opts ...grpc.CallOption) (*JoinAuthMeshReply, error)
}
type authenticationClient struct {
cc grpc.ClientConnInterface
}
func NewAuthenticationClient(cc grpc.ClientConnInterface) AuthenticationClient {
return &authenticationClient{cc}
}
func (c *authenticationClient) JoinMesh(ctx context.Context, in *JoinAuthMeshRequest, opts ...grpc.CallOption) (*JoinAuthMeshReply, error) {
out := new(JoinAuthMeshReply)
err := c.cc.Invoke(ctx, "/rpctypes.Authentication/JoinMesh", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// AuthenticationServer is the server API for Authentication service.
// All implementations must embed UnimplementedAuthenticationServer
// for forward compatibility
type AuthenticationServer interface {
JoinMesh(context.Context, *JoinAuthMeshRequest) (*JoinAuthMeshReply, error)
mustEmbedUnimplementedAuthenticationServer()
}
// UnimplementedAuthenticationServer must be embedded to have forward compatible implementations.
type UnimplementedAuthenticationServer struct {
}
func (UnimplementedAuthenticationServer) JoinMesh(context.Context, *JoinAuthMeshRequest) (*JoinAuthMeshReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method JoinMesh not implemented")
}
func (UnimplementedAuthenticationServer) mustEmbedUnimplementedAuthenticationServer() {}
// UnsafeAuthenticationServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to AuthenticationServer will
// result in compilation errors.
type UnsafeAuthenticationServer interface {
mustEmbedUnimplementedAuthenticationServer()
}
func RegisterAuthenticationServer(s grpc.ServiceRegistrar, srv AuthenticationServer) {
s.RegisterService(&Authentication_ServiceDesc, srv)
}
func _Authentication_JoinMesh_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(JoinAuthMeshRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(AuthenticationServer).JoinMesh(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/rpctypes.Authentication/JoinMesh",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(AuthenticationServer).JoinMesh(ctx, req.(*JoinAuthMeshRequest))
}
return interceptor(ctx, in, info, handler)
}
// Authentication_ServiceDesc is the grpc.ServiceDesc for Authentication service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var Authentication_ServiceDesc = grpc.ServiceDesc{
ServiceName: "rpctypes.Authentication",
HandlerType: (*AuthenticationServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "JoinMesh",
Handler: _Authentication_JoinMesh_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "pkg/grpc/ctrlserver/authentication.proto",
}

View File

@ -1,201 +1,279 @@
package sync
import (
"errors"
"fmt"
"io"
"math/rand"
"sync"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/mesh"
)
// Syncer: picks random nodes from the meshs
type Syncer interface {
Sync(meshId string) error
Sync(theMesh mesh.MeshProvider) (bool, error)
SyncMeshes() error
}
// SyncerImpl: implementation of a syncer to sync meshes
type SyncerImpl struct {
manager mesh.MeshManager
meshManager mesh.MeshManager
requester SyncRequester
infectionCount int
syncCount int
cluster conn.ConnCluster
conf *conf.DaemonConfiguration
lastSync map[string]uint64
configuration *conf.DaemonConfiguration
lastSync map[string]int64
lastPoll map[string]int64
lastSyncLock sync.RWMutex
lastPollLock sync.RWMutex
}
// Sync: Sync random nodes
func (s *SyncerImpl) Sync(meshId string) error {
// Self can be nil if the node is removed
self, _ := s.manager.GetSelf(meshId)
// Sync: Sync with random nodes. Returns true if there was changes false otherwise
func (s *SyncerImpl) Sync(correspondingMesh mesh.MeshProvider) (bool, error) {
if correspondingMesh == nil {
return false, fmt.Errorf("mesh provided was nil cannot sync nil mesh")
}
correspondingMesh := s.manager.GetMesh(meshId)
// Self can be nil if the node is removed
selfID := s.meshManager.GetPublicKey()
self, _ := correspondingMesh.GetNode(selfID.String())
correspondingMesh.Prune()
if self != nil && self.GetType() == conf.PEER_ROLE && !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
logging.Log.WriteInfof("No changes for %s", meshId)
if correspondingMesh.HasChanges() {
logging.Log.WriteInfof("meshes %s has changes", correspondingMesh.GetMeshId())
}
// If not synchronised in certain pull from random neighbour
if uint64(time.Now().Unix())-s.lastSync[meshId] > 20 {
return s.Pull(meshId)
// If removed sync with other nodes to gossip the node is removed
if self != nil && self.GetType() == conf.PEER_ROLE && !correspondingMesh.HasChanges() && s.infectionCount == 0 {
logging.Log.WriteInfof("no changes for %s", correspondingMesh.GetMeshId())
// If not synchronised in certain time pull from random neighbour
if s.configuration.PullInterval != 0 && time.Now().Unix()-s.lastSync[correspondingMesh.GetMeshId()] > int64(s.configuration.PullInterval) {
return s.Pull(self, correspondingMesh)
}
return nil
return false, nil
}
before := time.Now()
s.manager.GetRouteManager().UpdateRoutes()
publicKey := s.manager.GetPublicKey()
logging.Log.WriteInfof(publicKey.String())
publicKey := s.meshManager.GetPublicKey()
nodeNames := correspondingMesh.GetPeers()
if self != nil {
nodeNames = lib.Filter(nodeNames, func(s string) bool {
return s != mesh.NodeID(self)
})
}
nodeNames = lib.Filter(nodeNames, func(s string) bool {
// Filter our only public key out so we dont sync with ourself
return s != publicKey.String()
})
var gossipNodes []string
// Clients always pings its peer for configuration
if self != nil && self.GetType() == conf.CLIENT_ROLE {
keyFunc := lib.HashString
bucketFunc := lib.HashString
if self != nil && self.GetType() == conf.CLIENT_ROLE && len(nodeNames) > 1 {
neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
neighbour := lib.ConsistentHash(nodeNames, publicKey.String(), keyFunc, bucketFunc)
gossipNodes = make([]string, 1)
gossipNodes[0] = neighbour
if len(neighbours) == 0 {
return false, nil
}
// Peer with 2 nodes so that there is redundnacy in
// the situation the node leaves pre-emptively
redundancyLength := min(len(neighbours), 2)
gossipNodes = neighbours[:redundancyLength]
} else {
neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
gossipNodes = lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
gossipNodes = lib.RandomSubsetOfLength(neighbours, s.configuration.Branch)
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
if len(nodeNames) > s.configuration.ClusterSize && rand.Float64() < s.configuration.InterClusterChance {
gossipNodes[len(gossipNodes)-1] = s.cluster.GetInterCluster(nodeNames, publicKey.String())
}
}
var succeeded bool = false
// Do this synchronously to conserve bandwidth
for _, node := range gossipNodes {
correspondingPeer := s.manager.GetNode(meshId, node)
var wait sync.WaitGroup
if correspondingPeer == nil {
logging.Log.WriteErrorf("node %s does not exist", node)
for index, node := range gossipNodes {
wait.Add(1)
syncNode := func(i int) {
correspondingPeer, err := correspondingMesh.GetNode(node)
defer wait.Done()
if correspondingPeer == nil || err != nil {
logging.Log.WriteErrorf("node %s does not exist", node)
return
}
err = s.requester.SyncMesh(correspondingMesh, correspondingPeer)
if err == nil || err == io.EOF {
succeeded = true
}
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
}
err := s.requester.SyncMesh(meshId, correspondingPeer)
if err == nil || err == io.EOF {
succeeded = true
} else {
// If the synchronisation operation has failed them mark a gravestone
// preventing the peer from being re-contacted until it has updated
// itself
s.manager.GetMesh(meshId).Mark(node)
}
go syncNode(index)
}
s.syncCount++
logging.Log.WriteInfof("SYNC TIME: %v", time.Since(before))
logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount)
wait.Wait()
s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount)
s.syncCount++
logging.Log.WriteInfof("sync time: %v", time.Since(before))
logging.Log.WriteInfof("number of syncs: %d", s.syncCount)
s.infectionCount = ((s.configuration.InfectionCount + s.infectionCount - 1) % s.configuration.InfectionCount)
if !succeeded {
// If could not gossip with anyone then repeat.
s.infectionCount++
}
s.manager.GetMesh(meshId).SaveChanges()
s.lastSync[meshId] = uint64(time.Now().Unix())
changes := correspondingMesh.HasChanges()
correspondingMesh.SaveChanges()
logging.Log.WriteInfof("UPDATING WG CONF")
err := s.manager.ApplyConfig()
if err != nil {
logging.Log.WriteInfof("Failed to update config %w", err)
}
return nil
s.lastSyncLock.Lock()
s.lastSync[correspondingMesh.GetMeshId()] = time.Now().Unix()
s.lastSyncLock.Unlock()
return changes, nil
}
// Pull one node in the cluster, if there has not been message dissemination
// in a certain period of time pull a random node within the cluster
func (s *SyncerImpl) Pull(meshId string) error {
mesh := s.manager.GetMesh(meshId)
self, err := s.manager.GetSelf(meshId)
if err != nil {
return err
}
func (s *SyncerImpl) Pull(self mesh.MeshNode, mesh mesh.MeshProvider) (bool, error) {
peers := mesh.GetPeers()
pubKey, _ := self.GetPublicKey()
if mesh == nil {
return errors.New("mesh is nil, invalid operation")
}
peers := mesh.GetPeers()
neighbours := s.cluster.GetNeighbours(peers, pubKey.String())
neighbour := lib.RandomSubsetOfLength(neighbours, 1)
if len(neighbour) == 0 {
logging.Log.WriteInfof("no neighbours")
return nil
return false, nil
}
logging.Log.WriteInfof("PULLING from node %s", neighbour[0])
logging.Log.WriteInfof("pulling from node %s", neighbour[0])
pullNode, err := mesh.GetNode(neighbour[0])
if err != nil || pullNode == nil {
return fmt.Errorf("node %s does not exist in the mesh", neighbour[0])
return false, fmt.Errorf("node %s does not exist in the mesh", neighbour[0])
}
err = s.requester.SyncMesh(meshId, pullNode)
err = s.requester.SyncMesh(mesh, pullNode)
if err == nil || err == io.EOF {
s.lastSync[meshId] = uint64(time.Now().Unix())
s.lastSync[mesh.GetMeshId()] = time.Now().Unix()
} else {
return err
return false, err
}
s.syncCount++
return nil
changes := mesh.HasChanges()
return changes, nil
}
// SyncMeshes: Sync all meshes
func (s *SyncerImpl) SyncMeshes() error {
for meshId := range s.manager.GetMeshes() {
err := s.Sync(meshId)
var wg sync.WaitGroup
if err != nil {
logging.Log.WriteErrorf(err.Error())
meshes := s.meshManager.GetMeshes()
s.lastPollLock.Lock()
meshesToSync := lib.Filter(lib.MapValues(meshes), func(mesh mesh.MeshProvider) bool {
return time.Now().Unix()-s.lastPoll[mesh.GetMeshId()] >= int64(s.configuration.SyncInterval)
})
s.lastPollLock.Unlock()
changes := make(chan bool, len(meshesToSync))
for i := 0; i < len(meshesToSync); {
wg.Add(1)
sync := func(index int) {
defer wg.Done()
var hasChanges bool = false
mesh := meshesToSync[index]
hasChanges, err := s.Sync(mesh)
changes <- hasChanges
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
s.lastPollLock.Lock()
s.lastPoll[mesh.GetMeshId()] = time.Now().Unix()
s.lastPollLock.Unlock()
}
go sync(i)
i++
}
wg.Wait()
hasChanges := false
for i := 0; i < len(changes); i++ {
if <-changes {
hasChanges = true
}
}
return nil
var err error
if hasChanges {
logging.Log.WriteInfof("updating the WireGuard configuration")
err = s.meshManager.ApplyConfig()
if err != nil {
logging.Log.WriteErrorf("failed to update config %s", err.Error())
}
err = s.meshManager.GetRouteManager().UpdateRoutes()
if err != nil {
logging.Log.WriteErrorf("update routes failed %s", err.Error())
}
}
return err
}
func NewSyncer(m mesh.MeshManager, conf *conf.DaemonConfiguration, r SyncRequester) Syncer {
cluster, _ := conn.NewConnCluster(conf.ClusterSize)
type NewSyncerParams struct {
MeshManager mesh.MeshManager
ConnectionManager conn.ConnectionManager
Configuration *conf.DaemonConfiguration
Requester SyncRequester
}
func NewSyncer(params *NewSyncerParams) Syncer {
cluster, _ := conn.NewConnCluster(params.Configuration.ClusterSize)
syncRequester := NewSyncRequester(NewSyncRequesterParams{
MeshManager: params.MeshManager,
ConnectionManager: params.ConnectionManager,
Configuration: params.Configuration,
})
return &SyncerImpl{
manager: m,
conf: conf,
requester: r,
meshManager: params.MeshManager,
configuration: params.Configuration,
requester: syncRequester,
infectionCount: 0,
syncCount: 0,
cluster: cluster,
lastSync: make(map[string]uint64)}
lastSync: make(map[string]int64),
lastPoll: make(map[string]int64)}
}

View File

@ -1,41 +1,60 @@
package sync
import (
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/conn"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// SyncErrorHandler: Handles errors when attempting to sync
type SyncErrorHandler interface {
Handle(meshId string, endpoint string, err error) bool
Handle(mesh mesh.MeshProvider, endpoint string, err error) bool
}
// SyncErrorHandlerImpl Is an implementation of the SyncErrorHandler
type SyncErrorHandlerImpl struct {
meshManager mesh.MeshManager
connManager conn.ConnectionManager
}
func (s *SyncErrorHandlerImpl) handleFailed(meshId string, nodeId string) bool {
mesh := s.meshManager.GetMesh(meshId)
func (s *SyncErrorHandlerImpl) handleFailed(mesh mesh.MeshProvider, nodeId string) bool {
mesh.Mark(nodeId)
node, err := mesh.GetNode(nodeId)
if err != nil {
s.connManager.RemoveConnection(node.GetHostEndpoint())
}
return true
}
func (s *SyncErrorHandlerImpl) Handle(meshId string, nodeId string, err error) bool {
func (s *SyncErrorHandlerImpl) handleDeadlineExceeded(mesh mesh.MeshProvider, nodeId string) bool {
node, err := mesh.GetNode(nodeId)
if err != nil {
return false
}
s.connManager.RemoveConnection(node.GetHostEndpoint())
return true
}
func (s *SyncErrorHandlerImpl) Handle(mesh mesh.MeshProvider, nodeId string, err error) bool {
errStatus, _ := status.FromError(err)
logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message())
switch errStatus.Code() {
case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound:
return s.handleFailed(meshId, nodeId)
case codes.Unavailable, codes.Unknown, codes.Internal, codes.NotFound:
return s.handleFailed(mesh, nodeId)
case codes.DeadlineExceeded:
return s.handleDeadlineExceeded(mesh, nodeId)
}
return false
}
func NewSyncErrorHandler(m mesh.MeshManager) SyncErrorHandler {
return &SyncErrorHandlerImpl{meshManager: m}
func NewSyncErrorHandler(m mesh.MeshManager, conn conn.ConnectionManager) SyncErrorHandler {
return &SyncErrorHandlerImpl{meshManager: m, connManager: conn}
}

View File

@ -2,76 +2,44 @@ package sync
import (
"context"
"errors"
"io"
"time"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/rpc"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/conn"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/rpc"
)
// SyncRequester: coordinates the syncing of meshes
type SyncRequester interface {
GetMesh(meshId string, ifName string, port int, endPoint string) error
SyncMesh(meshid string, meshNode mesh.MeshNode) error
SyncMesh(mesh mesh.MeshProvider, meshNode mesh.MeshNode) error
}
type SyncRequesterImpl struct {
server *ctrlserver.MeshCtrlServer
errorHdlr SyncErrorHandler
manager mesh.MeshManager
connectionManager conn.ConnectionManager
configuration *conf.DaemonConfiguration
errorHdlr SyncErrorHandler
}
// GetMesh: Retrieves the local state of the mesh at the endpoint
func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endPoint string) error {
peerConnection, err := s.server.ConnectionManager.GetConnection(endPoint)
if err != nil {
return err
}
client, err := peerConnection.GetClient()
if err != nil {
return err
}
c := rpc.NewSyncServiceClient(client)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
reply, err := c.GetConf(ctx, &rpc.GetConfRequest{MeshId: meshId})
if err != nil {
return err
}
err = s.server.MeshManager.AddMesh(&mesh.AddMeshParams{
MeshId: meshId,
WgPort: port,
MeshBytes: reply.Mesh,
})
return err
}
func (s *SyncRequesterImpl) handleErr(meshId, pubKey string, err error) error {
ok := s.errorHdlr.Handle(meshId, pubKey, err)
// handleErr: handleGrpc errors
func (s *SyncRequesterImpl) handleErr(mesh mesh.MeshProvider, pubKey string, err error) error {
ok := s.errorHdlr.Handle(mesh, pubKey, err)
if ok {
return nil
}
return err
}
// SyncMesh: Proactively send a sync request to the other mesh
func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) error {
func (s *SyncRequesterImpl) SyncMesh(mesh mesh.MeshProvider, meshNode mesh.MeshNode) error {
endpoint := meshNode.GetHostEndpoint()
pubKey, _ := meshNode.GetPublicKey()
peerConnection, err := s.server.ConnectionManager.GetConnection(endpoint)
peerConnection, err := s.connectionManager.GetConnection(endpoint)
if err != nil {
return err
@ -83,15 +51,9 @@ func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) erro
return err
}
mesh := s.server.MeshManager.GetMesh(meshId)
if mesh == nil {
return errors.New("mesh does not exist")
}
c := rpc.NewSyncServiceClient(client)
syncTimeOut := float64(s.server.Conf.SyncRate) * float64(time.Second)
syncTimeOut := float64(s.configuration.SyncInterval) * float64(time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(syncTimeOut))
defer cancel()
@ -99,11 +61,11 @@ func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) erro
err = s.syncMesh(mesh, ctx, c)
if err != nil {
return s.handleErr(meshId, pubKey.String(), err)
s.handleErr(mesh, pubKey.String(), err)
}
logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId)
return nil
logging.Log.WriteInfof("synced with node: %s meshId: %s\n", endpoint, mesh.GetMeshId())
return err
}
func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context, client rpc.SyncServiceClient) error {
@ -127,7 +89,7 @@ func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context
in, err := stream.Recv()
if err != nil && err != io.EOF {
logging.Log.WriteInfof("Stream recv error: %s\n", err.Error())
logging.Log.WriteInfof("stream recv error: %s\n", err.Error())
return err
}
@ -136,7 +98,7 @@ func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context
}
if err != nil {
logging.Log.WriteInfof("Syncer recv error: %s\n", err.Error())
logging.Log.WriteInfof("syncer recv error: %s\n", err.Error())
return err
}
@ -150,7 +112,17 @@ func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context
return nil
}
func NewSyncRequester(s *ctrlserver.MeshCtrlServer) SyncRequester {
errorHdlr := NewSyncErrorHandler(s.MeshManager)
return &SyncRequesterImpl{server: s, errorHdlr: errorHdlr}
type NewSyncRequesterParams struct {
MeshManager mesh.MeshManager
ConnectionManager conn.ConnectionManager
Configuration *conf.DaemonConfiguration
}
func NewSyncRequester(params NewSyncRequesterParams) SyncRequester {
errorHdlr := NewSyncErrorHandler(params.MeshManager, params.ConnectionManager)
return &SyncRequesterImpl{manager: params.MeshManager,
connectionManager: params.ConnectionManager,
configuration: params.Configuration,
errorHdlr: errorHdlr,
}
}

View File

@ -1,18 +0,0 @@
package sync
import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
// Run implements SyncScheduler.
func syncFunction(syncer Syncer) lib.TimerFunc {
return func() error {
syncer.SyncMeshes()
return nil
}
}
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester, syncer Syncer) *lib.Timer {
return lib.NewTimer(syncFunction(syncer), s.Conf.SyncRate)
}

View File

@ -6,19 +6,18 @@ import (
"errors"
"io"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/rpc"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/rpc"
)
type SyncServiceImpl struct {
rpc.UnimplementedSyncServiceServer
Server *ctrlserver.MeshCtrlServer
MeshManager mesh.MeshManager
}
// GetMesh: Gets a nodes local mesh configuration as a CRDT
func (s *SyncServiceImpl) GetConf(context context.Context, request *rpc.GetConfRequest) (*rpc.GetConfReply, error) {
mesh := s.Server.MeshManager.GetMesh(request.MeshId)
mesh := s.MeshManager.GetMesh(request.MeshId)
if mesh == nil {
return nil, errors.New("mesh does not exist")
@ -56,7 +55,7 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error
if len(meshId) == 0 {
meshId = in.MeshId
mesh := s.Server.MeshManager.GetMesh(meshId)
mesh := s.MeshManager.GetMesh(meshId)
if mesh == nil {
return errors.New("mesh does not exist")
@ -92,7 +91,3 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error
}
}
}
func NewSyncService(server *ctrlserver.MeshCtrlServer) *SyncServiceImpl {
return &SyncServiceImpl{Server: server}
}

View File

@ -1,15 +0,0 @@
package timer
import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
)
func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
timerFunc := func() error {
logging.Log.WriteInfof("Updated Timestamp")
return ctrlServer.MeshManager.UpdateTimeStamp()
}
return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime)
}

View File

@ -5,8 +5,8 @@ import (
"crypto/rand"
"fmt"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/smegmesh/pkg/log"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)