1
0
forked from extern/smegmesh

Compare commits

...

59 Commits

Author SHA1 Message Date
9a30f4d5cb Submitting 2024-01-05 18:22:05 +00:00
f647c1b806 81-seperate-processes
Prep for submission
2024-01-05 16:59:02 +00:00
a55dadf088 81-seperate-synchronisation-into-independent-procs
- Neaten code
2024-01-05 12:59:13 +00:00
0ec5156e59 81-procs
- fixed issue where route not deleting if mesh only one
2024-01-05 00:14:25 +00:00
2b73d241b6 81-serparate-procs
- nil dereference again
2024-01-04 22:29:30 +00:00
69b1790bb6 81-processes
- issue with client client traversal
2024-01-04 22:08:14 +00:00
4a92743880 81-seperate-sync
- build error
2024-01-04 21:48:54 +00:00
038393052c 81-seperate-synchronisation-into-independent-proc
- build error
2024-01-04 21:47:29 +00:00
5efff2314b 81-separate-synchronisation-into-independent-process
- nil dereference when no joins
2024-01-04 21:45:28 +00:00
1f8d229076 81-seperate-synchronisation-into-independent-process
- nil dereference due to concurrency issues (the method shouldn't be
  concurrent)
2024-01-04 21:16:33 +00:00
a0e7a4a644 81-seperateprocesses-into-independent-processes
- Fixed errors
2024-01-04 13:15:29 +00:00
f9b8b85ec3 81-seperate-synchronisation
- Removed authentication.proto
2024-01-04 13:12:33 +00:00
59d8ae4334 81-seperate-synchronisation
- More code comments
2024-01-04 13:12:07 +00:00
02dfd73e08 81-seperate-synchronisation-into-independent
- Separated synchronisation calls into independent processes
- Commented code for submission
2024-01-04 13:10:08 +00:00
9818645299 Merge pull request #82 from tim-beatham/bugfix-node-not-leving
bugfix-node-not-leaving
2024-01-04 00:24:58 +00:00
1f0914e2df bugfix-node-not-leaving
- Add lock when perform synchronisation on concurrent access
2024-01-04 00:23:20 +00:00
efb40d65de Merge pull request #80 from tim-beatham/bugfix-node-not-leving
main
2024-01-02 20:32:09 +00:00
27e00196cd main
- Not waiting in the waitgroup
2024-01-02 20:31:24 +00:00
4543205703 Merge pull request #79 from tim-beatham/bugfix-node-not-leving
main
2024-01-02 20:21:27 +00:00
dea6f1a22d main
- error in code invalid check for nil
2024-01-02 20:19:34 +00:00
4d19da6727 Merge pull request #78 from tim-beatham/bugfix-node-not-leving
main
2024-01-02 20:12:10 +00:00
913de57568 main
- Fixed bug
2024-01-02 20:11:11 +00:00
8a5673e303 Merge pull request #77 from tim-beatham/bugfix-node-not-leving
bugfix node not leaving
2024-01-02 19:43:04 +00:00
ce829114b1 bugfix
- on synchornisation node is not leaving mesh
2024-01-02 19:41:20 +00:00
05cc287e31 Merge pull request #76 from tim-beatham/74-perform-dad
- Fixing DNS error
2024-01-02 00:16:45 +00:00
cd844ff46e - Fixing DNS error 2024-01-02 00:15:23 +00:00
4b9406a920 Merge pull request #75 from tim-beatham/74-perform-dad
74-perform-dad
2024-01-02 00:14:37 +00:00
d0b1913796 74-perform-dad
- Fixing nil pointer dereference
2024-01-02 00:13:04 +00:00
90cfe820d2 - Fixing errors with stale paths 2024-01-02 00:09:31 +00:00
8a49809855 74-perform-dad
- Adding go.sum to fix errors
2024-01-01 23:59:04 +00:00
dbc18bddc6 74-perform-dad
- Performing DAD to check if IPv6 address present before adding
  outselves to mesh
- Changing name from wgmesh to smegmesh
2024-01-01 23:55:50 +00:00
14f335af74 Merge pull request #73 from tim-beatham/72-pull-rate-in-configuration
72 pull rate in configuration
2023-12-31 14:26:34 +00:00
36e82dba47 72-pull-rate-in-configuration
- Refactored pull rate into the configuration
- code freeze so no more code changes
2023-12-31 14:25:06 +00:00
3cc87bc252 72-pull-rate-in-configuration
- Updated examples
2023-12-31 12:47:45 +00:00
a9ed7c0a20 72-pull-rate-in-configuration
- Removing libp2p reference
2023-12-31 12:47:45 +00:00
fd29af73e3 72-pull-rate-in-configuration
- Added pull rate to configuration (finally) so this can
be modified by an administrator.
2023-12-31 12:47:45 +00:00
9e1058e0f2 72-pull-rate-in-configuration
- Added the pull rate to the configuration file
2023-12-31 12:47:45 +00:00
c29eb197f3 Merge pull request #71 from tim-beatham/66-ipv6-address-not-conforming-to-spec
66 ipv6 address not conforming to spec
2023-12-30 22:26:53 +00:00
1a9d9d61ad 66-ipv6-address-not-conforming-to-spec
- Missing commit
2023-12-30 22:26:08 +00:00
6954608c32 66-ipv6-address-not-confirming-to-spec
- UUID is not random just a name generator needs changing to shortuuid
- When in multiple meshes there is no wait group
2023-12-30 22:24:43 +00:00
2e6aed6f93 main
- Fixing issue with nil pointer de-reference due to bad design of mesh
  manager.
- Going forward all references to GetSelf should be depracated. It
  introduces a race condition when leaving a mesh network
2023-12-30 00:44:57 +00:00
b0893a0b8e Merge pull request #69 from tim-beatham/60-unit-test-crdt-data-store
60-unit-test-crdt-data-store
2023-12-29 22:06:20 +00:00
e7d6055fa3 60-unit-test-crdt-data-store
Provided unit tests for datastore.go
And fixed unit tets failing by different way of providing CA
2023-12-29 22:05:05 +00:00
e0f3f116b9 main
- Stale serverConfig entry causing certificate authorities
to not become authorised
2023-12-29 19:54:08 +00:00
352648b7cb main
- Fixed problem where connection not removed on error
2023-12-29 11:12:40 +00:00
2d5df25b1d main
- If deadline exceeded error remove connection from
connection manager
2023-12-29 01:29:11 +00:00
cabe173831 main
Adding retry parameter
2023-12-29 01:10:26 +00:00
d2c8a52ec6 main
- Adding retry policy for mobility
2023-12-29 00:58:43 +00:00
bf53108384 main
- Bugfix, fix consistent hash problem where
if failure happens then causes panic
2023-12-28 23:24:38 +00:00
77aac5534b main
- Bugfix in client where "-" was attempted to be parsed as a UDP addr
2023-12-28 17:46:04 +00:00
58439fcd56 main
- Bugfix when keepalivewg is not set causes segmentation fault
- give keepalive a default value of 0 if not set
2023-12-28 17:32:54 +00:00
311a15363a Merge pull request #67 from tim-beatham/66-improve-graph-dot-tool
66 improve graph dot tool
2023-12-25 01:26:15 +00:00
255d3c8b39 66-improve-graph-dot-tool
- Showing services a node provides
- Showing all meshes not just one
- Showing the default route
2023-12-25 01:25:20 +00:00
41899c5831 66-improve-graph-dot-tool
Improving the graph dot tool so that it shows all
meshes
2023-12-25 01:10:11 +00:00
fe4ca66ff6 Merge pull request #65 from tim-beatham/64-2p-set-unit-test
64 2p set unit test
2023-12-22 23:58:59 +00:00
0b91ba744a 61-improve-unit-test-coverage
- Provided unit tests for g_map and 2p_map
2023-12-22 23:57:10 +00:00
67483c2a90 64-unit-test-two-phase-set
Provide unit tests for two phase set to make it more
transparent what exactly they are doing.
2023-12-22 23:57:10 +00:00
af26e81bd3 Merge pull request #63 from tim-beatham/61-improve-unit-testing-coverage
61-improve-unit-testing-coverage
2023-12-22 21:52:46 +00:00
186acbe915 Merge pull request #62 from tim-beatham/61-improve-unit-testing-coverage
61-improve-unit-testing-coverage
2023-12-22 21:49:06 +00:00
80 changed files with 3564 additions and 1698 deletions

View File

@ -9,4 +9,4 @@ RUN apt-get update && apt-get install -y \
vim vim
WORKDIR /wgmesh WORKDIR /wgmesh
RUN go mod tidy RUN go mod tidy
RUN go build -o /usr/local/bin ./... RUN go build -o /usr/local/bin ./...

View File

@ -3,7 +3,7 @@ package main
import ( import (
"log" "log"
"github.com/tim-beatham/wgmesh/pkg/api" "github.com/tim-beatham/smegmesh/pkg/api"
) )
func main() { func main() {

View File

@ -3,7 +3,7 @@ package main
import ( import (
"log" "log"
smegdns "github.com/tim-beatham/wgmesh/pkg/dns" smegdns "github.com/tim-beatham/smegmesh/pkg/dns"
) )
func main() { func main() {

View File

@ -6,8 +6,10 @@ import (
"os" "os"
"github.com/akamensky/argparse" "github.com/akamensky/argparse"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
logging "github.com/tim-beatham/wgmesh/pkg/log" graph "github.com/tim-beatham/smegmesh/pkg/dot"
"github.com/tim-beatham/smegmesh/pkg/ipc"
logging "github.com/tim-beatham/smegmesh/pkg/log"
) )
const SockAddr = "/tmp/wgmesh_ipc.sock" const SockAddr = "/tmp/wgmesh_ipc.sock"
@ -20,25 +22,22 @@ type CreateMeshParams struct {
AdvertiseDefault bool AdvertiseDefault bool
} }
func createMesh(params *CreateMeshParams) string { func createMesh(client *ipc.SmegmeshIpc, args *ipc.NewMeshArgs) {
var reply string var reply string
newMeshParams := ipc.NewMeshArgs{ err := client.CreateMesh(args, &reply)
WgArgs: params.WgArgs,
}
err := params.Client.Call("IpcHandler.CreateMesh", &newMeshParams, &reply)
if err != nil { if err != nil {
return err.Error() fmt.Println(err.Error())
return
} }
return reply fmt.Println(reply)
} }
func listMeshes(client *ipcRpc.Client) { func listMeshes(client *ipc.SmegmeshIpc) {
reply := new(ipc.ListMeshReply) reply := new(ipc.ListMeshReply)
err := client.Call("IpcHandler.ListMeshes", "", &reply) err := client.ListMeshes(reply)
if err != nil { if err != nil {
logging.Log.WriteErrorf(err.Error()) logging.Log.WriteErrorf(err.Error())
@ -50,38 +49,22 @@ func listMeshes(client *ipcRpc.Client) {
} }
} }
type JoinMeshParams struct { func joinMesh(client *ipc.SmegmeshIpc, args ipc.JoinMeshArgs) {
Client *ipcRpc.Client
MeshId string
IpAddress string
Endpoint string
WgArgs ipc.WireGuardArgs
AdvertiseRoutes bool
AdvertiseDefault bool
}
func joinMesh(params *JoinMeshParams) string {
var reply string var reply string
args := ipc.JoinMeshArgs{ err := client.JoinMesh(args, &reply)
MeshId: params.MeshId,
IpAdress: params.IpAddress,
WgArgs: params.WgArgs,
}
err := params.Client.Call("IpcHandler.JoinMesh", &args, &reply)
if err != nil { if err != nil {
return err.Error() fmt.Println(err.Error())
} }
return reply fmt.Println(reply)
} }
func leaveMesh(client *ipcRpc.Client, meshId string) { func leaveMesh(client *ipc.SmegmeshIpc, meshId string) {
var reply string var reply string
err := client.Call("IpcHandler.LeaveMesh", &meshId, &reply) err := client.LeaveMesh(meshId, &reply)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
@ -91,10 +74,51 @@ func leaveMesh(client *ipcRpc.Client, meshId string) {
fmt.Println(reply) fmt.Println(reply)
} }
func getGraph(client *ipcRpc.Client, meshId string) { func getGraph(client *ipc.SmegmeshIpc) {
listMeshesReply := new(ipc.ListMeshReply)
err := client.ListMeshes(listMeshesReply)
if err != nil {
fmt.Println(err.Error())
return
}
meshes := make(map[string][]ctrlserver.MeshNode)
for _, meshId := range listMeshesReply.Meshes {
var meshReply ipc.GetMeshReply
err := client.GetMesh(meshId, &meshReply)
if err != nil {
fmt.Println(err.Error())
return
}
meshes[meshId] = meshReply.Nodes
}
dotGenerator := graph.NewMeshGraphConverter(meshes)
dot, err := dotGenerator.Generate()
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(dot)
}
func queryMesh(client *ipc.SmegmeshIpc, meshId, query string) {
var reply string var reply string
err := client.Call("IpcHandler.GetDOT", &meshId, &reply) args := ipc.QueryMesh{
MeshId: meshId,
Query: query,
}
err := client.Query(args, &reply)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
@ -104,24 +128,13 @@ func getGraph(client *ipcRpc.Client, meshId string) {
fmt.Println(reply) fmt.Println(reply)
} }
func queryMesh(client *ipcRpc.Client, meshId, query string) { func putDescription(client *ipc.SmegmeshIpc, meshId, description string) {
var reply string var reply string
err := client.Call("IpcHandler.Query", &ipc.QueryMesh{MeshId: meshId, Query: query}, &reply) err := client.PutDescription(ipc.PutDescriptionArgs{
MeshId: meshId,
if err != nil { Description: description,
fmt.Println(err.Error()) }, &reply)
return
}
fmt.Println(reply)
}
// putDescription: puts updates the description about the node to the meshes
func putDescription(client *ipcRpc.Client, description string) {
var reply string
err := client.Call("IpcHandler.PutDescription", &description, &reply)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
@ -132,10 +145,13 @@ func putDescription(client *ipcRpc.Client, description string) {
} }
// putAlias: puts an alias for the node // putAlias: puts an alias for the node
func putAlias(client *ipcRpc.Client, alias string) { func putAlias(client *ipc.SmegmeshIpc, meshid, alias string) {
var reply string var reply string
err := client.Call("IpcHandler.PutAlias", &alias, &reply) err := client.PutAlias(ipc.PutAliasArgs{
MeshId: meshid,
Alias: alias,
}, &reply)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
@ -145,15 +161,14 @@ func putAlias(client *ipcRpc.Client, alias string) {
fmt.Println(reply) fmt.Println(reply)
} }
func setService(client *ipcRpc.Client, service, value string) { func setService(client *ipc.SmegmeshIpc, meshId, service, value string) {
var reply string var reply string
serviceArgs := &ipc.PutServiceArgs{ err := client.PutService(ipc.PutServiceArgs{
MeshId: meshId,
Service: service, Service: service,
Value: value, Value: value,
} }, &reply)
err := client.Call("IpcHandler.PutService", serviceArgs, &reply)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
@ -163,10 +178,13 @@ func setService(client *ipcRpc.Client, service, value string) {
fmt.Println(reply) fmt.Println(reply)
} }
func deleteService(client *ipcRpc.Client, service string) { func deleteService(client *ipc.SmegmeshIpc, meshId, service string) {
var reply string var reply string
err := client.Call("IpcHandler.PutService", &service, &reply) err := client.DeleteService(ipc.DeleteServiceArgs{
MeshId: meshId,
Service: service,
}, &reply)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
@ -177,8 +195,8 @@ func deleteService(client *ipcRpc.Client, service string) {
} }
func main() { func main() {
parser := argparse.NewParser("wg-mesh", parser := argparse.NewParser("smgctl",
"wg-mesh Manipulate WireGuard mesh networks") "smegctl Manipulate WireGuard mesh networks")
newMeshCmd := parser.NewCommand("new-mesh", "Create a new mesh") newMeshCmd := parser.NewCommand("new-mesh", "Create a new mesh")
listMeshCmd := parser.NewCommand("list-meshes", "List meshes the node is connected to") listMeshCmd := parser.NewCommand("list-meshes", "List meshes the node is connected to")
@ -201,7 +219,6 @@ func main() {
}) })
var newMeshRole *string = newMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{ var newMeshRole *string = newMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{
Default: "peer",
Help: "Role in the mesh network. A value of peer means that the node is publicly routeable and thus considered" + Help: "Role in the mesh network. A value of peer means that the node is publicly routeable and thus considered" +
" in the gossip protocol. Client means that the node is not publicly routeable and is not a candidate in the gossip" + " in the gossip protocol. Client means that the node is not publicly routeable and is not a candidate in the gossip" +
" protocol", " protocol",
@ -234,7 +251,6 @@ func main() {
}) })
var joinMeshRole *string = joinMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{ var joinMeshRole *string = joinMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{
Default: "peer",
Help: "Role in the mesh network. A value of peer means that the node is publicly routeable and thus considered" + Help: "Role in the mesh network. A value of peer means that the node is publicly routeable and thus considered" +
" in the gossip protocol. Client means that the node is not publicly routeable and is not a candidate in the gossip" + " in the gossip protocol. Client means that the node is not publicly routeable and is not a candidate in the gossip" +
" protocol", " protocol",
@ -258,11 +274,6 @@ func main() {
Help: "Advertise ::/0 into the mesh network", Help: "Advertise ::/0 into the mesh network",
}) })
var getGraphMeshId *string = getGraphCmd.String("m", "mesh", &argparse.Options{
Required: true,
Help: "MeshID of the graph to get",
})
var leaveMeshMeshId *string = leaveMeshCmd.String("m", "mesh", &argparse.Options{ var leaveMeshMeshId *string = leaveMeshCmd.String("m", "mesh", &argparse.Options{
Required: true, Required: true,
Help: "MeshID of the mesh to leave", Help: "MeshID of the mesh to leave",
@ -282,6 +293,16 @@ func main() {
Help: "Description of the node in the mesh", Help: "Description of the node in the mesh",
}) })
var descriptionMeshId *string = putDescriptionCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
var aliasMeshId *string = putAliasCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
var alias *string = putAliasCmd.String("a", "alias", &argparse.Options{ var alias *string = putAliasCmd.String("a", "alias", &argparse.Options{
Required: true, Required: true,
Help: "Alias of the node to set can be used in DNS to lookup an IP address", Help: "Alias of the node to set can be used in DNS to lookup an IP address",
@ -296,11 +317,21 @@ func main() {
Help: "Value of the service to advertise in the mesh network", Help: "Value of the service to advertise in the mesh network",
}) })
var serviceMeshId *string = setServiceCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{ var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{
Required: true, Required: true,
Help: "Key of the service to remove", Help: "Key of the service to remove",
}) })
var deleteServiceMeshid *string = deleteServiceCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
err := parser.Parse(os.Args) err := parser.Parse(os.Args)
if err != nil { if err != nil {
@ -308,16 +339,13 @@ func main() {
return return
} }
client, err := ipcRpc.DialHTTP("unix", SockAddr) client, err := ipc.NewClientIpc()
if err != nil { if err != nil {
fmt.Println(err.Error()) panic(err)
return
} }
if newMeshCmd.Happened() { if newMeshCmd.Happened() {
fmt.Println(createMesh(&CreateMeshParams{ args := &ipc.NewMeshArgs{
Client: client,
Endpoint: *newMeshEndpoint,
WgArgs: ipc.WireGuardArgs{ WgArgs: ipc.WireGuardArgs{
Endpoint: *newMeshEndpoint, Endpoint: *newMeshEndpoint,
Role: *newMeshRole, Role: *newMeshRole,
@ -326,7 +354,9 @@ func main() {
AdvertiseDefaultRoute: *newMeshAdvertiseDefaults, AdvertiseDefaultRoute: *newMeshAdvertiseDefaults,
AdvertiseRoutes: *newMeshAdvertiseRoutes, AdvertiseRoutes: *newMeshAdvertiseRoutes,
}, },
})) }
createMesh(client, args)
} }
if listMeshCmd.Happened() { if listMeshCmd.Happened() {
@ -334,11 +364,9 @@ func main() {
} }
if joinMeshCmd.Happened() { if joinMeshCmd.Happened() {
fmt.Println(joinMesh(&JoinMeshParams{ args := ipc.JoinMeshArgs{
Client: client,
IpAddress: *joinMeshIpAddress, IpAddress: *joinMeshIpAddress,
MeshId: *joinMeshId, MeshId: *joinMeshId,
Endpoint: *joinMeshEndpoint,
WgArgs: ipc.WireGuardArgs{ WgArgs: ipc.WireGuardArgs{
Endpoint: *joinMeshEndpoint, Endpoint: *joinMeshEndpoint,
Role: *joinMeshRole, Role: *joinMeshRole,
@ -347,11 +375,12 @@ func main() {
AdvertiseDefaultRoute: *joinMeshAdvertiseDefaults, AdvertiseDefaultRoute: *joinMeshAdvertiseDefaults,
AdvertiseRoutes: *joinMeshAdvertiseRoutes, AdvertiseRoutes: *joinMeshAdvertiseRoutes,
}, },
})) }
joinMesh(client, args)
} }
if getGraphCmd.Happened() { if getGraphCmd.Happened() {
getGraph(client, *getGraphMeshId) getGraph(client)
} }
if leaveMeshCmd.Happened() { if leaveMeshCmd.Happened() {
@ -363,18 +392,18 @@ func main() {
} }
if putDescriptionCmd.Happened() { if putDescriptionCmd.Happened() {
putDescription(client, *description) putDescription(client, *descriptionMeshId, *description)
} }
if putAliasCmd.Happened() { if putAliasCmd.Happened() {
putAlias(client, *alias) putAlias(client, *aliasMeshId, *alias)
} }
if setServiceCmd.Happened() { if setServiceCmd.Happened() {
setService(client, *serviceKey, *serviceValue) setService(client, *serviceMeshId, *serviceKey, *serviceValue)
} }
if deleteServiceCmd.Happened() { if deleteServiceCmd.Happened() {
deleteService(client, *deleteServiceKey) deleteService(client, *deleteServiceMeshid, *deleteServiceKey)
} }
} }

View File

@ -6,29 +6,30 @@ import (
"os" "os"
"os/signal" "os/signal"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
ctrlserver "github.com/tim-beatham/wgmesh/pkg/ctrlserver" ctrlserver "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/smegmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/robin"
"github.com/tim-beatham/wgmesh/pkg/robin" "github.com/tim-beatham/smegmesh/pkg/sync"
"github.com/tim-beatham/wgmesh/pkg/sync"
timer "github.com/tim-beatham/wgmesh/pkg/timers"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
) )
func main() { func main() {
if len(os.Args) != 2 { if len(os.Args) != 2 {
logging.Log.WriteErrorf("Did not provide configuration") logging.Log.WriteErrorf("Did not provide configuration")
return return
} }
conf, err := conf.ParseDaemonConfiguration(os.Args[1]) configuration, err := conf.ParseDaemonConfiguration(os.Args[1])
if err != nil { if err != nil {
logging.Log.WriteErrorf("Could not parse configuration: %s", err.Error()) logging.Log.WriteErrorf("Could not parse configuration: %s", err.Error())
return return
} }
logging.SetLogger(logging.NewLogrusLogger(configuration.LogLevel))
client, err := wgctrl.New() client, err := wgctrl.New()
if err != nil { if err != nil {
@ -36,7 +37,7 @@ func main() {
return return
} }
if conf.Profile { if configuration.Profile {
go func() { go func() {
http.ListenAndServe("localhost:6060", nil) http.ListenAndServe("localhost:6060", nil)
}() }()
@ -45,25 +46,21 @@ func main() {
var robinRpc robin.WgRpc var robinRpc robin.WgRpc
var robinIpc robin.IpcHandler var robinIpc robin.IpcHandler
var syncProvider sync.SyncServiceImpl var syncProvider sync.SyncServiceImpl
var syncRequester sync.SyncRequester
var syncer sync.Syncer
ctrlServerParams := ctrlserver.NewCtrlServerParams{ ctrlServerParams := ctrlserver.NewCtrlServerParams{
Conf: conf, Conf: configuration,
CtrlProvider: &robinRpc, CtrlProvider: &robinRpc,
SyncProvider: &syncProvider, SyncProvider: &syncProvider,
Client: client, Client: client,
OnDelete: func(mp mesh.MeshProvider) {
syncer.SyncMeshes()
},
} }
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams) ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
syncProvider.Server = ctrlServer
syncRequester = sync.NewSyncRequester(ctrlServer) if err != nil {
syncer = sync.NewSyncer(ctrlServer.MeshManager, conf, syncRequester) panic(err)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, syncer) }
keepAlive := timer.NewTimestampScheduler(ctrlServer)
syncProvider.MeshManager = ctrlServer.MeshManager
robinIpcParams := robin.RobinIpcParams{ robinIpcParams := robin.RobinIpcParams{
CtrlServer: ctrlServer, CtrlServer: ctrlServer,
@ -77,16 +74,11 @@ func main() {
return return
} }
logging.Log.WriteInfof("Running IPC Handler") logging.Log.WriteInfof("running ipc handler")
go ipc.RunIpcHandler(&robinIpc) go ipc.RunIpcHandler(&robinIpc)
go syncScheduler.Run()
go keepAlive.Run()
closeResources := func() { closeResources := func() {
logging.Log.WriteInfof("Closing resources") logging.Log.WriteInfof("closing resources")
syncScheduler.Stop()
keepAlive.Stop()
ctrlServer.Close() ctrlServer.Close()
client.Close() client.Close()
} }

View File

@ -10,5 +10,5 @@ syncRate: 1
interClusterChance: 0.15 interClusterChance: 0.15
branchRate: 3 branchRate: 3
infectionCount: 3 infectionCount: 3
keepAliveTime: 10 heartBeatTime: 10
pruneTime: 20 pruneTime: 20

View File

@ -1,95 +0,0 @@
version: '3'
networks:
net-1:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.0.0/17
net-2:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.155.0/17
services:
wg-1:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-2:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-3:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-4:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
sysctls:
- net.ipv6.conf.all.forwarding=1
networks:
- net-1
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-5:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-6:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-7:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"

View File

@ -1,14 +0,0 @@
certificatePath: "/wgmesh/cert/cert.pem"
privateKeyPath: "/wgmesh/cert/priv.pem"
caCertificatePath: "/wgmesh/cert/cacert.pem"
skipCertVerification: true
timeout: 5
gRPCPort: "21906"
advertiseRoutes: true
clusterSize: 32
syncRate: 1
interClusterChance: 0.15
branchRate: 3
infectionCount: 3
keepAliveTime: 10
pruneTime: 20

View File

@ -1,14 +1,9 @@
version: '3' version: '3'
networks: networks:
net-1: net-1:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.0.0/17
services: services:
wg-1: wg-1:
image: wg-mesh-base:latest image: localhost/smegmesh-base:latest
cap_add: cap_add:
- NET_ADMIN - NET_ADMIN
- NET_RAW - NET_RAW
@ -17,9 +12,11 @@ services:
- net-1 - net-1
volumes: volumes:
- ./shared:/shared - ./shared:/shared
command: "wgmeshd /shared/configuration.yaml" command: "smegd /shared/configuration.yaml"
sysctls:
- net.ipv6.conf.all.forwarding=1
wg-2: wg-2:
image: wg-mesh-base:latest image: localhost/smegmesh-base:latest
cap_add: cap_add:
- NET_ADMIN - NET_ADMIN
- NET_RAW - NET_RAW
@ -28,9 +25,11 @@ services:
- net-1 - net-1
volumes: volumes:
- ./shared:/shared - ./shared:/shared
command: "wgmeshd /shared/configuration.yaml" command: "smegd /shared/configuration.yaml"
sysctls:
- net.ipv6.conf.all.forwarding=1
wg-3: wg-3:
image: wg-mesh-base:latest image: localhost/smegmesh-base:latest
cap_add: cap_add:
- NET_ADMIN - NET_ADMIN
- NET_RAW - NET_RAW
@ -39,4 +38,6 @@ services:
- net-1 - net-1
volumes: volumes:
- ./shared:/shared - ./shared:/shared
command: "wgmeshd /shared/configuration.yaml" command: "smegd /shared/configuration.yaml"
sysctls:
- net.ipv6.conf.all.forwarding=1

View File

@ -1,14 +1,34 @@
certificatePath: "/wgmesh/cert/cert.pem" # Paths to the certificates modify
privateKeyPath: "/wgmesh/cert/priv.pem" # if not running from Smegmesh
caCertificatePath: "/wgmesh/cert/cacert.pem" certificatePath: "./cert/cert.pem"
privateKeyPath: "./cert/priv.pem"
caCertificatePath: "./cert/cacert.pem"
skipCertVerification: true skipCertVerification: true
# timeout is the configured grpc timeout
timeout: 5 timeout: 5
gRPCPort: "21906" # gRPC port to run the solution
advertiseRoutes: true gRPCPort: 4000
clusterSize: 32 # whether or not to run go profiler
syncRate: 1 profile: false
interClusterChance: 0.15 # stubWg: whether to install WireGuard configurations
branchRate: 3 # if true just tests the control plane
stubWg: false
heartbeatInterval: 60
branch: 3
pullInterval: 20
infectionCount: 3 infectionCount: 3
keepAliveTime: 10 interClusterChance: 0.15
pruneTime: 20 syncInterval: 2
clusterSize: 64
logLevel: "info"
baseConfiguration:
# ipDiscovery: specifies how to find your IP address
ipDiscovery: "outgoing"
# alternative to ipDiscovery specify an actual endpoint yourself with publicEndpoint: "xxxx"
# role is the role that you are playing (peer | client)
# peers can only bootstrap meshes
role: "peer"
# advertise meshes to other meshes
advertiseRoute: true
# advertise default routes
advertiseDefaults: true

4
go.mod
View File

@ -1,4 +1,4 @@
module github.com/tim-beatham/wgmesh module github.com/tim-beatham/smegmesh
go 1.21.3 go 1.21.3
@ -11,11 +11,11 @@ require (
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/jmespath/go-jmespath v0.4.0 github.com/jmespath/go-jmespath v0.4.0
github.com/jsimonetti/rtnetlink v1.3.5 github.com/jsimonetti/rtnetlink v1.3.5
github.com/lithammer/shortuuid v3.0.0+incompatible
github.com/miekg/dns v1.1.57 github.com/miekg/dns v1.1.57
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
golang.org/x/sys v0.14.0 golang.org/x/sys v0.14.0
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
gonum.org/v1/gonum v0.14.0
google.golang.org/grpc v1.58.1 google.golang.org/grpc v1.58.1
google.golang.org/protobuf v1.31.0 google.golang.org/protobuf v1.31.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1

6
go.sum
View File

@ -27,8 +27,6 @@ github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/o
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE= github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE=
github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
@ -57,6 +55,8 @@ github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZX
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/lithammer/shortuuid v3.0.0+incompatible h1:NcD0xWW/MZYXEHa6ITy6kaXN5nwm/V115vj2YXfhS0w=
github.com/lithammer/shortuuid v3.0.0+incompatible/go.mod h1:FR74pbAuElzOUuenUHTK2Tciko1/vKuIKS9dSkDrA4w=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
@ -123,8 +123,6 @@ golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 h1:EY138uSo1JYlDq+
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4= golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0=
gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 h1:bVf09lpb+OJbByTj913DRJioFFAjf/ZGxEz7MajTp2U= google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 h1:bVf09lpb+OJbByTj913DRJioFFAjf/ZGxEz7MajTp2U=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98/go.mod h1:TUfxEVdsvPg18p6AslUXFoLdpED4oBnGwyqk3dV1XzM= google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98/go.mod h1:TUfxEVdsvPg18p6AslUXFoLdpED4oBnGwyqk3dV1XzM=
google.golang.org/grpc v1.58.1 h1:OL+Vz23DTtrrldqHK49FUOPHyY75rvFqJfXC84NYW58= google.golang.org/grpc v1.58.1 h1:OL+Vz23DTtrrldqHK49FUOPHyY75rvFqJfXC84NYW58=

View File

@ -4,28 +4,14 @@ import (
"fmt" "fmt"
"net/http" "net/http"
ipcRpc "net/rpc"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/smegmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/what8words" "github.com/tim-beatham/smegmesh/pkg/what8words"
) )
const SockAddr = "/tmp/wgmesh_ipc.sock" // routesToApiRoute: convert the returned type to a JSON object
type ApiServer interface {
GetMeshes(c *gin.Context)
Run(addr string) error
}
type SmegServer struct {
router *gin.Engine
client *ipcRpc.Client
words *what8words.What8Words
}
func (s *SmegServer) routeToApiRoute(meshNode ctrlserver.MeshNode) []Route { func (s *SmegServer) routeToApiRoute(meshNode ctrlserver.MeshNode) []Route {
routes := make([]Route, len(meshNode.Routes)) routes := make([]Route, len(meshNode.Routes))
@ -44,6 +30,7 @@ func (s *SmegServer) routeToApiRoute(meshNode ctrlserver.MeshNode) []Route {
return routes return routes
} }
// meshNodeToAPImeshNode: convert daemon node to a JSON node
func (s *SmegServer) meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode { func (s *SmegServer) meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode {
if meshNode.Routes == nil { if meshNode.Routes == nil {
meshNode.Routes = make([]ctrlserver.MeshRoute, 0) meshNode.Routes = make([]ctrlserver.MeshRoute, 0)
@ -74,6 +61,7 @@ func (s *SmegServer) meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNo
} }
} }
// meshToAPIMesh: Convert daemon mesh network to a JSON mesh network
func (s *SmegServer) meshToAPIMesh(meshId string, nodes []ctrlserver.MeshNode) SmegMesh { func (s *SmegServer) meshToAPIMesh(meshId string, nodes []ctrlserver.MeshNode) SmegMesh {
var smegMesh SmegMesh var smegMesh SmegMesh
smegMesh.MeshId = meshId smegMesh.MeshId = meshId
@ -86,6 +74,25 @@ func (s *SmegServer) meshToAPIMesh(meshId string, nodes []ctrlserver.MeshNode) S
return smegMesh return smegMesh
} }
// putAlias: place an alias in the mesh
func (s *SmegServer) putAlias(meshId, alias string) error {
var reply string
return s.client.PutAlias(ipc.PutAliasArgs{
Alias: alias,
MeshId: meshId,
}, &reply)
}
func (s *SmegServer) putDescription(meshId, description string) error {
var reply string
return s.client.PutDescription(ipc.PutDescriptionArgs{
Description: description,
MeshId: meshId,
}, &reply)
}
// CreateMesh: creates a mesh network // CreateMesh: creates a mesh network
func (s *SmegServer) CreateMesh(c *gin.Context) { func (s *SmegServer) CreateMesh(c *gin.Context) {
var createMesh CreateMeshRequest var createMesh CreateMeshRequest
@ -98,15 +105,21 @@ func (s *SmegServer) CreateMesh(c *gin.Context) {
return return
} }
fmt.Printf("%+v\n", createMesh)
ipcRequest := ipc.NewMeshArgs{ ipcRequest := ipc.NewMeshArgs{
WgArgs: ipc.WireGuardArgs{ WgArgs: ipc.WireGuardArgs{
WgPort: createMesh.WgPort, WgPort: createMesh.WgPort,
Role: createMesh.Role,
Endpoint: createMesh.PublicEndpoint,
AdvertiseRoutes: createMesh.AdvertiseRoutes,
AdvertiseDefaultRoute: createMesh.AdvertiseDefaults,
}, },
} }
var reply string var reply string
err := s.client.Call("IpcHandler.CreateMesh", &ipcRequest, &reply) err := s.client.CreateMesh(&ipcRequest, &reply)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, &gin.H{ c.JSON(http.StatusBadRequest, &gin.H{
@ -115,6 +128,14 @@ func (s *SmegServer) CreateMesh(c *gin.Context) {
return return
} }
if createMesh.Alias != "" {
s.putAlias(reply, createMesh.Alias)
}
if createMesh.Description != "" {
s.putDescription(reply, createMesh.Description)
}
c.JSON(http.StatusOK, &gin.H{ c.JSON(http.StatusOK, &gin.H{
"meshid": reply, "meshid": reply,
}) })
@ -132,16 +153,20 @@ func (s *SmegServer) JoinMesh(c *gin.Context) {
} }
ipcRequest := ipc.JoinMeshArgs{ ipcRequest := ipc.JoinMeshArgs{
MeshId: joinMesh.MeshId, MeshId: joinMesh.MeshId,
IpAdress: joinMesh.Bootstrap, IpAddress: joinMesh.Bootstrap,
WgArgs: ipc.WireGuardArgs{ WgArgs: ipc.WireGuardArgs{
WgPort: joinMesh.WgPort, WgPort: joinMesh.WgPort,
Endpoint: joinMesh.PublicEndpoint,
Role: joinMesh.Role,
AdvertiseRoutes: joinMesh.AdvertiseRoutes,
AdvertiseDefaultRoute: joinMesh.AdvertiseDefaults,
}, },
} }
var reply string var reply string
err := s.client.Call("IpcHandler.JoinMesh", &ipcRequest, &reply) err := s.client.JoinMesh(ipcRequest, &reply)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, &gin.H{ c.JSON(http.StatusBadRequest, &gin.H{
@ -150,6 +175,14 @@ func (s *SmegServer) JoinMesh(c *gin.Context) {
return return
} }
if joinMesh.Alias != "" {
s.putAlias(reply, joinMesh.Alias)
}
if joinMesh.Description != "" {
s.putDescription(reply, joinMesh.Description)
}
c.JSON(http.StatusOK, &gin.H{ c.JSON(http.StatusOK, &gin.H{
"status": "success", "status": "success",
}) })
@ -164,7 +197,7 @@ func (s *SmegServer) GetMesh(c *gin.Context) {
getMeshReply := new(ipc.GetMeshReply) getMeshReply := new(ipc.GetMeshReply)
err := s.client.Call("IpcHandler.GetMesh", &meshid, &getMeshReply) err := s.client.GetMesh(meshid, getMeshReply)
if err != nil { if err != nil {
c.JSON(http.StatusNotFound, c.JSON(http.StatusNotFound,
@ -179,10 +212,12 @@ func (s *SmegServer) GetMesh(c *gin.Context) {
c.JSON(http.StatusOK, mesh) c.JSON(http.StatusOK, mesh)
} }
// GetMeshes: return all the mesh networks that the
// user is a part of
func (s *SmegServer) GetMeshes(c *gin.Context) { func (s *SmegServer) GetMeshes(c *gin.Context) {
listMeshesReply := new(ipc.ListMeshReply) listMeshesReply := new(ipc.ListMeshReply)
err := s.client.Call("IpcHandler.ListMeshes", "", &listMeshesReply) err := s.client.ListMeshes(listMeshesReply)
if err != nil { if err != nil {
logging.Log.WriteErrorf(err.Error()) logging.Log.WriteErrorf(err.Error())
@ -195,7 +230,7 @@ func (s *SmegServer) GetMeshes(c *gin.Context) {
for _, mesh := range listMeshesReply.Meshes { for _, mesh := range listMeshesReply.Meshes {
getMeshReply := new(ipc.GetMeshReply) getMeshReply := new(ipc.GetMeshReply)
err := s.client.Call("IpcHandler.GetMesh", &mesh, &getMeshReply) err := s.client.GetMesh(mesh, getMeshReply)
if err != nil { if err != nil {
logging.Log.WriteErrorf(err.Error()) logging.Log.WriteErrorf(err.Error())
@ -209,13 +244,16 @@ func (s *SmegServer) GetMeshes(c *gin.Context) {
c.JSON(http.StatusOK, meshes) c.JSON(http.StatusOK, meshes)
} }
// Run: run the API server
func (s *SmegServer) Run(addr string) error { func (s *SmegServer) Run(addr string) error {
logging.Log.WriteInfof("Running API server") logging.Log.WriteInfof("Running API server")
return s.router.Run(addr) return s.router.Run(addr)
} }
// NewSmegServer: creates an instance of a new API server
// returns an error if something went wrong
func NewSmegServer(conf ApiServerConf) (ApiServer, error) { func NewSmegServer(conf ApiServerConf) (ApiServer, error) {
client, err := ipcRpc.DialHTTP("unix", SockAddr) client, err := ipc.NewClientIpc()
if err != nil { if err != nil {
return nil, err return nil, err
@ -239,9 +277,19 @@ func NewSmegServer(conf ApiServerConf) (ApiServer, error) {
words: words, words: words,
} }
router.GET("/meshes", smegServer.GetMeshes) v1 := router.Group("/api/v1")
router.GET("/mesh/:meshid", smegServer.GetMesh) {
router.POST("/mesh/create", smegServer.CreateMesh) meshes := v1.Group("/meshes")
router.POST("/mesh/join", smegServer.JoinMesh) {
meshes.GET("/", smegServer.GetMeshes)
}
mesh := v1.Group("/mesh")
{
mesh.GET("/:meshid", smegServer.GetMesh)
mesh.POST("/create", smegServer.CreateMesh)
mesh.POST("/join", smegServer.JoinMesh)
}
}
return smegServer, nil return smegServer, nil
} }

View File

@ -1,47 +1,129 @@
package api package api
import "time" import (
"time"
"github.com/gin-gonic/gin"
"github.com/tim-beatham/smegmesh/pkg/ipc"
"github.com/tim-beatham/smegmesh/pkg/what8words"
)
// Route is an advertised route in the data store
type Route struct { type Route struct {
Prefix string `json:"prefix"` // Prefix is the advertised route prefix
Path []string `json:"path"` Prefix string `json:"prefix"`
// Path is the hops the destination
Path []string `json:"path"`
} }
// SmegStats is the WireGuard stats that the underlying host
// has sent to the peer
type SmegStats struct { type SmegStats struct {
TotalTransmit int64 `json:"totalTransmit"` // TotalTransmit number of bytes sent to the peer
TotalReceived int64 `json:"totalReceived"` TotalTransmit int64 `json:"totalTransmit"`
// TotalReceived number of bytes received from the peer
TotalReceived int64 `json:"totalReceived"`
// KeepAliveInterval WireGuard keepalive interval that is sent to the host
KeepAliveInterval time.Duration `json:"keepaliveInterval"` KeepAliveInterval time.Duration `json:"keepaliveInterval"`
AllowedIps []string `json:"allowedIps"` // AllowsIps is the allowed path to the destination
AllowedIps []string `json:"allowedIps"`
} }
// SmegNode is a node in the mesh network
type SmegNode struct { type SmegNode struct {
Alias string `json:"alias"` // Alias is the human readable name that the node is assocaited with
WgHost string `json:"wgHost"` Alias string `json:"alias"`
WgEndpoint string `json:"wgEndpoint"` // WgHost is the WireGuard IP address of the node. This is an IPv6
Endpoint string `json:"endpoint"` // address
Timestamp int `json:"timestamp"` WgHost string `json:"wgHost"`
Description string `json:"description"` // WgEndpoint is the physical endpoint of the host that packets
PublicKey string `json:"publicKey"` // are forwarded to
Routes []Route `json:"routes"` WgEndpoint string `json:"wgEndpoint"`
Services map[string]string `json:"services"` // Endpoint is the control plane endpoint of the host which
Stats SmegStats `json:"stats"` // grpc connections are to be sent along
Endpoint string `json:"endpoint"`
// Timestamp is the last time the signified it was alive.
// if the node is the leader this is evert heartBeatInterval
// otherwise this is the time the node joined the network
Timestamp int `json:"timestamp"`
// Description is the human readable description of the node
Description string `json:"description"`
// PublicKey is the WireGuard public key of the node
PublicKey string `json:"publicKey"`
// Routes is the routes that the node is advertising
Routes []Route `json:"routes"`
// Services is information about services that the node offers
Services map[string]string `json:"services"`
// Stats is the WireGuard stats of the node (if any)
Stats SmegStats `json:"stats"`
} }
// SmegMesh encapsulates a single mesh in the API
type SmegMesh struct { type SmegMesh struct {
MeshId string `json:"meshid"` // MeshId is the mesh id of the network
Nodes map[string]SmegNode `json:"nodes"` MeshId string `json:"meshid"`
// Nodes is the nodes in the network keyed by their public
// key
Nodes map[string]SmegNode `json:"nodes"`
} }
// CreateMeshRequest encapsulates a request to create a mesh network
type CreateMeshRequest struct { type CreateMeshRequest struct {
// WgPort is the WireGuard to create the mesh in
WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"` WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"`
// Role is the role to take on in the mesh
Role string `json:"role" binding:"required,eq=client|eq=peer"`
// AdvertiseRoutes: advertise thi mesh to other meshes
AdvertiseRoutes bool `json:"advertiseRoutes"`
// AdvertiseDefaults: advertise an exit point
AdvertiseDefaults bool `json:"advertiseDefaults"`
// Alias: alias of the node in the mesh
Alias string `json:"alias"`
// Description: description of the node in the mesh
Description string `json:"description"`
// PublicEndpoint: an alternative public endpoint to advertise
PublicEndpoint string `json:"publicEndpoint"`
} }
// JoinMeshRequests encapsulates a request to create a mesh network
type JoinMeshRequest struct { type JoinMeshRequest struct {
WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"` // WgPort is the WireGuard port to run the service on
WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"`
// Bootstrap is a bootstrap node to use to join the network
Bootstrap string `json:"bootstrap" binding:"required"` Bootstrap string `json:"bootstrap" binding:"required"`
MeshId string `json:"meshid" binding:"required"` // MeshId is the ID of the mesh to join
MeshId string `json:"meshid" binding:"required"`
// Role is the role to take on in the mesh
Role string `json:"role" binding:"required,eq=client|eq=peer"`
// AdvertiseRoutes: advertise thi mesh to other meshes
AdvertiseRoutes bool `json:"advertiseRoutes"`
// AdvertiseDefaults: advertise an exit point
AdvertiseDefaults bool `json:"advertiseDefaults"`
// Alias: alias of the node in the mesh
Alias string `json:"alias"`
// Description: description of the node in the mesh
Description string `json:"description"`
// PublicEndpoint: an alternative public endpoint to advertise
PublicEndpoint string `json:"publicEndpoint"`
} }
// ApiServerConf configuration to instantiate the API server
type ApiServerConf struct { type ApiServerConf struct {
// WordsFile to use to map IP to words
WordsFile string WordsFile string
} }
// SmegSever is the GIN api server that runs the service
type SmegServer struct {
// gin router to use
router *gin.Engine
// client to invoke operations
client *ipc.SmegmeshIpc
// what8words to use to convert IP to an alias
words *what8words.What8Words
}
// ApiSever absrtacts the API server
type ApiServer interface {
Run(addr string) error
}

View File

@ -1,3 +1,5 @@
// automerge: package is depracated and unused. Please refer to crdt
// for crdt operations in the mesh
package automerge package automerge
import ( import (
@ -9,26 +11,36 @@ import (
"time" "time"
"github.com/automerge/automerge-go" "github.com/automerge/automerge-go"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// CrdtMeshManager manages nodes in the crdt mesh // CrdtMeshManager manage the CRDT datastore
type CrdtMeshManager struct { type CrdtMeshManager struct {
MeshId string // MeshID of the mesh the datastore represents
IfName string MeshId string
Client *wgctrl.Client // IfName: corresponding ifName
doc *automerge.Doc IfName string
LastHash automerge.ChangeHash // Client: corresponding wireguard control client
conf *conf.WgConfiguration Client *wgctrl.Client
cache *MeshCrdt // doc: autommerge document
doc *automerge.Doc
// LastHash: last hash that the changes were made to
LastHash automerge.ChangeHash
// conf: WireGuard configuration
conf *conf.WgConfiguration
// cache: stored cache of the list automerge document
// so that the store does not have to be repopulated each time
cache *MeshCrdt
// lastCachehash: hash of when the document was last changed
lastCacheHash automerge.ChangeHash lastCacheHash automerge.ChangeHash
} }
// AddNode as a node to the datastore
func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) { func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
crdt, ok := node.(*MeshNodeCrdt) crdt, ok := node.(*MeshNodeCrdt)
@ -47,6 +59,7 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
} }
} }
// isPeer: returns true if the given node has type peer
func (c *CrdtMeshManager) isPeer(nodeId string) bool { func (c *CrdtMeshManager) isPeer(nodeId string) bool {
node, err := c.doc.Path("nodes").Map().Get(nodeId) node, err := c.doc.Path("nodes").Map().Get(nodeId)
@ -64,7 +77,8 @@ func (c *CrdtMeshManager) isPeer(nodeId string) bool {
} }
// isAlive: checks that the node's configuration has been updated // isAlive: checks that the node's configuration has been updated
// since the rquired keep alive time // since the rquired keep alive time. Depracated no longer works
// due to changes in approach
func (c *CrdtMeshManager) isAlive(nodeId string) bool { func (c *CrdtMeshManager) isAlive(nodeId string) bool {
node, err := c.doc.Path("nodes").Map().Get(nodeId) node, err := c.doc.Path("nodes").Map().Get(nodeId)
@ -78,10 +92,11 @@ func (c *CrdtMeshManager) isAlive(nodeId string) bool {
return false return false
} }
return true
// return (time.Now().Unix() - keepAliveTime) < int64(c.conf.DeadTime) // return (time.Now().Unix() - keepAliveTime) < int64(c.conf.DeadTime)
return true
} }
// GetPeers: get all the peers in the mesh
func (c *CrdtMeshManager) GetPeers() []string { func (c *CrdtMeshManager) GetPeers() []string {
keys, _ := c.doc.Path("nodes").Map().Keys() keys, _ := c.doc.Path("nodes").Map().Keys()
@ -92,7 +107,7 @@ func (c *CrdtMeshManager) GetPeers() []string {
return keys return keys
} }
// GetMesh(): Converts the document into a struct // GetMesh: Converts the document into a mesh network
func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) { func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
changes, err := c.doc.Changes(c.lastCacheHash) changes, err := c.doc.Changes(c.lastCacheHash)
@ -114,7 +129,7 @@ func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
return c.cache, nil return c.cache, nil
} }
// GetMeshId returns the meshid of the mesh // GetMeshId: returns the meshid of the mesh
func (c *CrdtMeshManager) GetMeshId() string { func (c *CrdtMeshManager) GetMeshId() string {
return c.MeshId return c.MeshId
} }
@ -135,6 +150,8 @@ func (c *CrdtMeshManager) Load(bytes []byte) error {
return nil return nil
} }
// NewCrdtNodeManagerParams: params to instantiate a new automerge
// datastore
type NewCrdtNodeMangerParams struct { type NewCrdtNodeMangerParams struct {
MeshId string MeshId string
DevName string DevName string
@ -143,7 +160,7 @@ type NewCrdtNodeMangerParams struct {
Client *wgctrl.Client Client *wgctrl.Client
} }
// NewCrdtNodeManager: Create a new crdt node manager // NewCrdtNodeManager: Create a new automerge crdt data store
func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, error) { func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, error) {
var manager CrdtMeshManager var manager CrdtMeshManager
manager.MeshId = params.MeshId manager.MeshId = params.MeshId
@ -155,12 +172,13 @@ func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, erro
return &manager, nil return &manager, nil
} }
// NodeExists: returns true if the node exists. Returns false // NodeExists: returns true if the node exists other returns false
func (m *CrdtMeshManager) NodeExists(key string) bool { func (m *CrdtMeshManager) NodeExists(key string) bool {
node, err := m.doc.Path("nodes").Map().Get(key) node, err := m.doc.Path("nodes").Map().Get(key)
return node.Kind() == automerge.KindMap && err == nil return node.Kind() == automerge.KindMap && err == nil
} }
// GetNode: gets a node from the mesh network.
func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) { func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
node, err := m.doc.Path("nodes").Map().Get(endpoint) node, err := m.doc.Path("nodes").Map().Get(endpoint)
@ -181,10 +199,12 @@ func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
return meshNode, nil return meshNode, nil
} }
// Length: returns the number of nodes in the store
func (m *CrdtMeshManager) Length() int { func (m *CrdtMeshManager) Length() int {
return m.doc.Path("nodes").Map().Len() return m.doc.Path("nodes").Map().Len()
} }
// GetDevice: get the underlying WireGuard device
func (m *CrdtMeshManager) GetDevice() (*wgtypes.Device, error) { func (m *CrdtMeshManager) GetDevice() (*wgtypes.Device, error) {
dev, err := m.Client.Device(m.IfName) dev, err := m.Client.Device(m.IfName)
@ -195,7 +215,7 @@ func (m *CrdtMeshManager) GetDevice() (*wgtypes.Device, error) {
return dev, nil return dev, nil
} }
// HasChanges returns true if we have changes since the last time we synced // HasChanges: returns true if there are changes since last time synchronised
func (m *CrdtMeshManager) HasChanges() bool { func (m *CrdtMeshManager) HasChanges() bool {
changes, err := m.doc.Changes(m.LastHash) changes, err := m.doc.Changes(m.LastHash)
@ -209,6 +229,7 @@ func (m *CrdtMeshManager) HasChanges() bool {
return len(changes) > 0 return len(changes) > 0
} }
// SaveChanges: save changes to the datastore
func (m *CrdtMeshManager) SaveChanges() { func (m *CrdtMeshManager) SaveChanges() {
hashes := m.doc.Heads() hashes := m.doc.Heads()
hash := hashes[len(hashes)-1] hash := hashes[len(hashes)-1]
@ -217,6 +238,7 @@ func (m *CrdtMeshManager) SaveChanges() {
m.LastHash = hash m.LastHash = hash
} }
// UpdateTimeStamp: updates the timestamp of the document
func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error { func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId) node, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -237,6 +259,7 @@ func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
return err return err
} }
// SetDescription: set the description of the given node
func (m *CrdtMeshManager) SetDescription(nodeId string, description string) error { func (m *CrdtMeshManager) SetDescription(nodeId string, description string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId) node, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -257,6 +280,7 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro
return err return err
} }
// SetAlias: set the alias of the given node
func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error { func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId) node, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -277,6 +301,7 @@ func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
return err return err
} }
// AddService: add a service to the given node
func (m *CrdtMeshManager) AddService(nodeId, key, value string) error { func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId) node, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -298,6 +323,7 @@ func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
return err return err
} }
// RemoveService: remove a service from a node
func (m *CrdtMeshManager) RemoveService(nodeId, key string) error { func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId) node, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -378,6 +404,7 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
return nil return nil
} }
// getRoutes: get the routes that the given node is directly advertising
func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) { func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -404,6 +431,8 @@ func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) {
return lib.MapValues(routes), err return lib.MapValues(routes), err
} }
// GetRoutes: get all the routes that the node can see. The routes that the node
// can say may not be direct but cann also be indirect
func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) { func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) {
node, err := m.GetNode(targetNode) node, err := m.GetNode(targetNode)
@ -447,12 +476,13 @@ func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, e
return routes, nil return routes, nil
} }
// RemoveNode: removes a node from the datastore
func (m *CrdtMeshManager) RemoveNode(nodeId string) error { func (m *CrdtMeshManager) RemoveNode(nodeId string) error {
err := m.doc.Path("nodes").Map().Delete(nodeId) err := m.doc.Path("nodes").Map().Delete(nodeId)
return err return err
} }
// DeleteRoutes deletes the specified routes // RemoveRoutes: withdraw all the routes the nodeID is advertising
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error { func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -486,30 +516,37 @@ func (m *CrdtMeshManager) GetConfiguration() *conf.WgConfiguration {
func (m *CrdtMeshManager) Mark(nodeId string) { func (m *CrdtMeshManager) Mark(nodeId string) {
} }
// GetSyncer: get the bi-directionally syncer to synchronise the document
func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer { func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
return NewAutomergeSync(m) return NewAutomergeSync(m)
} }
// Prune: prune all dead nodes
func (m *CrdtMeshManager) Prune() error { func (m *CrdtMeshManager) Prune() error {
return nil return nil
} }
// Compare: compare two mesh node for equality
func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int { func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int {
return strings.Compare(m1.PublicKey, m2.PublicKey) return strings.Compare(m1.PublicKey, m2.PublicKey)
} }
// GetHostEndpoint: get the ctrl endpoint of the host
func (m *MeshNodeCrdt) GetHostEndpoint() string { func (m *MeshNodeCrdt) GetHostEndpoint() string {
return m.HostEndpoint return m.HostEndpoint
} }
// GetPublicKey: get the public key of the node
func (m *MeshNodeCrdt) GetPublicKey() (wgtypes.Key, error) { func (m *MeshNodeCrdt) GetPublicKey() (wgtypes.Key, error) {
return wgtypes.ParseKey(m.PublicKey) return wgtypes.ParseKey(m.PublicKey)
} }
// GetWgEndpoint: get the outer WireGuard endpoint
func (m *MeshNodeCrdt) GetWgEndpoint() string { func (m *MeshNodeCrdt) GetWgEndpoint() string {
return m.WgEndpoint return m.WgEndpoint
} }
// GetWgHost: get the WireGuard IP address of the host
func (m *MeshNodeCrdt) GetWgHost() *net.IPNet { func (m *MeshNodeCrdt) GetWgHost() *net.IPNet {
_, ipnet, err := net.ParseCIDR(m.WgHost) _, ipnet, err := net.ParseCIDR(m.WgHost)
@ -520,10 +557,12 @@ func (m *MeshNodeCrdt) GetWgHost() *net.IPNet {
return ipnet return ipnet
} }
// GetTimeStamp: get timestamp if when the node was last updated
func (m *MeshNodeCrdt) GetTimeStamp() int64 { func (m *MeshNodeCrdt) GetTimeStamp() int64 {
return m.Timestamp return m.Timestamp
} }
// GetRoutes: get all the routes advertised by the node
func (m *MeshNodeCrdt) GetRoutes() []mesh.Route { func (m *MeshNodeCrdt) GetRoutes() []mesh.Route {
return lib.Map(lib.MapValues(m.Routes), func(r Route) mesh.Route { return lib.Map(lib.MapValues(m.Routes), func(r Route) mesh.Route {
return &Route{ return &Route{
@ -533,10 +572,12 @@ func (m *MeshNodeCrdt) GetRoutes() []mesh.Route {
}) })
} }
// GetDescription: get the description of the node
func (m *MeshNodeCrdt) GetDescription() string { func (m *MeshNodeCrdt) GetDescription() string {
return m.Description return m.Description
} }
// GetIdentifier: get the iderntifier section of the ipv6 address
func (m *MeshNodeCrdt) GetIdentifier() string { func (m *MeshNodeCrdt) GetIdentifier() string {
ipv6 := m.WgHost[:len(m.WgHost)-4] ipv6 := m.WgHost[:len(m.WgHost)-4]
@ -545,10 +586,12 @@ func (m *MeshNodeCrdt) GetIdentifier() string {
return strings.Join(constituents, ":") return strings.Join(constituents, ":")
} }
// GetAlias: get the alias of the node
func (m *MeshNodeCrdt) GetAlias() string { func (m *MeshNodeCrdt) GetAlias() string {
return m.Alias return m.Alias
} }
// GetServices: get all the services the node is advertising
func (m *MeshNodeCrdt) GetServices() map[string]string { func (m *MeshNodeCrdt) GetServices() map[string]string {
services := make(map[string]string) services := make(map[string]string)
@ -565,6 +608,7 @@ func (n *MeshNodeCrdt) GetType() conf.NodeType {
return conf.NodeType(n.Type) return conf.NodeType(n.Type)
} }
// GetNodes: get all the nodes in the network
func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode { func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
nodes := make(map[string]mesh.MeshNode) nodes := make(map[string]mesh.MeshNode)
@ -586,15 +630,18 @@ func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
return nodes return nodes
} }
// GetDestination: get destination of the route
func (r *Route) GetDestination() *net.IPNet { func (r *Route) GetDestination() *net.IPNet {
_, ipnet, _ := net.ParseCIDR(r.Destination) _, ipnet, _ := net.ParseCIDR(r.Destination)
return ipnet return ipnet
} }
// GetHopCount: get the number of hops to the destination
func (r *Route) GetHopCount() int { func (r *Route) GetHopCount() int {
return len(r.Path) return len(r.Path)
} }
// GetPath: get the total path which includes the number of hops
func (r *Route) GetPath() []string { func (r *Route) GetPath() []string {
return r.Path return r.Path
} }

View File

@ -1,15 +1,24 @@
// automerge: automerge is a CRDT library. Defines a CRDT
// datastore and methods to resolve conflicts
package automerge package automerge
import ( import (
"github.com/automerge/automerge-go" "github.com/automerge/automerge-go"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
) )
// AutomergeSync: defines a synchroniser to bi-directionally synchronise the
// two states
type AutomergeSync struct { type AutomergeSync struct {
state *automerge.SyncState // state: the automerge sync state to use
state *automerge.SyncState
// manager: the corresponding data store that we are merging
manager *CrdtMeshManager manager *CrdtMeshManager
} }
// GenerateMessage: geenrate a new automerge message to synchronise
// returns a byte of the message and a boolean of whether or not there
// are more messages in the sequence
func (a *AutomergeSync) GenerateMessage() ([]byte, bool) { func (a *AutomergeSync) GenerateMessage() ([]byte, bool) {
msg, valid := a.state.GenerateMessage() msg, valid := a.state.GenerateMessage()
@ -20,6 +29,8 @@ func (a *AutomergeSync) GenerateMessage() ([]byte, bool) {
return msg.Bytes(), true return msg.Bytes(), true
} }
// RecvMessage: receive an automerge message to merge in the datastore
// returns an error if unsuccessful
func (a *AutomergeSync) RecvMessage(msg []byte) error { func (a *AutomergeSync) RecvMessage(msg []byte) error {
_, err := a.state.ReceiveMessage(msg) _, err := a.state.ReceiveMessage(msg)
@ -30,11 +41,13 @@ func (a *AutomergeSync) RecvMessage(msg []byte) error {
return nil return nil
} }
// Complete: complete the synchronisation process
func (a *AutomergeSync) Complete() { func (a *AutomergeSync) Complete() {
logging.Log.WriteInfof("Sync Completed") logging.Log.WriteInfof("sync completed")
a.manager.SaveChanges() a.manager.SaveChanges()
} }
// NewAutomergeSync: instantiates a new automerge syncer
func NewAutomergeSync(manager *CrdtMeshManager) *AutomergeSync { func NewAutomergeSync(manager *CrdtMeshManager) *AutomergeSync {
return &AutomergeSync{ return &AutomergeSync{
state: automerge.NewSyncState(manager.doc), state: automerge.NewSyncState(manager.doc),

View File

@ -6,9 +6,9 @@ import (
"testing" "testing"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@ -83,7 +83,6 @@ func TestAddNodeAddRoute(t *testing.T) {
testParams.manager.AddNode(testNode) testParams.manager.AddNode(testNode)
testParams.manager.AddRoutes(pubKey.String(), &mesh.RouteStub{ testParams.manager.AddRoutes(pubKey.String(), &mesh.RouteStub{
Destination: destination, Destination: destination,
HopCount: 0,
Path: make([]string, 0), Path: make([]string, 0),
}) })
updatedNode, err := testParams.manager.GetNode(pubKey.String()) updatedNode, err := testParams.manager.GetNode(pubKey.String())
@ -297,7 +296,6 @@ func TestAddRoutesNodeDoesNotExist(t *testing.T) {
err := testParams.manager.AddRoutes("AAAAA", &mesh.RouteStub{ err := testParams.manager.AddRoutes("AAAAA", &mesh.RouteStub{
Destination: destination, Destination: destination,
HopCount: 0,
Path: make([]string, 0), Path: make([]string, 0),
}) })

View File

@ -3,13 +3,16 @@ package automerge
import ( import (
"fmt" "fmt"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
) )
// CrdtProviderFactory: abstracts the instantiation of an automerge
// datastore
type CrdtProviderFactory struct{} type CrdtProviderFactory struct{}
// CreateMesh: create a new mesh datastore
func (f *CrdtProviderFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) { func (f *CrdtProviderFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) {
return NewCrdtNodeManager(&NewCrdtNodeMangerParams{ return NewCrdtNodeManager(&NewCrdtNodeMangerParams{
MeshId: params.MeshId, MeshId: params.MeshId,
@ -19,11 +22,12 @@ func (f *CrdtProviderFactory) CreateMesh(params *mesh.MeshProviderFactoryParams)
}) })
} }
// MeshNodeFactory: abstracts the instnatiation of a node
type MeshNodeFactory struct { type MeshNodeFactory struct {
Config conf.DaemonConfiguration Config conf.DaemonConfiguration
} }
// Build builds the mesh node that represents the host machine to add // Build: builds the mesh node that represents the host machine to add
// to the mesh // to the mesh
func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode { func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode {
hostName := f.getAddress(params) hostName := f.getAddress(params)
@ -48,7 +52,7 @@ func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNod
} }
} }
// getAddress returns the routable address of the machine. // getAddress: returns the routable address of the machine.
func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string { func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string {
var hostName string = "" var hostName string = ""
@ -59,7 +63,7 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string
} else { } else {
ipFunc := lib.GetPublicIP ipFunc := lib.GetPublicIP
if *params.MeshConfig.IPDiscovery == conf.DNS_IP_DISCOVERY { if *params.MeshConfig.IPDiscovery == conf.OUTGOING_IP_DISCOVERY {
ipFunc = lib.GetOutboundIP ipFunc = lib.GetOutboundIP
} }

View File

@ -6,10 +6,12 @@ import (
"strings" "strings"
) )
// CmdRunner: run cmd commands when instantiating a network
type CmdRunner interface { type CmdRunner interface {
RunCommands(commands ...string) error RunCommands(commands ...string) error
} }
// UnixCmdRunner: Run UNIX commands
type UnixCmdRunner struct{} type UnixCmdRunner struct{}
// RunCommand: runs the unix command. It splits the command into fields // RunCommand: runs the unix command. It splits the command into fields
@ -20,6 +22,7 @@ func RunCommand(cmd string) error {
return c.Run() return c.Run()
} }
// RunCommands: run a series of commands
func (l *UnixCmdRunner) RunCommands(commands ...string) error { func (l *UnixCmdRunner) RunCommands(commands ...string) error {
for _, cmd := range commands { for _, cmd := range commands {
err := RunCommand(cmd) err := RunCommand(cmd)

View File

@ -8,14 +8,7 @@ import (
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
type WgMeshConfigurationError struct { // NodeType types of the node either peer or client
msg string
}
func (m *WgMeshConfigurationError) Error() string {
return m.msg
}
type NodeType string type NodeType string
const ( const (
@ -23,11 +16,23 @@ const (
CLIENT_ROLE NodeType = "client" CLIENT_ROLE NodeType = "client"
) )
// IPDiscovery: what IPDiscovery service to use
type IPDiscovery string type IPDiscovery string
const ( const (
// Public IP use an IP service to discover your IP
PUBLIC_IP_DISCOVERY IPDiscovery = "public" PUBLIC_IP_DISCOVERY IPDiscovery = "public"
DNS_IP_DISCOVERY IPDiscovery = "dns" // Outgonig: Use your labelled packet IP
OUTGOING_IP_DISCOVERY IPDiscovery = "outgoing"
)
// Loglevel: what log level to use either error info or warning
type LogLevel string
const (
ERROR LogLevel = "error"
WARNING LogLevel = "warning"
INFO LogLevel = "info"
) )
// WgConfiguration contains per-mesh WireGuard configuration. Contains poitner types only so we can // WgConfiguration contains per-mesh WireGuard configuration. Contains poitner types only so we can
@ -35,7 +40,7 @@ const (
type WgConfiguration struct { type WgConfiguration struct {
// IPDIscovery: how to discover your IP if not specified. Use your outgoing IP or use a public // IPDIscovery: how to discover your IP if not specified. Use your outgoing IP or use a public
// service for IPDiscoverability // service for IPDiscoverability
IPDiscovery *IPDiscovery `yaml:"ipDiscovery" validate:"required,eq=public|eq=dns"` IPDiscovery *IPDiscovery `yaml:"ipDiscovery" validate:"required,eq=public|eq=outgoing"`
// AdvertiseRoutes: specifies whether the node can act as a router routing packets between meshes // AdvertiseRoutes: specifies whether the node can act as a router routing packets between meshes
AdvertiseRoutes *bool `yaml:"advertiseRoute" validate:"required"` AdvertiseRoutes *bool `yaml:"advertiseRoute" validate:"required"`
// AdvertiseDefaultRoute: specifies whether or not this route should advertise a default route // AdvertiseDefaultRoute: specifies whether or not this route should advertise a default route
@ -47,7 +52,6 @@ type WgConfiguration struct {
// If the user is globaly accessible they specify themselves as a client. // If the user is globaly accessible they specify themselves as a client.
Role *NodeType `yaml:"role" validate:"required,eq=client|eq=peer"` Role *NodeType `yaml:"role" validate:"required,eq=client|eq=peer"`
// KeepAliveWg configures the implementation so that we send keep alive packets to peers. // KeepAliveWg configures the implementation so that we send keep alive packets to peers.
// KeepAlive can only be set if role is type client
KeepAliveWg *int `yaml:"keepAliveWg" validate:"omitempty,gte=0"` KeepAliveWg *int `yaml:"keepAliveWg" validate:"omitempty,gte=0"`
// PreUp are WireGuard commands to run before adding the WG interface // PreUp are WireGuard commands to run before adding the WG interface
PreUp []string `yaml:"preUp"` PreUp []string `yaml:"preUp"`
@ -77,22 +81,26 @@ type DaemonConfiguration struct {
Profile bool `yaml:"profile"` Profile bool `yaml:"profile"`
// StubWg whether or not to stub the WireGuard types // StubWg whether or not to stub the WireGuard types
StubWg bool `yaml:"stubWg"` StubWg bool `yaml:"stubWg"`
// SyncRate specifies how long the minimum time should be between synchronisation // SyncInterval specifies how long the minimum time should be between synchronisation
SyncRate int `yaml:"syncRate" validate:"required,gte=1"` SyncInterval int `yaml:"syncInterval" validate:"required,gte=1"`
// KeepAliveTime: number of seconds before the leader of the mesh sends an update to // PullInterval specifies the interval between checking for configuration changes
PullInterval int `yaml:"pullInterval" validate:"gte=0"`
// Heartbeat: number of seconds before the leader of the mesh sends an update to
// send to every member in the mesh // send to every member in the mesh
KeepAliveTime int `yaml:"keepAliveTime" validate:"required,gte=1"` Heartbeat int `yaml:"heartbeatInterval" validate:"required,gte=1"`
// ClusterSize specifies how many neighbours you should synchronise with per round // ClusterSize specifies how many neighbours you should synchronise with per round
ClusterSize int `yaml:"clusterSize" validate:"gte=1"` ClusterSize int `yaml:"clusterSize" validate:"gte=1"`
// InterClusterChance specifies the probabilityof inter-cluster communication in a sync round // InterClusterChance specifies the probabilityof inter-cluster communication in a sync round
InterClusterChance float64 `yaml:"interClusterChance" validate:"gt=0"` InterClusterChance float64 `yaml:"interClusterChance" validate:"gt=0"`
// BranchRate specifies the number of nodes to synchronise with when a node has // Branch specifies the number of nodes to synchronise with when a node has
// new changes to send to the mesh // new changes to send to the mesh
BranchRate int `yaml:"branchRate" validate:"required,gte=1"` Branch int `yaml:"branch" validate:"required,gte=1"`
// InfectionCount: number of time to sync before an update can no longer be 'caught' // InfectionCount: number of time to sync before an update can no longer be 'caught'
InfectionCount int `yaml:"infectionCount" validate:"required,gte=1"` InfectionCount int `yaml:"infectionCount" validate:"required,gte=1"`
// BaseConfiguration base WireGuard configuration to use, this is used when none is provided // BaseConfiguration base WireGuard configuration to use, this is used when none is provided
BaseConfiguration WgConfiguration `yaml:"baseConfiguration" validate:"required"` BaseConfiguration WgConfiguration `yaml:"baseConfiguration" validate:"required"`
// LogLevel specifies the log level to output, defaults is warning
LogLevel LogLevel `yaml:"logLevel" validate:"eq=info|eq=warning|eq=error"`
} }
// ValdiateMeshConfiguration: validates the mesh configuration // ValdiateMeshConfiguration: validates the mesh configuration
@ -120,32 +128,21 @@ func ValidateMeshConfiguration(conf *WgConfiguration) error {
} }
// ValidateDaemonConfiguration: validates the dameon configuration that is used. // ValidateDaemonConfiguration: validates the dameon configuration that is used.
func ValidateDaemonConfiguration(c *DaemonConfiguration) error { func ValidateDaemonConfiguration(conf *DaemonConfiguration) error {
if conf.BaseConfiguration.KeepAliveWg == nil {
var keepAlive int = 0
conf.BaseConfiguration.KeepAliveWg = &keepAlive
}
if conf.LogLevel == "" {
conf.LogLevel = WARNING
}
validate := validator.New(validator.WithRequiredStructEnabled()) validate := validator.New(validator.WithRequiredStructEnabled())
err := validate.Struct(c) err := validate.Struct(conf)
return err return err
} }
// ParseMeshConfiguration: parses the mesh network configuration. Parses parameters such as
// keepalive time, role and so forth.
func ParseMeshConfiguration(filePath string) (*WgConfiguration, error) {
var conf WgConfiguration
yamlBytes, err := os.ReadFile(filePath)
if err != nil {
return nil, err
}
err = yaml.Unmarshal(yamlBytes, &conf)
if err != nil {
return nil, err
}
return &conf, ValidateMeshConfiguration(&conf)
}
// ParseDaemonConfiguration parses the mesh configuration and validates the configuration // ParseDaemonConfiguration parses the mesh configuration and validates the configuration
func ParseDaemonConfiguration(filePath string) (*DaemonConfiguration, error) { func ParseDaemonConfiguration(filePath string) (*DaemonConfiguration, error) {
var conf DaemonConfiguration var conf DaemonConfiguration

View File

@ -21,11 +21,12 @@ func getExampleConfiguration() *DaemonConfiguration {
Timeout: 5, Timeout: 5,
Profile: false, Profile: false,
StubWg: false, StubWg: false,
SyncRate: 2, SyncInterval: 2,
KeepAliveTime: 2, Heartbeat: 2,
ClusterSize: 64, ClusterSize: 64,
InterClusterChance: 0.15, InterClusterChance: 0.15,
BranchRate: 3, Branch: 3,
PullInterval: 0,
InfectionCount: 2, InfectionCount: 2,
BaseConfiguration: WgConfiguration{ BaseConfiguration: WgConfiguration{
IPDiscovery: &discovery, IPDiscovery: &discovery,
@ -153,7 +154,7 @@ func TestRoleTypeNotSpecified(t *testing.T) {
func TestBranchRateZero(t *testing.T) { func TestBranchRateZero(t *testing.T) {
conf := getExampleConfiguration() conf := getExampleConfiguration()
conf.BranchRate = 0 conf.Branch = 0
err := ValidateDaemonConfiguration(conf) err := ValidateDaemonConfiguration(conf)
@ -162,9 +163,9 @@ func TestBranchRateZero(t *testing.T) {
} }
} }
func TestSyncRateZero(t *testing.T) { func TestsyncTimeZero(t *testing.T) {
conf := getExampleConfiguration() conf := getExampleConfiguration()
conf.SyncRate = 0 conf.SyncInterval = 0
err := ValidateDaemonConfiguration(conf) err := ValidateDaemonConfiguration(conf)
@ -175,7 +176,7 @@ func TestSyncRateZero(t *testing.T) {
func TestKeepAliveTimeZero(t *testing.T) { func TestKeepAliveTimeZero(t *testing.T) {
conf := getExampleConfiguration() conf := getExampleConfiguration()
conf.KeepAliveTime = 0 conf.Heartbeat = 0
err := ValidateDaemonConfiguration(conf) err := ValidateDaemonConfiguration(conf)
if err == nil { if err == nil {
@ -215,6 +216,17 @@ func TestInfectionCountOne(t *testing.T) {
} }
} }
func TestPullTimeNegative(t *testing.T) {
conf := getExampleConfiguration()
conf.PullInterval = -1
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestValidConfiguration(t *testing.T) { func TestValidConfiguration(t *testing.T) {
conf := getExampleConfiguration() conf := getExampleConfiguration()
err := ValidateDaemonConfiguration(conf) err := ValidateDaemonConfiguration(conf)

View File

@ -7,25 +7,30 @@ import (
"slices" "slices"
) )
// ConnCluster splits nodes into clusters where nodes in a cluster communicate // ConnCluster: splits nodes into clusters where nodes in a cluster communicate
// frequently and nodes outside of a cluster communicate infrequently // frequently and nodes outside of a cluster communicate infrequently
type ConnCluster interface { type ConnCluster interface {
// Getneighbours: get neighbours of the cluster the node is in
GetNeighbours(global []string, selfId string) []string GetNeighbours(global []string, selfId string) []string
// GetInterCluster: get the cluster to communicate with
GetInterCluster(global []string, selfId string) string GetInterCluster(global []string, selfId string) string
} }
// ConnnClusterImpl: implementation of the connection cluster
type ConnClusterImpl struct { type ConnClusterImpl struct {
clusterSize int clusterSize int
} }
// perform binary search to attain a size of a group
func binarySearch(global []string, selfId string, groupSize int) (int, int) { func binarySearch(global []string, selfId string, groupSize int) (int, int) {
slices.Sort(global) slices.Sort(global)
lower := 0 lower := 0
higher := len(global) - 1 higher := len(global) - 1
mid := (lower + higher) / 2
for (higher+1)-lower > groupSize { for (higher+1)-lower > groupSize {
mid := (lower + higher) / 2
if global[mid] < selfId { if global[mid] < selfId {
lower = mid + 1 lower = mid + 1
} else if global[mid] > selfId { } else if global[mid] > selfId {
@ -33,14 +38,12 @@ func binarySearch(global []string, selfId string, groupSize int) (int, int) {
} else { } else {
break break
} }
mid = (lower + higher) / 2
} }
return lower, int(math.Min(float64(lower+groupSize), float64(len(global)))) return lower, int(math.Min(float64(lower+groupSize), float64(len(global))))
} }
// GetNeighbours return the neighbours 'nearest' to you. In this implementation the // GetNeighbours: return the neighbours 'nearest' to you. In this implementation the
// neighbours aren't actually the ones nearest to you but just the ones nearest // neighbours aren't actually the ones nearest to you but just the ones nearest
// to you alphabetically. Perform binary search to get the total group // to you alphabetically. Perform binary search to get the total group
func (i *ConnClusterImpl) GetNeighbours(global []string, selfId string) []string { func (i *ConnClusterImpl) GetNeighbours(global []string, selfId string) []string {
@ -51,7 +54,7 @@ func (i *ConnClusterImpl) GetNeighbours(global []string, selfId string) []string
return global[lower:higher] return global[lower:higher]
} }
// GetInterCluster get nodes not in your cluster. Every round there is a given chance // GetInterCluster: get nodes not in your cluster. Every round there is a given chance
// you will communicate with a random node that is not in your cluster. // you will communicate with a random node that is not in your cluster.
func (i *ConnClusterImpl) GetInterCluster(global []string, selfId string) string { func (i *ConnClusterImpl) GetInterCluster(global []string, selfId string) string {
// Doesn't matter if not in it. Get index of where the node 'should' be // Doesn't matter if not in it. Get index of where the node 'should' be
@ -66,6 +69,7 @@ func (i *ConnClusterImpl) GetInterCluster(global []string, selfId string) string
return global[neighbourIndex] return global[neighbourIndex]
} }
// NewConnCluster: instantiate a new connection cluster of a given group size.
func NewConnCluster(clusterSize int) (ConnCluster, error) { func NewConnCluster(clusterSize int) (ConnCluster, error) {
log2Cluster := math.Log2(float64(clusterSize)) log2Cluster := math.Log2(float64(clusterSize))

View File

@ -6,7 +6,7 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
) )
@ -18,6 +18,7 @@ type PeerConnection interface {
GetClient() (*grpc.ClientConn, error) GetClient() (*grpc.ClientConn, error)
} }
// PeerConenctionFactory: create a new connection to a peer
type PeerConnectionFactory = func(clientConfig *tls.Config, server string) (PeerConnection, error) type PeerConnectionFactory = func(clientConfig *tls.Config, server string) (PeerConnection, error)
// WgCtrlConnection implements PeerConnection. // WgCtrlConnection implements PeerConnection.

View File

@ -7,7 +7,7 @@ import (
"os" "os"
"sync" "sync"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
) )
// ConnectionManager defines an interface for maintaining peer connections // ConnectionManager defines an interface for maintaining peer connections
@ -19,9 +19,11 @@ type ConnectionManager interface {
// If the endpoint does not exist then add the connection. Returns an error // If the endpoint does not exist then add the connection. Returns an error
// if something went wrong // if something went wrong
GetConnection(endPoint string) (PeerConnection, error) GetConnection(endPoint string) (PeerConnection, error)
// HasConnections returns true if a client has already registered at the givne // HasConnections returns true if a peer has already registered at the given
// endpoint or false otherwise. // endpoint or false otherwise.
HasConnection(endPoint string) bool HasConnection(endPoint string) bool
// Removes a connection if it exists
RemoveConnection(endPoint string) error
// Goes through all the connections and closes eachone // Goes through all the connections and closes eachone
Close() error Close() error
} }
@ -32,7 +34,6 @@ type ConnectionManagerImpl struct {
// clientConnections maps an endpoint to a connection // clientConnections maps an endpoint to a connection
conLoc sync.RWMutex conLoc sync.RWMutex
clientConnections map[string]PeerConnection clientConnections map[string]PeerConnection
serverConfig *tls.Config
clientConfig *tls.Config clientConfig *tls.Config
connFactory PeerConnectionFactory connFactory PeerConnectionFactory
} }
@ -61,37 +62,25 @@ func NewConnectionManager(params *NewConnectionManagerParams) (ConnectionManager
return nil, err return nil, err
} }
serverAuth := tls.RequireAndVerifyClientCert
if params.SkipCertVerification {
serverAuth = tls.RequireAnyClientCert
}
certPool := x509.NewCertPool() certPool := x509.NewCertPool()
if !params.SkipCertVerification { if params.CaCert == "" {
return nil, errors.New("CA Cert is not specified")
if params.CaCert == "" {
return nil, errors.New("CA Cert is not specified")
}
caCert, err := os.ReadFile(params.CaCert)
if err != nil {
return nil, err
}
certPool.AppendCertsFromPEM(caCert)
} }
serverConfig := &tls.Config{ caCert, err := os.ReadFile(params.CaCert)
ClientAuth: serverAuth,
Certificates: []tls.Certificate{cert}, if err != nil {
return nil, err
}
if ok := certPool.AppendCertsFromPEM(caCert); !ok {
return nil, errors.New("could not parse PEM")
} }
clientConfig := &tls.Config{ clientConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: params.SkipCertVerification, InsecureSkipVerify: params.SkipCertVerification,
Certificates: []tls.Certificate{cert},
RootCAs: certPool, RootCAs: certPool,
} }
@ -99,7 +88,6 @@ func NewConnectionManager(params *NewConnectionManagerParams) (ConnectionManager
connMgr := ConnectionManagerImpl{ connMgr := ConnectionManagerImpl{
sync.RWMutex{}, sync.RWMutex{},
connections, connections,
serverConfig,
clientConfig, clientConfig,
params.ConnFactory, params.ConnFactory,
} }
@ -150,6 +138,15 @@ func (m *ConnectionManagerImpl) HasConnection(endPoint string) bool {
return exists return exists
} }
// RemoveConnection removes the given connection if it exists
func (m *ConnectionManagerImpl) RemoveConnection(endPoint string) error {
m.conLoc.Lock()
err := m.clientConnections[endPoint].Close()
delete(m.clientConnections, endPoint)
m.conLoc.Unlock()
return err
}
func (m *ConnectionManagerImpl) Close() error { func (m *ConnectionManagerImpl) Close() error {
for _, conn := range m.clientConnections { for _, conn := range m.clientConnections {
if err := conn.Close(); err != nil { if err := conn.Close(); err != nil {

View File

@ -53,13 +53,13 @@ func TestNewConnectionManagerCACertDoesNotExistAndVerify(t *testing.T) {
func TestNewConnectionManagerCACertDoesNotExistAndNotVerify(t *testing.T) { func TestNewConnectionManagerCACertDoesNotExistAndNotVerify(t *testing.T) {
params := getConnectionManagerParams() params := getConnectionManagerParams()
params.CaCert = "" params.CaCert = "./cert/sdjsdjsdjk.pem"
params.SkipCertVerification = true params.SkipCertVerification = true
_, err := NewConnectionManager(params) _, err := NewConnectionManager(params)
if err != nil { if err == nil {
t.Fatal(`an error should not be thrown`) t.Fatalf(`an error should be thrown`)
} }
} }

View File

@ -2,22 +2,23 @@ package conn
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509"
"errors"
"fmt" "fmt"
"net" "net"
"os"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/smegmesh/pkg/rpc"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
) )
// ConnectionServer manages gRPC server peer connections // ConnectionServer manages gRPC server peer connections
type ConnectionServer struct { type ConnectionServer struct {
// tlsConfiguration of the server
serverConfig *tls.Config
// server an instance of the grpc server // server an instance of the grpc server
server *grpc.Server // the authentication service to authenticate nodes server *grpc.Server
// the ctrl service to manage node // the ctrl service to manage node
ctrlProvider rpc.MeshCtrlServerServer ctrlProvider rpc.MeshCtrlServerServer
// the sync service to synchronise nodes // the sync service to synchronise nodes
@ -48,9 +49,26 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
serverAuth = tls.RequireAnyClientCert serverAuth = tls.RequireAnyClientCert
} }
certPool := x509.NewCertPool()
if params.Conf.CaCertificatePath == "" {
return nil, errors.New("CA Cert is not specified")
}
caCert, err := os.ReadFile(params.Conf.CaCertificatePath)
if err != nil {
return nil, err
}
if ok := certPool.AppendCertsFromPEM(caCert); !ok {
return nil, errors.New("could not parse PEM")
}
serverConfig := &tls.Config{ serverConfig := &tls.Config{
ClientAuth: serverAuth, ClientAuth: serverAuth,
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
ClientCAs: certPool,
} }
server := grpc.NewServer( server := grpc.NewServer(
@ -61,7 +79,6 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
syncProvider := params.SyncProvider syncProvider := params.SyncProvider
connServer := ConnectionServer{ connServer := ConnectionServer{
serverConfig: serverConfig,
server: server, server: server,
ctrlProvider: ctrlProvider, ctrlProvider: ctrlProvider,
syncProvider: syncProvider, syncProvider: syncProvider,
@ -74,7 +91,6 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
// Listen for incoming requests. Returns an error if something went wrong. // Listen for incoming requests. Returns an error if something went wrong.
func (s *ConnectionServer) Listen() error { func (s *ConnectionServer) Listen() error {
rpc.RegisterMeshCtrlServerServer(s.server, s.ctrlProvider) rpc.RegisterMeshCtrlServerServer(s.server, s.ctrlProvider)
rpc.RegisterSyncServiceServer(s.server, s.syncProvider) rpc.RegisterSyncServiceServer(s.server, s.syncProvider)
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", s.Conf.GrpcPort)) lis, err := net.Listen("tcp", fmt.Sprintf(":%d", s.Conf.GrpcPort))

View File

@ -16,6 +16,11 @@ func (s *ConnectionManagerStub) AddConnection(endPoint string) (PeerConnection,
return mock, nil return mock, nil
} }
func (s *ConnectionManagerStub) RemoveConnection(endPoint string) error {
delete(s.Endpoints, endPoint)
return nil
}
func (s *ConnectionManagerStub) GetConnection(endPoint string) (PeerConnection, error) { func (s *ConnectionManagerStub) GetConnection(endPoint string) (PeerConnection, error) {
endpoint, ok := s.Endpoints[endPoint] endpoint, ok := s.Endpoints[endPoint]

View File

@ -9,17 +9,20 @@ import (
"strings" "strings"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// Route: represents a route within the data store
type Route struct { type Route struct {
// Destination the route is advertising
Destination string Destination string
Path []string // Path to the destination
Path []string
} }
// GetDestination implements mesh.Route. // GetDestination implements mesh.Route.
@ -158,8 +161,8 @@ type TwoPhaseStoreMeshManager struct {
IfName string IfName string
Client *wgctrl.Client Client *wgctrl.Client
LastClock uint64 LastClock uint64
conf *conf.WgConfiguration Conf *conf.WgConfiguration
daemonConf *conf.DaemonConfiguration DaemonConf *conf.DaemonConfiguration
store *TwoPhaseMap[string, MeshNode] store *TwoPhaseMap[string, MeshNode]
} }
@ -204,7 +207,6 @@ func (m *TwoPhaseStoreMeshManager) Save() []byte {
var buf bytes.Buffer var buf bytes.Buffer
enc := gob.NewEncoder(&buf) enc := gob.NewEncoder(&buf)
err := enc.Encode(*snapshot) err := enc.Encode(*snapshot)
if err != nil { if err != nil {
@ -249,7 +251,8 @@ func (m *TwoPhaseStoreMeshManager) SaveChanges() {
m.LastClock = clockValue m.LastClock = clockValue
} }
// UpdateTimeStamp: update the timestamp of the given node // UpdateTimeStamp: update the timestamp of the given node, causes a configuration refresh if the node
// is the leader causing all nodes to update their vector clocks
func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error { func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error {
if !m.store.Contains(nodeId) { if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
@ -265,7 +268,7 @@ func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error {
peerToUpdate := peers[0] peerToUpdate := peers[0]
if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.daemonConf.KeepAliveTime) { if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.DaemonConf.Heartbeat) {
m.store.Mark(peerToUpdate) m.store.Mark(peerToUpdate)
if len(peers) < 2 { if len(peers) < 2 {
@ -313,6 +316,8 @@ func (m *TwoPhaseStoreMeshManager) AddRoutes(nodeId string, routes ...mesh.Route
} }
} }
// Only add nodes on changes. Otherwise the node will advertise new
// information whenever they get new routes
if changes { if changes {
m.store.Put(nodeId, node) m.store.Put(nodeId, node)
} }
@ -320,7 +325,7 @@ func (m *TwoPhaseStoreMeshManager) AddRoutes(nodeId string, routes ...mesh.Route
return nil return nil
} }
// DeleteRoutes: deletes the routes from the node // RemoveRoute: deletes the routes from the given node
func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error { func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error {
if !m.store.Contains(nodeId) { if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
@ -336,6 +341,7 @@ func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Ro
for _, route := range routes { for _, route := range routes {
changes = true changes = true
logging.Log.WriteInfof("deleting: %s", route.GetDestination().String())
delete(node.Routes, route.GetDestination().String()) delete(node.Routes, route.GetDestination().String())
} }
@ -346,12 +352,12 @@ func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Ro
return nil return nil
} }
// GetSyncer: returns the automerge syncer for sync // GetSyncer: returns the bi-directionally synchroniser to merge documents
func (m *TwoPhaseStoreMeshManager) GetSyncer() mesh.MeshSyncer { func (m *TwoPhaseStoreMeshManager) GetSyncer() mesh.MeshSyncer {
return NewTwoPhaseSyncer(m) return NewTwoPhaseSyncer(m)
} }
// GetNode get a particular not within the mesh // GetNode: get a particular not within the mesh network
func (m *TwoPhaseStoreMeshManager) GetNode(nodeId string) (mesh.MeshNode, error) { func (m *TwoPhaseStoreMeshManager) GetNode(nodeId string) (mesh.MeshNode, error) {
if !m.store.Contains(nodeId) { if !m.store.Contains(nodeId) {
return nil, fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) return nil, fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
@ -379,7 +385,7 @@ func (m *TwoPhaseStoreMeshManager) SetDescription(nodeId string, description str
return nil return nil
} }
// SetAlias: set the alias of the nodeId // SetAlias: set the alias of the given node
func (m *TwoPhaseStoreMeshManager) SetAlias(nodeId string, alias string) error { func (m *TwoPhaseStoreMeshManager) SetAlias(nodeId string, alias string) error {
if !m.store.Contains(nodeId) { if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
@ -392,7 +398,7 @@ func (m *TwoPhaseStoreMeshManager) SetAlias(nodeId string, alias string) error {
return nil return nil
} }
// AddService: adds the service to the given node // AddService: adds a service to the given node
func (m *TwoPhaseStoreMeshManager) AddService(nodeId string, key string, value string) error { func (m *TwoPhaseStoreMeshManager) AddService(nodeId string, key string, value string) error {
if !m.store.Contains(nodeId) { if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
@ -404,19 +410,25 @@ func (m *TwoPhaseStoreMeshManager) AddService(nodeId string, key string, value s
return nil return nil
} }
// RemoveService: removes the service form the node. throws an error if the service does not exist // RemoveService: removes the service form a node, throws an error if the service does not exist
func (m *TwoPhaseStoreMeshManager) RemoveService(nodeId string, key string) error { func (m *TwoPhaseStoreMeshManager) RemoveService(nodeId string, key string) error {
if !m.store.Contains(nodeId) { if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
} }
node := m.store.Get(nodeId) node := m.store.Get(nodeId)
if _, ok := node.Services[key]; !ok {
return fmt.Errorf("datastore: node does not contain service %s", key)
}
delete(node.Services, key) delete(node.Services, key)
m.store.Put(nodeId, node) m.store.Put(nodeId, node)
return nil return nil
} }
// Prune: prunes all nodes that have not updated their timestamp in // Prune: prunes all nodes that have not updated their vector clock in a given amount
// of time
func (m *TwoPhaseStoreMeshManager) Prune() error { func (m *TwoPhaseStoreMeshManager) Prune() error {
m.store.Prune() m.store.Prune()
return nil return nil
@ -445,6 +457,7 @@ func (m *TwoPhaseStoreMeshManager) GetPeers() []string {
}) })
} }
// getRoutes: get all routes the target node is advertising
func (m *TwoPhaseStoreMeshManager) getRoutes(targetNode string) (map[string]Route, error) { func (m *TwoPhaseStoreMeshManager) getRoutes(targetNode string) (map[string]Route, error) {
if !m.store.Contains(targetNode) { if !m.store.Contains(targetNode) {
return nil, fmt.Errorf("getRoute: cannot get route %s does not exist", targetNode) return nil, fmt.Errorf("getRoute: cannot get route %s does not exist", targetNode)
@ -454,7 +467,8 @@ func (m *TwoPhaseStoreMeshManager) getRoutes(targetNode string) (map[string]Rout
return node.Routes, nil return node.Routes, nil
} }
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen // GetRoutes: Get all unique routes the target node is advertising.
// on conflicts the route with the least hop count is chosen
func (m *TwoPhaseStoreMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) { func (m *TwoPhaseStoreMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) {
node, err := m.GetNode(targetNode) node, err := m.GetNode(targetNode)
@ -498,7 +512,7 @@ func (m *TwoPhaseStoreMeshManager) GetRoutes(targetNode string) (map[string]mesh
return routes, nil return routes, nil
} }
// RemoveNode(): remove the node from the mesh // RemoveNode: remove the node from the mesh
func (m *TwoPhaseStoreMeshManager) RemoveNode(nodeId string) error { func (m *TwoPhaseStoreMeshManager) RemoveNode(nodeId string) error {
if !m.store.Contains(nodeId) { if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
@ -508,7 +522,8 @@ func (m *TwoPhaseStoreMeshManager) RemoveNode(nodeId string) error {
return nil return nil
} }
// GetConfiguration implements mesh.MeshProvider. // GetConfiguration gets the WireGuard configuration to use for this
// network
func (m *TwoPhaseStoreMeshManager) GetConfiguration() *conf.WgConfiguration { func (m *TwoPhaseStoreMeshManager) GetConfiguration() *conf.WgConfiguration {
return m.conf return m.Conf
} }

440
pkg/crdt/datastore_test.go Normal file
View File

@ -0,0 +1,440 @@
package crdt
import (
"net"
"slices"
"testing"
"time"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type TestParams struct {
manager mesh.MeshProvider
publicKey *wgtypes.Key
}
func setUpTests() *TestParams {
advertiseRoutes := false
advertiseDefaultRoute := false
role := conf.PEER_ROLE
discovery := conf.OUTGOING_IP_DISCOVERY
factory := &TwoPhaseMapFactory{
Config: &conf.DaemonConfiguration{
CertificatePath: "/somecertificatepath",
PrivateKeyPath: "/someprivatekeypath",
CaCertificatePath: "/somecacertificatepath",
SkipCertVerification: true,
GrpcPort: 0,
Timeout: 20,
Profile: false,
SyncInterval: 2,
Heartbeat: 10,
ClusterSize: 32,
InterClusterChance: 0.15,
Branch: 3,
InfectionCount: 3,
BaseConfiguration: conf.WgConfiguration{
IPDiscovery: &discovery,
AdvertiseRoutes: &advertiseRoutes,
AdvertiseDefaultRoute: &advertiseDefaultRoute,
Role: &role,
},
},
}
key, _ := wgtypes.GeneratePrivateKey()
mesh, _ := factory.CreateMesh(&mesh.MeshProviderFactoryParams{
DevName: "bob",
MeshId: "meshid123",
Client: nil,
Conf: &factory.Config.BaseConfiguration,
DaemonConf: factory.Config,
NodeID: "bob",
})
publicKey := key.PublicKey()
return &TestParams{
manager: mesh,
publicKey: &publicKey,
}
}
func getOurNode(testParams *TestParams) *MeshNode {
return &MeshNode{
HostEndpoint: "public-endpoint:8080",
WgEndpoint: "public-endpoint:21906",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d218/128",
PublicKey: testParams.publicKey.String(),
Timestamp: time.Now().Unix(),
Description: "A node that we are adding",
Type: "peer",
}
}
func getRandomNode() *MeshNode {
key, _ := wgtypes.GeneratePrivateKey()
publicKey := key.PublicKey()
return &MeshNode{
HostEndpoint: "public-endpoint:8081",
WgEndpoint: "public-endpoint:21907",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d234/128",
PublicKey: publicKey.String(),
Timestamp: time.Now().Unix(),
Description: "A node that we are adding",
Type: "peer",
}
}
func TestAddNodeAddsTheNodesToTheStore(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
if !testParams.manager.NodeExists(testParams.publicKey.String()) {
t.Fatalf(`node %s should have been added to the mesh network`, testParams.publicKey.String())
}
}
func TestAddNodeNodeAlreadyExistsReplacesTheNode(t *testing.T) {
TestAddNodeAddsTheNodesToTheStore(t)
TestAddNodeAddsTheNodesToTheStore(t)
}
func TestSaveThenLoad(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
bytes := testParams.manager.Save()
if err := testParams.manager.Load(bytes); err != nil {
t.Fatalf(`error caused by loading datastore: %s`, err.Error())
}
}
func TestHasChangesReturnsTrueWhenThereAreChangesInTheMesh(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
if !testParams.manager.HasChanges() {
t.Fatalf(`mesh has change but HasChanges returned false`)
}
testParams.manager.SetDescription(testParams.publicKey.String(), "Bob marley")
if !testParams.manager.HasChanges() {
t.Fatalf(`mesh has change but HasChanges returned false`)
}
testParams.manager.SaveChanges()
}
func TestHasChangesWhenThereAreNoChangesInTheMeshReturnsFalse(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.SaveChanges()
if testParams.manager.HasChanges() {
t.Fatalf(`mesh has no changes but HasChanges was true`)
}
testParams.manager.SetDescription(testParams.publicKey.String(), "Bob marley")
testParams.manager.SaveChanges()
if testParams.manager.HasChanges() {
t.Fatalf(`mesh has no changes but HasChanges was true`)
}
}
func TestUpdateTimeStampUpdatesTheTimeStampOfTheGivenNodeIfItIsTheLeader(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
before, _ := testParams.manager.GetNode(testParams.publicKey.String())
time.Sleep(1 * time.Second)
testParams.manager.UpdateTimeStamp(testParams.publicKey.String())
after, _ := testParams.manager.GetNode(testParams.publicKey.String())
if before.GetTimeStamp() >= after.GetTimeStamp() {
t.Fatalf(`before should not be after after`)
}
}
func TestUpdateTimeStampUpdatesTheTimeStampOfTheGivenNodeIfItIsNotLeader(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
newNode := getRandomNode()
newNode.PublicKey = "aaaaaaaaaa"
testParams.manager.AddNode(newNode)
before, _ := testParams.manager.GetNode(testParams.publicKey.String())
time.Sleep(1 * time.Second)
after, _ := testParams.manager.GetNode(testParams.publicKey.String())
if before.GetTimeStamp() != after.GetTimeStamp() {
t.Fatalf(`before and after should be the same`)
}
}
func TestAddRoutesAddsARouteToTheGivenMesh(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
_, destination, _ := net.ParseCIDR("0353:1da7:7f33:acc0:7a3f:6e55:912b:bc1f/64")
testParams.manager.AddRoutes(testParams.publicKey.String(), &mesh.RouteStub{
Destination: destination,
Path: make([]string, 0),
})
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
containsDestination := lib.Contains(node.GetRoutes(), func(r mesh.Route) bool {
return r.GetDestination().Contains(destination.IP)
})
if !containsDestination {
t.Fatalf(`route has not been added to the node`)
}
}
func TestRemoveRoutesWithdrawsRoutesFromTheMesh(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
_, destination, _ := net.ParseCIDR("0353:1da7:7f33:acc0:7a3f:6e55:912b:bc1f/64")
route := &mesh.RouteStub{
Destination: destination,
Path: make([]string, 0),
}
testParams.manager.AddRoutes(testParams.publicKey.String(), route)
testParams.manager.RemoveRoutes(testParams.publicKey.String(), route)
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
containsDestination := lib.Contains(node.GetRoutes(), func(r mesh.Route) bool {
return r.GetDestination().Contains(destination.IP)
})
if containsDestination {
t.Fatalf(`route has not been removed from the node`)
}
}
func TestGetNodeGetsTheNodeWhenItExists(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
if node == nil {
t.Fatalf(`node not found returned nil`)
}
}
func TestGetNodeReturnsNilWhenItDoesNotExist(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.RemoveNode(testParams.publicKey.String())
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
if node != nil {
t.Fatalf(`node found but should be nil`)
}
}
func TestNodeExistsReturnsFalseWhenNotExists(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.RemoveNode(testParams.publicKey.String())
if testParams.manager.NodeExists(testParams.publicKey.String()) {
t.Fatalf(`nodeexists should be false`)
}
}
func TestSetDescriptionReturnsErrorWhenNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.SetDescription("djdjdj", "djdsjkd")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestSetDescriptionSetsTheDescription(t *testing.T) {
testParams := setUpTests()
descriptionToSet := "djdsjkd"
testParams.manager.AddNode(getOurNode(testParams))
err := testParams.manager.SetDescription(testParams.publicKey.String(), descriptionToSet)
if err != nil {
t.Fatalf(`error %s thrown`, err.Error())
}
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
description := node.GetDescription()
if description != descriptionToSet {
t.Fatalf(`description was %s should be %s`, description, descriptionToSet)
}
}
func TestAliasNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.SetAlias("djdjdj", "djdsjkd")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestSetAliasSetsAlias(t *testing.T) {
testParams := setUpTests()
aliasToSet := "djdsjkd"
testParams.manager.AddNode(getOurNode(testParams))
err := testParams.manager.SetAlias(testParams.publicKey.String(), aliasToSet)
if err != nil {
t.Fatalf(`error %s thrown`, err.Error())
}
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
alias := node.GetAlias()
if alias != aliasToSet {
t.Fatalf(`description was %s should be %s`, alias, aliasToSet)
}
}
func TestAddServiceNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.AddService("djdjdj", "djdsjkd", "sddsds")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestAddServiceNodeExists(t *testing.T) {
testParams := setUpTests()
service := "djdsjkd"
serviceValue := "dsdsds"
testParams.manager.AddNode(getOurNode(testParams))
err := testParams.manager.AddService(testParams.publicKey.String(), service, serviceValue)
if err != nil {
t.Fatalf(`error %s thrown`, err.Error())
}
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
services := node.GetServices()
if value, ok := services[service]; !ok || value != serviceValue {
t.Fatalf(`service not added to the data store`)
}
}
func TestRemoveServiceDoesNotExists(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.RemoveService("djdjdj", "dsdssd")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestRemoveServiceServiceDoesNotExist(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
if err := testParams.manager.RemoveService(testParams.publicKey.String(), "dhsdh"); err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestGetPeersReturnsAllPeersInTheMesh(t *testing.T) {
testParams := setUpTests()
peer1 := getRandomNode()
peer2 := getRandomNode()
client := getRandomNode()
client.Type = "client"
testParams.manager.AddNode(peer1)
testParams.manager.AddNode(peer2)
testParams.manager.AddNode(client)
peers := testParams.manager.GetPeers()
slices.Sort(peers)
if len(peers) != 2 {
t.Fatalf(`there should be two peers in the mesh`)
}
peer1Pub, _ := peer1.GetPublicKey()
if !slices.Contains(peers, peer1Pub.String()) {
t.Fatalf(`peer1 not in the list`)
}
peer2Pub, _ := peer2.GetPublicKey()
if !slices.Contains(peers, peer2Pub.String()) {
t.Fatalf(`peer2 not in the list`)
}
}
func TestRemoveNodeReturnsErrorIfNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.RemoveNode("dsjdssjk")
if err == nil {
t.Fatalf(`error should have returned`)
}
}

View File

@ -4,34 +4,39 @@ import (
"fmt" "fmt"
"hash/fnv" "hash/fnv"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
) )
// TwoPhaseMapFactory: instantiate a new twophasemap
// datastore
type TwoPhaseMapFactory struct { type TwoPhaseMapFactory struct {
Config *conf.DaemonConfiguration Config *conf.DaemonConfiguration
} }
// CreateMesh: create a new mesh network
func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) { func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) {
return &TwoPhaseStoreMeshManager{ return &TwoPhaseStoreMeshManager{
MeshId: params.MeshId, MeshId: params.MeshId,
IfName: params.DevName, IfName: params.DevName,
Client: params.Client, Client: params.Client,
conf: params.Conf, Conf: params.Conf,
daemonConf: params.DaemonConf, DaemonConf: params.DaemonConf,
store: NewTwoPhaseMap[string, MeshNode](params.NodeID, func(s string) uint64 { store: NewTwoPhaseMap[string, MeshNode](params.NodeID, func(s string) uint64 {
h := fnv.New64a() h := fnv.New64a()
h.Write([]byte(s)) h.Write([]byte(s))
return h.Sum64() return h.Sum64()
}, uint64(3*f.Config.KeepAliveTime)), }, uint64(3*f.Config.Heartbeat)),
}, nil }, nil
} }
// MeshNodeFactory: create a new node in the mesh network
type MeshNodeFactory struct { type MeshNodeFactory struct {
Config conf.DaemonConfiguration Config conf.DaemonConfiguration
} }
// Build: build a new mesh network
func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode { func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode {
hostName := f.getAddress(params) hostName := f.getAddress(params)
@ -66,7 +71,7 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string
} else { } else {
ipFunc := lib.GetPublicIP ipFunc := lib.GetPublicIP
if *params.MeshConfig.IPDiscovery == conf.DNS_IP_DISCOVERY { if *params.MeshConfig.IPDiscovery == conf.OUTGOING_IP_DISCOVERY {
ipFunc = lib.GetOutboundIP ipFunc = lib.GetOutboundIP
} }

View File

@ -1,4 +1,4 @@
// crdt is a golang implementation of a crdt // crdt provides go implementations for crdts
package crdt package crdt
import ( import (
@ -6,6 +6,7 @@ import (
"sync" "sync"
) )
// Bucket: bucket represents a value in the grow only map
type Bucket[D any] struct { type Bucket[D any] struct {
Vector uint64 Vector uint64
Contents D Contents D
@ -19,6 +20,7 @@ type GMap[K cmp.Ordered, D any] struct {
clock *VectorClock[K] clock *VectorClock[K]
} }
// Put: put a new entry in the grow-only-map
func (g *GMap[K, D]) Put(key K, value D) { func (g *GMap[K, D]) Put(key K, value D) {
g.lock.Lock() g.lock.Lock()
@ -32,6 +34,8 @@ func (g *GMap[K, D]) Put(key K, value D) {
g.lock.Unlock() g.lock.Unlock()
} }
// Contains: returns whether or not the key is contained
// in the g-map
func (g *GMap[K, D]) Contains(key K) bool { func (g *GMap[K, D]) Contains(key K) bool {
return g.contains(g.clock.hashFunc(key)) return g.contains(g.clock.hashFunc(key))
} }
@ -64,11 +68,23 @@ func (g *GMap[K, D]) get(key uint64) Bucket[D] {
return bucket return bucket
} }
// Get: get the value associated with the given key
func (g *GMap[K, D]) Get(key K) D { func (g *GMap[K, D]) Get(key K) D {
if !g.Contains(key) {
var def D
return def
}
return g.get(g.clock.hashFunc(key)).Contents return g.get(g.clock.hashFunc(key)).Contents
} }
// Mark: marks the node, this means the status of the node
// is an undefined state
func (g *GMap[K, D]) Mark(key K) { func (g *GMap[K, D]) Mark(key K) {
if !g.Contains(key) {
return
}
g.lock.Lock() g.lock.Lock()
bucket := g.contents[g.clock.hashFunc(key)] bucket := g.contents[g.clock.hashFunc(key)]
bucket.Gravestone = true bucket.Gravestone = true
@ -76,7 +92,7 @@ func (g *GMap[K, D]) Mark(key K) {
g.lock.Unlock() g.lock.Unlock()
} }
// IsMarked: returns true if the node is marked // IsMarked: returns true if the node is marked (in an undefined state)
func (g *GMap[K, D]) IsMarked(key K) bool { func (g *GMap[K, D]) IsMarked(key K) bool {
marked := false marked := false
@ -89,10 +105,10 @@ func (g *GMap[K, D]) IsMarked(key K) bool {
} }
g.lock.RUnlock() g.lock.RUnlock()
return marked return marked
} }
// Keys: return all the keys in the grow-only map
func (g *GMap[K, D]) Keys() []uint64 { func (g *GMap[K, D]) Keys() []uint64 {
g.lock.RLock() g.lock.RLock()
@ -108,6 +124,7 @@ func (g *GMap[K, D]) Keys() []uint64 {
return contents return contents
} }
// Save: saves the grow only map
func (g *GMap[K, D]) Save() map[uint64]Bucket[D] { func (g *GMap[K, D]) Save() map[uint64]Bucket[D] {
buckets := make(map[uint64]Bucket[D]) buckets := make(map[uint64]Bucket[D])
g.lock.RLock() g.lock.RLock()
@ -120,6 +137,7 @@ func (g *GMap[K, D]) Save() map[uint64]Bucket[D] {
return buckets return buckets
} }
// SaveWithKeys: get all the values corresponding with the provided keys
func (g *GMap[K, D]) SaveWithKeys(keys []uint64) map[uint64]Bucket[D] { func (g *GMap[K, D]) SaveWithKeys(keys []uint64) map[uint64]Bucket[D] {
buckets := make(map[uint64]Bucket[D]) buckets := make(map[uint64]Bucket[D])
g.lock.RLock() g.lock.RLock()
@ -132,6 +150,7 @@ func (g *GMap[K, D]) SaveWithKeys(keys []uint64) map[uint64]Bucket[D] {
return buckets return buckets
} }
// GetClock: get all the vector clocks in the g_map
func (g *GMap[K, D]) GetClock() map[uint64]uint64 { func (g *GMap[K, D]) GetClock() map[uint64]uint64 {
clock := make(map[uint64]uint64) clock := make(map[uint64]uint64)
g.lock.RLock() g.lock.RLock()
@ -144,6 +163,7 @@ func (g *GMap[K, D]) GetClock() map[uint64]uint64 {
return clock return clock
} }
// GetHash: get the hash of the g_map representing its state
func (g *GMap[K, D]) GetHash() uint64 { func (g *GMap[K, D]) GetHash() uint64 {
hash := uint64(0) hash := uint64(0)
@ -157,6 +177,7 @@ func (g *GMap[K, D]) GetHash() uint64 {
return hash return hash
} }
// Prune: prune all stale entries
func (g *GMap[K, D]) Prune() { func (g *GMap[K, D]) Prune() {
stale := g.clock.getStale() stale := g.clock.getStale()
g.lock.Lock() g.lock.Lock()

224
pkg/crdt/g_map_test.go Normal file
View File

@ -0,0 +1,224 @@
// crdt_test unit tests the crdt implementations
package crdt
import (
"hash/fnv"
"slices"
"testing"
"time"
"github.com/tim-beatham/smegmesh/pkg/lib"
)
func NewGmap() *GMap[string, bool] {
vectorClock := NewVectorClock("a", func(key string) uint64 {
hash := fnv.New64a()
hash.Write([]byte(key))
return hash.Sum64()
}, 1) // 1 second stale time
gMap := NewGMap[string, bool](vectorClock)
return gMap
}
func TestGMapPutInsertsItems(t *testing.T) {
gMap := NewGmap()
gMap.Put("bruh1234", true)
if !gMap.Contains("bruh1234") {
t.Fatalf(`value not added to map`)
}
}
func TestGMapPutReplacesItems(t *testing.T) {
gMap := NewGmap()
gMap.Put("bruh1234", true)
gMap.Put("bruh1234", false)
value := gMap.Get("bruh1234")
if value {
t.Fatalf(`value should ahve been replaced to false`)
}
}
func TestContainsValueNotPresent(t *testing.T) {
gMap := NewGmap()
if gMap.Contains("sdhjsdhsdj") {
t.Fatalf(`value should not be present in the map`)
}
}
func TestContainsValuePresent(t *testing.T) {
gMap := NewGmap()
key := "hehehehe"
gMap.Put(key, false)
if !gMap.Contains(key) {
t.Fatalf(`%s should not be present in the map`, key)
}
}
func TestGMapGetNotPresentReturnsError(t *testing.T) {
gMap := NewGmap()
value := gMap.Get("bruh123")
if value != false {
t.Fatalf(`value should be default type false`)
}
}
func TestGMapGetReturnsValue(t *testing.T) {
gMap := NewGmap()
gMap.Put("bobdylan", true)
value := gMap.Get("bobdylan")
if !value {
t.Fatalf("value should be true but was false")
}
}
func TestMarkMarksTheValue(t *testing.T) {
gMap := NewGmap()
gMap.Put("hello123", true)
gMap.Mark("hello123")
if !gMap.IsMarked("hello123") {
t.Fatal(`hello123 should be marked`)
}
}
func TestMarkValueNotPresent(t *testing.T) {
gMap := NewGmap()
gMap.Mark("ok123456")
}
func TestKeysMapEmpty(t *testing.T) {
gMap := NewGmap()
keys := gMap.Keys()
if len(keys) != 0 {
t.Fatal(`list of keys was not empty but should be empty`)
}
}
func TestKeysMapReturnsKeysInMap(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
keys := gMap.Keys()
if len(keys) != 3 {
t.Fatal(`key length should be 3`)
}
}
func TestSaveMapEmptyReturnsEmptyMap(t *testing.T) {
gMap := NewGmap()
saveMap := gMap.Save()
if len(saveMap) != 0 {
t.Fatal(`saves should be empty`)
}
}
func TestSaveMapReturnsMapOfBuckets(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
saveMap := gMap.Save()
if len(saveMap) != 3 {
t.Fatalf(`save length should be 3`)
}
}
func TestSaveWithKeysNoKeysReturnsEmptyBucket(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
saveMap := gMap.SaveWithKeys([]uint64{})
if len(saveMap) != 0 {
t.Fatalf(`save map should be empty`)
}
}
func TestSaveWithKeysReturnsIntersection(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
clock := lib.MapKeys(gMap.GetClock())
clock = clock[:len(clock)-1]
values := gMap.SaveWithKeys(clock)
if len(values) != len(clock) {
t.Fatalf(`intersection not returned`)
}
}
func TestGetClockMapEmptyReturnsEmptyClock(t *testing.T) {
gMap := NewGmap()
clocks := gMap.GetClock()
if len(clocks) != 0 {
t.Fatalf(`vector clock is not empty`)
}
}
func TestGetClockReturnsAllCLocks(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
clocks := lib.MapValues(gMap.GetClock())
slices.Sort(clocks)
if !slices.Equal([]uint64{0, 1, 2}, clocks) {
t.Fatalf(`clocks are invalid`)
}
}
func TestGetHashChangesHashOnValueAdded(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
prevHash := gMap.GetHash()
gMap.Put("b", true)
if prevHash == gMap.GetHash() {
t.Fatalf(`hash should be different`)
}
}
func TestPruneGarbageCollectsValuesThatHaveNotBeenUpdated(t *testing.T) {
gMap := NewGmap()
gMap.clock.Put("c", 12)
gMap.Put("c", false)
gMap.Put("a", false)
time.Sleep(4 * time.Second)
gMap.Put("a", true)
gMap.Prune()
if gMap.Contains("c") {
t.Fatalf(`a should have been pruned`)
}
}

View File

@ -3,9 +3,10 @@ package crdt
import ( import (
"cmp" "cmp"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
) )
// TwoPhaseMap: comprises of two grow-only maps
type TwoPhaseMap[K cmp.Ordered, D any] struct { type TwoPhaseMap[K cmp.Ordered, D any] struct {
addMap *GMap[K, D] addMap *GMap[K, D]
removeMap *GMap[K, bool] removeMap *GMap[K, bool]
@ -23,7 +24,7 @@ func (m *TwoPhaseMap[K, D]) Contains(key K) bool {
return m.contains(m.Clock.hashFunc(key)) return m.contains(m.Clock.hashFunc(key))
} }
// Contains checks whether the value exists in the map // contains: checks whether the key exists in the map
func (m *TwoPhaseMap[K, D]) contains(key uint64) bool { func (m *TwoPhaseMap[K, D]) contains(key uint64) bool {
if !m.addMap.contains(key) { if !m.addMap.contains(key) {
return false return false
@ -40,6 +41,7 @@ func (m *TwoPhaseMap[K, D]) contains(key uint64) bool {
return addValue.Vector >= removeValue.Vector return addValue.Vector >= removeValue.Vector
} }
// Get: get the value corresponding with the given key
func (m *TwoPhaseMap[K, D]) Get(key K) D { func (m *TwoPhaseMap[K, D]) Get(key K) D {
var result D var result D
@ -60,18 +62,19 @@ func (m *TwoPhaseMap[K, D]) get(key uint64) D {
return m.addMap.get(key).Contents return m.addMap.get(key).Contents
} }
// Put places the key K in the map // Put: places the key K in the map with the associated data D
func (m *TwoPhaseMap[K, D]) Put(key K, data D) { func (m *TwoPhaseMap[K, D]) Put(key K, data D) {
msgSequence := m.Clock.IncrementClock() msgSequence := m.Clock.IncrementClock()
m.Clock.Put(key, msgSequence) m.Clock.Put(key, msgSequence)
m.addMap.Put(key, data) m.addMap.Put(key, data)
} }
// Mark: marks the status of the node as undetermiend
func (m *TwoPhaseMap[K, D]) Mark(key K) { func (m *TwoPhaseMap[K, D]) Mark(key K) {
m.addMap.Mark(key) m.addMap.Mark(key)
} }
// Remove removes the value from the map // Remove: removes the value from the map
func (m *TwoPhaseMap[K, D]) Remove(key K) { func (m *TwoPhaseMap[K, D]) Remove(key K) {
m.removeMap.Put(key, true) m.removeMap.Put(key, true)
} }
@ -92,6 +95,7 @@ func (m *TwoPhaseMap[K, D]) keys() []uint64 {
return keys return keys
} }
// AsList: convert the map to a list
func (m *TwoPhaseMap[K, D]) AsList() []D { func (m *TwoPhaseMap[K, D]) AsList() []D {
theList := make([]D, 0) theList := make([]D, 0)
@ -104,6 +108,8 @@ func (m *TwoPhaseMap[K, D]) AsList() []D {
return theList return theList
} }
// Snapshot: convert the map into an immutable snapshot.
// contains the contents of the add and remove map
func (m *TwoPhaseMap[K, D]) Snapshot() *TwoPhaseMapSnapshot[K, D] { func (m *TwoPhaseMap[K, D]) Snapshot() *TwoPhaseMapSnapshot[K, D] {
return &TwoPhaseMapSnapshot[K, D]{ return &TwoPhaseMapSnapshot[K, D]{
Add: m.addMap.Save(), Add: m.addMap.Save(),
@ -111,6 +117,8 @@ func (m *TwoPhaseMap[K, D]) Snapshot() *TwoPhaseMapSnapshot[K, D] {
} }
} }
// SnapshotFromState: create a snapshot of the intersection of values provided
// in the given state
func (m *TwoPhaseMap[K, D]) SnapShotFromState(state *TwoPhaseMapState[K]) *TwoPhaseMapSnapshot[K, D] { func (m *TwoPhaseMap[K, D]) SnapShotFromState(state *TwoPhaseMapState[K]) *TwoPhaseMapSnapshot[K, D] {
addKeys := lib.MapKeys(state.AddContents) addKeys := lib.MapKeys(state.AddContents)
removeKeys := lib.MapKeys(state.RemoveContents) removeKeys := lib.MapKeys(state.RemoveContents)
@ -121,12 +129,18 @@ func (m *TwoPhaseMap[K, D]) SnapShotFromState(state *TwoPhaseMapState[K]) *TwoPh
} }
} }
// TwoPhaseMapState: encapsulates the state of the map
// without specifying the data that is stored
type TwoPhaseMapState[K cmp.Ordered] struct { type TwoPhaseMapState[K cmp.Ordered] struct {
// Vectors: the vector ID of each process
Vectors map[uint64]uint64 Vectors map[uint64]uint64
// AddContents: the contents of the add map
AddContents map[uint64]uint64 AddContents map[uint64]uint64
// RemoveContents: the contents of the remove map
RemoveContents map[uint64]uint64 RemoveContents map[uint64]uint64
} }
// IsMarked: returns true if the given value is marked in an undetermined state
func (m *TwoPhaseMap[K, D]) IsMarked(key K) bool { func (m *TwoPhaseMap[K, D]) IsMarked(key K) bool {
return m.addMap.IsMarked(key) return m.addMap.IsMarked(key)
} }
@ -151,6 +165,8 @@ func (m *TwoPhaseMap[K, D]) GenerateMessage() *TwoPhaseMapState[K] {
} }
} }
// Difference: compute the set difference between the two states.
// highestStale represents the highest vector clock that has been marked as stale
func (m *TwoPhaseMapState[K]) Difference(highestStale uint64, state *TwoPhaseMapState[K]) *TwoPhaseMapState[K] { func (m *TwoPhaseMapState[K]) Difference(highestStale uint64, state *TwoPhaseMapState[K]) *TwoPhaseMapState[K] {
mapState := &TwoPhaseMapState[K]{ mapState := &TwoPhaseMapState[K]{
AddContents: make(map[uint64]uint64), AddContents: make(map[uint64]uint64),
@ -176,6 +192,7 @@ func (m *TwoPhaseMapState[K]) Difference(highestStale uint64, state *TwoPhaseMap
return mapState return mapState
} }
// Merge: merge a snapshot into the map
func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) { func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) {
for key, value := range snapshot.Add { for key, value := range snapshot.Add {
// Gravestone is local only to that node. // Gravestone is local only to that node.
@ -190,6 +207,7 @@ func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) {
} }
} }
// Prune: garbage collect all stale entries in the map
func (m *TwoPhaseMap[K, D]) Prune() { func (m *TwoPhaseMap[K, D]) Prune() {
m.addMap.Prune() m.addMap.Prune()
m.removeMap.Prune() m.removeMap.Prune()

View File

@ -4,7 +4,7 @@ import (
"bytes" "bytes"
"encoding/gob" "encoding/gob"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
) )
type SyncState int type SyncState int

View File

@ -0,0 +1,214 @@
package crdt
import (
"hash/fnv"
"slices"
"testing"
)
func NewMap(processId string) *TwoPhaseMap[string, string] {
theMap := NewTwoPhaseMap[string, string](processId, func(key string) uint64 {
hash := fnv.New64a()
hash.Write([]byte(key))
return hash.Sum64()
}, 1)
return theMap
}
func TestTwoPhaseMapEmpty(t *testing.T) {
theMap := NewMap("a")
if theMap.Contains("a") {
t.Fatalf(`a should not be present in the map`)
}
}
func TestTwoPhaseMapValuePresent(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "")
if !theMap.Contains("a") {
t.Fatalf(`should be present within the map`)
}
}
func TestTwoPhaseMapValueNotPresent(t *testing.T) {
theMap := NewMap("a")
theMap.Put("b", "")
if theMap.Contains("a") {
t.Fatalf(`a should not be present in the map`)
}
}
func TestTwoPhaseMapPutThenRemove(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "")
theMap.Remove("a")
if theMap.Contains("a") {
t.Fatalf(`a should not be present within the map`)
}
}
func TestTwoPhaseMapPutThenRemoveThenPut(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "")
theMap.Remove("a")
theMap.Put("a", "")
if !theMap.Contains("a") {
t.Fatalf(`a should be present within the map`)
}
}
func TestMarkMarksTheValueIn2PMap(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "")
theMap.Mark("a")
if !theMap.IsMarked("a") {
t.Fatalf(`a should be marked`)
}
}
func TestAsListReturnsItemsInList(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "bob")
theMap.Put("b", "dylan")
keys := theMap.AsList()
slices.Sort(keys)
if !slices.Equal([]string{"bob", "dylan"}, keys) {
t.Fatalf(`values should be bob, dylan`)
}
}
func TestSnapShotRemoveMapEmpty(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "bob")
theMap.Put("b", "dylan")
snapshot := theMap.Snapshot()
if len(snapshot.Add) != 2 {
t.Fatalf(`add values length should be 2`)
}
if len(snapshot.Remove) != 0 {
t.Fatalf(`remove map length should be 0`)
}
}
func TestSnapshotMapEmpty(t *testing.T) {
theMap := NewMap("a")
snapshot := theMap.Snapshot()
if len(snapshot.Add) != 0 || len(snapshot.Remove) != 0 {
t.Fatalf(`snapshot length should be 0`)
}
}
func TestSnapShotFromStateReturnsIntersection(t *testing.T) {
map1 := NewMap("a")
map1.Put("a", "heyy")
map2 := NewMap("b")
map2.Put("b", "hmmm")
message := map2.GenerateMessage()
snapShot := map1.SnapShotFromState(message)
if len(snapShot.Add) != 1 {
t.Fatalf(`add length should be 1`)
}
if len(snapShot.Remove) != 0 {
t.Fatalf(`remove length should be 0`)
}
}
func TestGetHashDifferentOnChange(t *testing.T) {
theMap := NewMap("a")
prevHash := theMap.GetHash()
theMap.Put("b", "hmmhmhmh")
if prevHash == theMap.GetHash() {
t.Fatalf(`hashes should not be the same`)
}
}
func TestGenerateMessageReturnsClocks(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "hmm")
theMap.Put("b", "hmm")
theMap.Remove("a")
message := theMap.GenerateMessage()
if len(message.AddContents) != 2 {
t.Fatalf(`two items added add should be 2`)
}
if len(message.RemoveContents) != 1 {
t.Fatalf(`a was removed remove map should be length 1`)
}
}
func TestDifferenceReturnsDifferenceOfMaps(t *testing.T) {
map1 := NewMap("a")
map1.Put("a", "ssms")
map1.Put("b", "sdmdsmd")
map2 := NewMap("b")
map2.Put("d", "eek")
map2.Put("c", "meh")
message1 := map1.GenerateMessage()
message2 := map2.GenerateMessage()
difference := message1.Difference(0, message2)
if len(difference.AddContents) != 2 {
t.Fatalf(`d and c are not in map1 they should be in add contents`)
}
if len(difference.RemoveContents) != 0 {
t.Fatalf(`remove should be empty`)
}
}
func TestMergeMergesValuesThatAreGreaterThanCurrentClock(t *testing.T) {
map1 := NewMap("a")
map1.Put("a", "ssms")
map1.Put("b", "sdmdsmd")
map2 := NewMap("b")
map2.Put("d", "eek")
map2.Put("c", "meh")
message1 := map1.GenerateMessage()
message2 := map2.GenerateMessage()
difference := message1.Difference(0, message2)
state := map2.SnapShotFromState(difference)
map1.Merge(*state)
if !map1.Contains("d") {
t.Fatalf(`d should be in the map`)
}
if !map2.Contains("c") {
t.Fatalf(`c should be in the map`)
}
}

View File

@ -5,9 +5,12 @@ import (
"sync" "sync"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
) )
// VectorBucket: represents a vector clock in the bucket
// recording both the time changes were last seen
// and when the lastUpdate epoch was recorded
type VectorBucket struct { type VectorBucket struct {
// clock current value of the node's clock // clock current value of the node's clock
clock uint64 clock uint64
@ -15,8 +18,9 @@ type VectorBucket struct {
lastUpdate uint64 lastUpdate uint64
} }
// Vector clock defines an abstract data type // VectorClock: defines an abstract data type
// for a vector clock implementation // for a vector clock implementation. Including a mechanism to
// garbage collect stale entries
type VectorClock[K cmp.Ordered] struct { type VectorClock[K cmp.Ordered] struct {
vectors map[uint64]*VectorBucket vectors map[uint64]*VectorBucket
lock sync.RWMutex lock sync.RWMutex
@ -62,6 +66,7 @@ func (m *VectorClock[K]) GetHash() uint64 {
return hash return hash
} }
// Merge: merge two clocks together
func (m *VectorClock[K]) Merge(vectors map[uint64]uint64) { func (m *VectorClock[K]) Merge(vectors map[uint64]uint64) {
for key, value := range vectors { for key, value := range vectors {
m.put(key, value) m.put(key, value)
@ -97,6 +102,7 @@ func (m *VectorClock[K]) GetStaleCount() uint64 {
return staleCount return staleCount
} }
// Prune: prunes all stale entries in the vector clock
func (m *VectorClock[K]) Prune() { func (m *VectorClock[K]) Prune() {
stale := m.getStale() stale := m.getStale()
@ -109,6 +115,8 @@ func (m *VectorClock[K]) Prune() {
m.lock.Unlock() m.lock.Unlock()
} }
// GetTimeStamp: get the last time the node was updated in UNIX
// epoch time
func (m *VectorClock[K]) GetTimestamp(processId K) uint64 { func (m *VectorClock[K]) GetTimestamp(processId K) uint64 {
m.lock.RLock() m.lock.RLock()
@ -118,6 +126,8 @@ func (m *VectorClock[K]) GetTimestamp(processId K) uint64 {
return lastUpdate return lastUpdate
} }
// Put: places the key with vector clock in the clock of the given
// process
func (m *VectorClock[K]) Put(key K, value uint64) { func (m *VectorClock[K]) Put(key K, value uint64) {
m.put(m.hashFunc(key), value) m.put(m.hashFunc(key), value)
} }
@ -133,7 +143,8 @@ func (m *VectorClock[K]) put(key uint64, value uint64) {
} }
// Make sure that entries that were garbage collected don't get // Make sure that entries that were garbage collected don't get
// addded back // highestStale represents the highest vector clock that has been
// invalidated
if value > clockValue && value > m.highestStale { if value > clockValue && value > m.highestStale {
newBucket := VectorBucket{ newBucket := VectorBucket{
clock: value, clock: value,
@ -145,6 +156,7 @@ func (m *VectorClock[K]) put(key uint64, value uint64) {
m.lock.Unlock() m.lock.Unlock()
} }
// GetClock: serialize the vector clock into an immutable map
func (m *VectorClock[K]) GetClock() map[uint64]uint64 { func (m *VectorClock[K]) GetClock() map[uint64]uint64 {
clock := make(map[uint64]uint64) clock := make(map[uint64]uint64)

View File

@ -1,16 +1,17 @@
package ctrlserver package ctrlserver
import ( import (
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn" "github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/crdt" "github.com/tim-beatham/smegmesh/pkg/crdt"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query" "github.com/tim-beatham/smegmesh/pkg/query"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/smegmesh/pkg/rpc"
"github.com/tim-beatham/wgmesh/pkg/wg" "github.com/tim-beatham/smegmesh/pkg/sync"
"github.com/tim-beatham/smegmesh/pkg/wg"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
) )
@ -21,7 +22,6 @@ type NewCtrlServerParams struct {
CtrlProvider rpc.MeshCtrlServerServer CtrlProvider rpc.MeshCtrlServerServer
SyncProvider rpc.SyncServiceServer SyncProvider rpc.SyncServiceServer
Querier query.Querier Querier query.Querier
OnDelete func(mesh.MeshProvider)
} }
// Create a new instance of the MeshCtrlServer or error if the // Create a new instance of the MeshCtrlServer or error if the
@ -34,12 +34,16 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
nodeFactory := &crdt.MeshNodeFactory{ nodeFactory := &crdt.MeshNodeFactory{
Config: *params.Conf, Config: *params.Conf,
} }
idGenerator := &lib.IDNameGenerator{} idGenerator := &lib.ShortIDGenerator{}
ipAllocator := &ip.ULABuilder{} ipAllocator := &ip.ULABuilder{}
interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client) interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client)
ctrlServer.timers = make([]*lib.Timer, 0)
configApplyer := mesh.NewWgMeshConfigApplyer() configApplyer := mesh.NewWgMeshConfigApplyer()
var syncer sync.Syncer
meshManagerParams := &mesh.NewMeshManagerParams{ meshManagerParams := &mesh.NewMeshManagerParams{
Conf: *params.Conf, Conf: *params.Conf,
Client: params.Client, Client: params.Client,
@ -49,7 +53,13 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
IPAllocator: ipAllocator, IPAllocator: ipAllocator,
InterfaceManipulator: interfaceManipulator, InterfaceManipulator: interfaceManipulator,
ConfigApplyer: configApplyer, ConfigApplyer: configApplyer,
OnDelete: params.OnDelete, OnDelete: func(mesh mesh.MeshProvider) {
_, err := syncer.Sync(mesh)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
},
} }
ctrlServer.MeshManager = mesh.NewMeshManager(meshManagerParams) ctrlServer.MeshManager = mesh.NewMeshManager(meshManagerParams)
@ -83,13 +93,40 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
return nil, err return nil, err
} }
syncer = sync.NewSyncer(&sync.NewSyncerParams{
MeshManager: ctrlServer.MeshManager,
ConnectionManager: ctrlServer.ConnectionManager,
Configuration: params.Conf,
})
// Check any syncs every 1 second
syncTimer := lib.NewTimer(func() error {
err = syncer.SyncMeshes()
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
return nil
}, 1)
heartbeatTimer := lib.NewTimer(func() error {
logging.Log.WriteInfof("checking heartbeat")
return ctrlServer.MeshManager.UpdateTimeStamp()
}, params.Conf.Heartbeat)
ctrlServer.timers = append(ctrlServer.timers, syncTimer, heartbeatTimer)
ctrlServer.Querier = query.NewJmesQuerier(ctrlServer.MeshManager) ctrlServer.Querier = query.NewJmesQuerier(ctrlServer.MeshManager)
ctrlServer.ConnectionServer = connServer ctrlServer.ConnectionServer = connServer
for _, timer := range ctrlServer.timers {
go timer.Run()
}
return ctrlServer, nil return ctrlServer, nil
} }
func (s *MeshCtrlServer) GetConfiguration() *conf.DaemonConfiguration { func (s *MeshCtrlServer) GetConfiguration() *conf.DaemonConfiguration {
return s.Conf return s.Conf
} }
@ -124,5 +161,13 @@ func (s *MeshCtrlServer) Close() error {
logging.Log.WriteErrorf(err.Error()) logging.Log.WriteErrorf(err.Error())
} }
for _, timer := range s.timers {
err := timer.Stop()
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
}
return nil return nil
} }

View File

@ -4,21 +4,23 @@ import (
"net" "net"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn" "github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query" "github.com/tim-beatham/smegmesh/pkg/query"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// MeshRoute: represents a route in the mesh that is
// available to client applications
type MeshRoute struct { type MeshRoute struct {
Destination string Destination string
Path []string Path []string
} }
// Represents the WireGuard configuration attached to the node // WireGuardStats: Represents the WireGuard configuration attached to the node
type WireGuardStats struct { type WireGuardStats struct {
AllowedIPs []string AllowedIPs []string
TransmitBytes int64 TransmitBytes int64
@ -26,7 +28,8 @@ type WireGuardStats struct {
PersistentKeepAliveInterval time.Duration PersistentKeepAliveInterval time.Duration
} }
// Represents a WireGuard MeshNode // MeshNode: represents a node in the WireGuard mesh that can be
// sent to ip chandlers
type MeshNode struct { type MeshNode struct {
HostEndpoint string HostEndpoint string
WgEndpoint string WgEndpoint string
@ -40,12 +43,13 @@ type MeshNode struct {
Stats WireGuardStats Stats WireGuardStats
} }
// Represents a WireGuard Mesh // Mesh: Represents a WireGuard Mesh network that can be sent
// along ipc to client frameworks
type Mesh struct { type Mesh struct {
SharedKey *wgtypes.Key
Nodes map[string]MeshNode Nodes map[string]MeshNode
} }
// CtrlServer: Encapsulates th ctrlserver
type CtrlServer interface { type CtrlServer interface {
GetConfiguration() *conf.DaemonConfiguration GetConfiguration() *conf.DaemonConfiguration
GetClient() *wgctrl.Client GetClient() *wgctrl.Client
@ -55,7 +59,7 @@ type CtrlServer interface {
GetConnectionManager() conn.ConnectionManager GetConnectionManager() conn.ConnectionManager
} }
// Represents a ctrlserver to be used in WireGuard // MeshCtrlServer: Represents a ctrlserver to be used in WireGuard
type MeshCtrlServer struct { type MeshCtrlServer struct {
Client *wgctrl.Client Client *wgctrl.Client
MeshManager mesh.MeshManager MeshManager mesh.MeshManager
@ -63,6 +67,7 @@ type MeshCtrlServer struct {
ConnectionServer *conn.ConnectionServer ConnectionServer *conn.ConnectionServer
Conf *conf.DaemonConfiguration Conf *conf.DaemonConfiguration
Querier query.Querier Querier query.Querier
timers []*lib.Timer
} }
// NewCtrlNode create an instance of a ctrl node to send over an // NewCtrlNode create an instance of a ctrl node to send over an

View File

@ -1,10 +1,10 @@
package ctrlserver package ctrlserver
import ( import (
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn" "github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query" "github.com/tim-beatham/smegmesh/pkg/query"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
) )

View File

@ -1,24 +1,22 @@
// smegdns: example of how to implement dns in the mesh
package smegdns package smegdns
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
"net/rpc"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/smegmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/query" "github.com/tim-beatham/smegmesh/pkg/query"
) )
const SockAddr = "/tmp/wgmesh_ipc.sock"
const MeshRegularExpression = `(?P<meshId>.+)\.(?P<alias>.+)\.smeg\.` const MeshRegularExpression = `(?P<meshId>.+)\.(?P<alias>.+)\.smeg\.`
type DNSHandler struct { type DNSHandler struct {
client *rpc.Client client *ipc.SmegmeshIpc
server *dns.Server server *dns.Server
} }
@ -27,7 +25,7 @@ type DNSHandler struct {
func (d *DNSHandler) queryMesh(meshId, alias string) net.IP { func (d *DNSHandler) queryMesh(meshId, alias string) net.IP {
var reply string var reply string
err := d.client.Call("IpcHandler.Query", &ipc.QueryMesh{ err := d.client.Query(ipc.QueryMesh{
MeshId: meshId, MeshId: meshId,
Query: fmt.Sprintf("[?alias == '%s'] | [0]", alias), Query: fmt.Sprintf("[?alias == '%s'] | [0]", alias),
}, &reply) }, &reply)
@ -48,6 +46,7 @@ func (d *DNSHandler) queryMesh(meshId, alias string) net.IP {
return ip return ip
} }
// handleQuery: handles a DNS query
func (d *DNSHandler) handleQuery(m *dns.Msg) { func (d *DNSHandler) handleQuery(m *dns.Msg) {
for _, q := range m.Question { for _, q := range m.Question {
switch q.Qtype { switch q.Qtype {
@ -75,6 +74,7 @@ func (d *DNSHandler) handleQuery(m *dns.Msg) {
} }
} }
// handleDNS query: handle a DNS request
func (h *DNSHandler) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { func (h *DNSHandler) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
msg := new(dns.Msg) msg := new(dns.Msg)
msg.SetReply(r) msg.SetReply(r)
@ -97,7 +97,7 @@ func (h *DNSHandler) Close() error {
} }
func NewDns(udpPort int) (*DNSHandler, error) { func NewDns(udpPort int) (*DNSHandler, error) {
client, err := rpc.DialHTTP("unix", SockAddr) client, err := ipc.NewClientIpc()
if err != nil { if err != nil {
return nil, err return nil, err

249
pkg/dot/dot.go Normal file
View File

@ -0,0 +1,249 @@
// Graph allows the definition of a DOT graph in golang
package graph
import (
"fmt"
"hash/fnv"
"strings"
"github.com/tim-beatham/smegmesh/pkg/lib"
)
type GraphType string
type Shape string
const (
GRAPH GraphType = "graph"
DIGRAPH GraphType = "digraph"
)
const (
CIRCLE Shape = "circle"
STAR Shape = "star"
HEXAGON Shape = "hexagon"
PARALLELOGRAM Shape = "parallelogram"
)
type Graph interface {
Dottable
GetType() GraphType
}
// Cluster: represents a subgraph in the graphs
type Cluster struct {
Type GraphType
Name string
Label string
nodes map[string]*Node
edges map[string]Edge
}
// RootGraph: Represents the top level graph
type RootGraph struct {
Type GraphType
Label string
nodes map[string]*Node
clusters map[string]*Cluster
edges map[string]Edge
}
// Node: represents a graphviz not
type Node struct {
Name string
Label string
Shape Shape
Size int
}
// Edge: represents an edge between adjacent nodes
type Edge interface {
Dottable
}
// DirectEdge: contains a directed edge between any two nodes
type DirectedEdge struct {
Name string
Label string
From string
To string
}
// UndirectedEdge: contains an undirected edge between any two
// nodes
type UndirectedEdge struct {
Name string
Label string
From string
To string
}
// Dottable means an implementer can convert the struct to DOT representation
type Dottable interface {
GetDOT() (string, error)
}
// PutNode: puts a node in the root graph
func (g *RootGraph) PutNode(name, label string, size int, shape Shape) error {
_, exists := g.nodes[name]
if exists {
// If exists no need to add the ndoe
return nil
}
g.nodes[name] = &Node{Name: name, Label: label, Size: size, Shape: shape}
return nil
}
// PutCluster: puts a cluster in the root graph
func (g *RootGraph) PutCluster(graph *Cluster) {
g.clusters[graph.Label] = graph
}
func writeContituents[D Dottable](result *strings.Builder, elements ...D) error {
for _, node := range elements {
dot, err := node.GetDOT()
if err != nil {
return err
}
_, err = result.WriteString(dot)
if err != nil {
return err
}
}
return nil
}
// GetDOT: convert the root graph into dot format
func (g *RootGraph) GetDOT() (string, error) {
var result strings.Builder
result.WriteString(fmt.Sprintf("%s {\n", g.Type))
result.WriteString("node [colorscheme=set312];\n")
result.WriteString("layout = fdp;\n")
nodes := lib.MapValues(g.nodes)
edges := lib.MapValues(g.edges)
writeContituents(&result, nodes...)
writeContituents(&result, edges...)
for _, cluster := range g.clusters {
clusterDOT, err := cluster.GetDOT()
if err != nil {
return "", err
}
result.WriteString(clusterDOT)
}
result.WriteString("}")
return result.String(), nil
}
// GetType: get the graph type. DIRECTED|UNDIRECTED
func (r *RootGraph) GetType() GraphType {
return r.Type
}
func constructEdge(graph Graph, name, label, from, to string) Edge {
switch graph.GetType() {
case DIGRAPH:
return &DirectedEdge{Name: name, Label: label, From: from, To: to}
default:
return &UndirectedEdge{Name: name, Label: label, From: from, To: to}
}
}
// AddEdge: adds an edge between two nodes in the root graph
func (g *RootGraph) AddEdge(name string, label string, from string, to string) error {
g.edges[name] = constructEdge(g, name, label, from, to)
return nil
}
const numColours = 12
func (n *Node) hash() int {
h := fnv.New32a()
h.Write([]byte(n.Name))
return (int(h.Sum32()) % numColours) + 1
}
// GetDOT: convert the node into DOT format
func (n *Node) GetDOT() (string, error) {
return fmt.Sprintf("node[label=\"%s\",shape=%s, style=\"filled\", fillcolor=%d, width=%d, height=%d, fixedsize=true] \"%s\";\n",
n.Label, n.Shape, n.hash(), n.Size, n.Size, n.Name), nil
}
// GetDOT: Convert a directed edge into dot format
func (e *DirectedEdge) GetDOT() (string, error) {
return fmt.Sprintf("\"%s\" -> \"%s\" [label=\"%s\"];\n", e.From, e.To, e.Label), nil
}
// GetDOT: convert an undirected edge into dot format
func (e *UndirectedEdge) GetDOT() (string, error) {
return fmt.Sprintf("\"%s\" -- \"%s\" [label=\"%s\"];\n", e.From, e.To, e.Label), nil
}
// AddEdge: adds an edge between two nodes in the graph
func (g *Cluster) AddEdge(name string, label string, from string, to string) error {
g.edges[name] = constructEdge(g, name, label, from, to)
return nil
}
// PutNode: puts a node in the graph
func (g *Cluster) PutNode(name, label string, size int, shape Shape) error {
_, exists := g.nodes[name]
if exists {
// If exists no need to add the ndoe
return nil
}
g.nodes[name] = &Node{Name: name, Label: label, Shape: shape, Size: size}
return nil
}
// GetDOT: convert the cluster into dot format
func (g *Cluster) GetDOT() (string, error) {
var builder strings.Builder
builder.WriteString(fmt.Sprintf("subgraph \"cluster%s\" {\n", g.Label))
builder.WriteString(fmt.Sprintf("label = \"%s\"\n", g.Label))
nodes := lib.MapValues(g.nodes)
edges := lib.MapValues(g.edges)
writeContituents(&builder, nodes...)
writeContituents(&builder, edges...)
builder.WriteString("}\n")
return builder.String(), nil
}
// GetType: get the type of the subgraph (directed|undirected)
func (g *Cluster) GetType() GraphType {
return g.Type
}
// NewSubGraph: instantiate a new subgraph
func NewSubGraph(name string, label string, graphType GraphType) *Cluster {
return &Cluster{
Label: name,
Type: graphType,
Name: name,
nodes: make(map[string]*Node),
edges: make(map[string]Edge),
}
}
// NewGraph: create a new root graph
func NewGraph(label string, graphType GraphType) *RootGraph {
return &RootGraph{
Type: graphType,
Label: label,
clusters: map[string]*Cluster{},
nodes: make(map[string]*Node),
edges: make(map[string]Edge),
}
}

116
pkg/dot/wg.go Normal file
View File

@ -0,0 +1,116 @@
package graph
import (
"fmt"
"slices"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
)
// MeshGraphConverter converts a mesh to a graph
type MeshGraphConverter interface {
// convert the mesh to textual form
Generate() (string, error)
}
type MeshDOTConverter struct {
meshes map[string][]ctrlserver.MeshNode
destinations map[string]interface{}
}
func (c *MeshDOTConverter) Generate() (string, error) {
g := NewGraph("Smegmesh", GRAPH)
for meshId := range c.meshes {
err := c.generateMesh(g, meshId)
if err != nil {
return "", err
}
}
for mesh := range c.meshes {
g.PutNode(mesh, mesh, 1, CIRCLE)
}
for destination := range c.destinations {
g.PutNode(destination, destination, 1, HEXAGON)
}
return g.GetDOT()
}
func (c *MeshDOTConverter) generateMesh(g *RootGraph, meshId string) error {
nodes := c.meshes[meshId]
g.PutNode(meshId, meshId, 1, CIRCLE)
for _, node := range nodes {
c.graphNode(g, node, meshId)
}
for _, node := range nodes {
g.AddEdge(fmt.Sprintf("%s to %s", node.PublicKey, meshId), "", node.PublicKey, meshId)
}
return nil
}
// graphNode: graphs a node within the mesh
func (c *MeshDOTConverter) graphNode(g *RootGraph, node ctrlserver.MeshNode, meshId string) {
alias := node.Alias
if alias == "" {
alias = node.WgHost[1:len(node.WgHost)-20] + "\\n" + node.WgHost[len(node.WgHost)-20:len(node.WgHost)]
}
g.PutNode(node.PublicKey, alias, 2, CIRCLE)
for _, route := range node.Routes {
if len(route.Path) == 0 {
g.AddEdge(route.Destination, "", node.PublicKey, route.Destination)
continue
}
reversedPath := slices.Clone(route.Path)
slices.Reverse(reversedPath)
g.AddEdge(fmt.Sprintf("%s to %s", node.PublicKey, reversedPath[0]), "", node.PublicKey, reversedPath[0])
for _, mesh := range route.Path {
if _, ok := c.meshes[mesh]; !ok {
c.destinations[mesh] = struct{}{}
}
}
for index := range reversedPath[0 : len(reversedPath)-1] {
routeID := fmt.Sprintf("%s to %s", reversedPath[index], reversedPath[index+1])
g.AddEdge(routeID, "", reversedPath[index], reversedPath[index+1])
}
if route.Destination == "::/0" {
c.destinations[route.Destination] = struct{}{}
lastMesh := reversedPath[len(reversedPath)-1]
routeID := fmt.Sprintf("%s to %s", lastMesh, route.Destination)
g.AddEdge(routeID, "", lastMesh, route.Destination)
}
}
for service := range node.Services {
c.putService(g, service, meshId, node)
}
}
// putService: construct a service node and a link between the nodes
func (c *MeshDOTConverter) putService(g *RootGraph, key, meshId string, node ctrlserver.MeshNode) {
serviceID := fmt.Sprintf("%s%s%s", key, node.PublicKey, meshId)
g.PutNode(serviceID, key, 1, PARALLELOGRAM)
g.AddEdge(fmt.Sprintf("%s to %s", node.PublicKey, serviceID), "", node.PublicKey, serviceID)
}
func NewMeshGraphConverter(meshes map[string][]ctrlserver.MeshNode) MeshGraphConverter {
return &MeshDOTConverter{
meshes: meshes,
destinations: make(map[string]interface{}),
}
}

View File

@ -1,178 +0,0 @@
// Graph allows the definition of a DOT graph in golang
package graph
import (
"errors"
"fmt"
"hash/fnv"
"strings"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
type GraphType string
type Shape string
const (
GRAPH GraphType = "graph"
DIGRAPH = "digraph"
)
const (
CIRCLE Shape = "circle"
STAR Shape = "star"
HEXAGON Shape = "hexagon"
)
type Graph struct {
Type GraphType
Label string
nodes map[string]*Node
edges []Edge
}
type Node struct {
Name string
Shape Shape
}
type Edge interface {
Dottable
}
type DirectedEdge struct {
Label string
From *Node
To *Node
}
type UndirectedEdge struct {
Label string
From *Node
To *Node
}
// Dottable means an implementer can convert the struct to DOT representation
type Dottable interface {
GetDOT() (string, error)
}
func NewGraph(label string, graphType GraphType) *Graph {
return &Graph{Type: graphType, Label: label, nodes: make(map[string]*Node), edges: make([]Edge, 0)}
}
// PutNode: puts a node in the graph
func (g *Graph) PutNode(label string, shape Shape) error {
_, exists := g.nodes[label]
if exists {
// If exists no need to add the ndoe
return nil
}
g.nodes[label] = &Node{Name: label, Shape: shape}
return nil
}
func writeContituents[D Dottable](result *strings.Builder, elements ...D) error {
for _, node := range elements {
dot, err := node.GetDOT()
if err != nil {
return err
}
_, err = result.WriteString(dot)
if err != nil {
return err
}
}
return nil
}
func (g *Graph) GetDOT() (string, error) {
var result strings.Builder
_, err := result.WriteString(fmt.Sprintf("%s {\n", g.Type))
if err != nil {
return "", err
}
_, err = result.WriteString("node [colorscheme=set312];\n")
if err != nil {
return "", err
}
nodes := lib.MapValues(g.nodes)
err = writeContituents(&result, nodes...)
if err != nil {
return "", err
}
err = writeContituents(&result, g.edges...)
if err != nil {
return "", err
}
_, err = result.WriteString("}")
if err != nil {
return "", err
}
return result.String(), nil
}
func (g *Graph) constructEdge(label string, from *Node, to *Node) Edge {
switch g.Type {
case DIGRAPH:
return &DirectedEdge{Label: label, From: from, To: to}
default:
return &UndirectedEdge{Label: label, From: from, To: to}
}
}
// AddEdge: adds an edge between two nodes in the graph
func (g *Graph) AddEdge(label string, from string, to string) error {
fromNode, exists := g.nodes[from]
if !exists {
return errors.New(fmt.Sprintf("Node %s does not exist", from))
}
toNode, exists := g.nodes[to]
if !exists {
return errors.New(fmt.Sprintf("Node %s does not exist", to))
}
g.edges = append(g.edges, g.constructEdge(label, fromNode, toNode))
return nil
}
const numColours = 12
func (n *Node) hash() int {
h := fnv.New32a()
h.Write([]byte(n.Name))
return (int(h.Sum32()) % numColours) + 1
}
func (n *Node) GetDOT() (string, error) {
return fmt.Sprintf("node[shape=%s, style=\"filled\", fillcolor=%d] %s;\n",
n.Shape, n.hash(), n.Name), nil
}
func (e *DirectedEdge) GetDOT() (string, error) {
return fmt.Sprintf("%s -> %s;\n", e.From.Name, e.To.Name), nil
}
func (e *UndirectedEdge) GetDOT() (string, error) {
return fmt.Sprintf("%s -- %s;\n", e.From.Name, e.To.Name), nil
}

212
pkg/grpc/ctrlserver.pb.go Normal file
View File

@ -0,0 +1,212 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.1
// protoc v3.21.12
// source: pkg/grpc/ctrlserver.proto
package rpc
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type GetMeshRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
MeshId string `protobuf:"bytes,1,opt,name=meshId,proto3" json:"meshId,omitempty"`
}
func (x *GetMeshRequest) Reset() {
*x = GetMeshRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *GetMeshRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GetMeshRequest) ProtoMessage() {}
func (x *GetMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GetMeshRequest.ProtoReflect.Descriptor instead.
func (*GetMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_proto_rawDescGZIP(), []int{0}
}
func (x *GetMeshRequest) GetMeshId() string {
if x != nil {
return x.MeshId
}
return ""
}
type GetMeshReply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Mesh []byte `protobuf:"bytes,1,opt,name=mesh,proto3" json:"mesh,omitempty"`
}
func (x *GetMeshReply) Reset() {
*x = GetMeshReply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *GetMeshReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GetMeshReply) ProtoMessage() {}
func (x *GetMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GetMeshReply.ProtoReflect.Descriptor instead.
func (*GetMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_proto_rawDescGZIP(), []int{1}
}
func (x *GetMeshReply) GetMesh() []byte {
if x != nil {
return x.Mesh
}
return nil
}
var File_pkg_grpc_ctrlserver_proto protoreflect.FileDescriptor
var file_pkg_grpc_ctrlserver_proto_rawDesc = []byte{
0x0a, 0x19, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73,
0x65, 0x72, 0x76, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x22, 0x28, 0x0a, 0x0e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68,
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49,
0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x22,
0x22, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12,
0x12, 0x0a, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x6d,
0x65, 0x73, 0x68, 0x32, 0x4f, 0x0a, 0x0e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x74, 0x72, 0x6c, 0x53,
0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x3d, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68,
0x12, 0x18, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d,
0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70,
0x6c, 0x79, 0x22, 0x00, 0x42, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62,
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_pkg_grpc_ctrlserver_proto_rawDescOnce sync.Once
file_pkg_grpc_ctrlserver_proto_rawDescData = file_pkg_grpc_ctrlserver_proto_rawDesc
)
func file_pkg_grpc_ctrlserver_proto_rawDescGZIP() []byte {
file_pkg_grpc_ctrlserver_proto_rawDescOnce.Do(func() {
file_pkg_grpc_ctrlserver_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_grpc_ctrlserver_proto_rawDescData)
})
return file_pkg_grpc_ctrlserver_proto_rawDescData
}
var file_pkg_grpc_ctrlserver_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_pkg_grpc_ctrlserver_proto_goTypes = []interface{}{
(*GetMeshRequest)(nil), // 0: rpctypes.GetMeshRequest
(*GetMeshReply)(nil), // 1: rpctypes.GetMeshReply
}
var file_pkg_grpc_ctrlserver_proto_depIdxs = []int32{
0, // 0: rpctypes.MeshCtrlServer.GetMesh:input_type -> rpctypes.GetMeshRequest
1, // 1: rpctypes.MeshCtrlServer.GetMesh:output_type -> rpctypes.GetMeshReply
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_pkg_grpc_ctrlserver_proto_init() }
func file_pkg_grpc_ctrlserver_proto_init() {
if File_pkg_grpc_ctrlserver_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_pkg_grpc_ctrlserver_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetMeshRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_ctrlserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetMeshReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pkg_grpc_ctrlserver_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_pkg_grpc_ctrlserver_proto_goTypes,
DependencyIndexes: file_pkg_grpc_ctrlserver_proto_depIdxs,
MessageInfos: file_pkg_grpc_ctrlserver_proto_msgTypes,
}.Build()
File_pkg_grpc_ctrlserver_proto = out.File
file_pkg_grpc_ctrlserver_proto_rawDesc = nil
file_pkg_grpc_ctrlserver_proto_goTypes = nil
file_pkg_grpc_ctrlserver_proto_depIdxs = nil
}

View File

@ -1,18 +0,0 @@
syntax = "proto3";
package rpctypes;
option go_package = "pkg/rpc";
service Authentication {
rpc JoinMesh(JoinAuthMeshRequest) returns (JoinAuthMeshReply) {}
}
message JoinAuthMeshRequest {
string meshId = 1;
string alias = 2;
}
message JoinAuthMeshReply {
bool success = 1;
optional string token = 2;
}

View File

@ -0,0 +1,105 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.2.0
// - protoc v3.21.12
// source: pkg/grpc/ctrlserver.proto
package rpc
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// MeshCtrlServerClient is the client API for MeshCtrlServer service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type MeshCtrlServerClient interface {
GetMesh(ctx context.Context, in *GetMeshRequest, opts ...grpc.CallOption) (*GetMeshReply, error)
}
type meshCtrlServerClient struct {
cc grpc.ClientConnInterface
}
func NewMeshCtrlServerClient(cc grpc.ClientConnInterface) MeshCtrlServerClient {
return &meshCtrlServerClient{cc}
}
func (c *meshCtrlServerClient) GetMesh(ctx context.Context, in *GetMeshRequest, opts ...grpc.CallOption) (*GetMeshReply, error) {
out := new(GetMeshReply)
err := c.cc.Invoke(ctx, "/rpctypes.MeshCtrlServer/GetMesh", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// MeshCtrlServerServer is the server API for MeshCtrlServer service.
// All implementations must embed UnimplementedMeshCtrlServerServer
// for forward compatibility
type MeshCtrlServerServer interface {
GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error)
mustEmbedUnimplementedMeshCtrlServerServer()
}
// UnimplementedMeshCtrlServerServer must be embedded to have forward compatible implementations.
type UnimplementedMeshCtrlServerServer struct {
}
func (UnimplementedMeshCtrlServerServer) GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetMesh not implemented")
}
func (UnimplementedMeshCtrlServerServer) mustEmbedUnimplementedMeshCtrlServerServer() {}
// UnsafeMeshCtrlServerServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to MeshCtrlServerServer will
// result in compilation errors.
type UnsafeMeshCtrlServerServer interface {
mustEmbedUnimplementedMeshCtrlServerServer()
}
func RegisterMeshCtrlServerServer(s grpc.ServiceRegistrar, srv MeshCtrlServerServer) {
s.RegisterService(&MeshCtrlServer_ServiceDesc, srv)
}
func _MeshCtrlServer_GetMesh_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(GetMeshRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(MeshCtrlServerServer).GetMesh(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/rpctypes.MeshCtrlServer/GetMesh",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(MeshCtrlServerServer).GetMesh(ctx, req.(*GetMeshRequest))
}
return interceptor(ctx, in, info, handler)
}
// MeshCtrlServer_ServiceDesc is the grpc.ServiceDesc for MeshCtrlServer service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var MeshCtrlServer_ServiceDesc = grpc.ServiceDesc{
ServiceName: "rpctypes.MeshCtrlServer",
HandlerType: (*MeshCtrlServerServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "GetMesh",
Handler: _MeshCtrlServer_GetMesh_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "pkg/grpc/ctrlserver.proto",
}

233
pkg/grpc/syncservice.pb.go Normal file
View File

@ -0,0 +1,233 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.1
// protoc v3.21.12
// source: pkg/grpc/syncservice.proto
package rpc
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type SyncMeshRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
MeshId string `protobuf:"bytes,1,opt,name=meshId,proto3" json:"meshId,omitempty"`
Changes []byte `protobuf:"bytes,2,opt,name=changes,proto3" json:"changes,omitempty"`
}
func (x *SyncMeshRequest) Reset() {
*x = SyncMeshRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_syncservice_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SyncMeshRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SyncMeshRequest) ProtoMessage() {}
func (x *SyncMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_syncservice_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SyncMeshRequest.ProtoReflect.Descriptor instead.
func (*SyncMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_syncservice_proto_rawDescGZIP(), []int{0}
}
func (x *SyncMeshRequest) GetMeshId() string {
if x != nil {
return x.MeshId
}
return ""
}
func (x *SyncMeshRequest) GetChanges() []byte {
if x != nil {
return x.Changes
}
return nil
}
type SyncMeshReply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"`
Changes []byte `protobuf:"bytes,2,opt,name=changes,proto3" json:"changes,omitempty"`
}
func (x *SyncMeshReply) Reset() {
*x = SyncMeshReply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_syncservice_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SyncMeshReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SyncMeshReply) ProtoMessage() {}
func (x *SyncMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_syncservice_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SyncMeshReply.ProtoReflect.Descriptor instead.
func (*SyncMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_syncservice_proto_rawDescGZIP(), []int{1}
}
func (x *SyncMeshReply) GetSuccess() bool {
if x != nil {
return x.Success
}
return false
}
func (x *SyncMeshReply) GetChanges() []byte {
if x != nil {
return x.Changes
}
return nil
}
var File_pkg_grpc_syncservice_proto protoreflect.FileDescriptor
var file_pkg_grpc_syncservice_proto_rawDesc = []byte{
0x0a, 0x1a, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x73, 0x79, 0x6e, 0x63, 0x73,
0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0b, 0x73, 0x79,
0x6e, 0x63, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x22, 0x43, 0x0a, 0x0f, 0x53, 0x79, 0x6e,
0x63, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06,
0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65,
0x73, 0x68, 0x49, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18,
0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x22, 0x43,
0x0a, 0x0d, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12,
0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08,
0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x68, 0x61,
0x6e, 0x67, 0x65, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x63, 0x68, 0x61, 0x6e,
0x67, 0x65, 0x73, 0x32, 0x59, 0x0a, 0x0b, 0x53, 0x79, 0x6e, 0x63, 0x53, 0x65, 0x72, 0x76, 0x69,
0x63, 0x65, 0x12, 0x4a, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x1c,
0x2e, 0x73, 0x79, 0x6e, 0x63, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x53, 0x79, 0x6e,
0x63, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x73,
0x79, 0x6e, 0x63, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x4d,
0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x09,
0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x33,
}
var (
file_pkg_grpc_syncservice_proto_rawDescOnce sync.Once
file_pkg_grpc_syncservice_proto_rawDescData = file_pkg_grpc_syncservice_proto_rawDesc
)
func file_pkg_grpc_syncservice_proto_rawDescGZIP() []byte {
file_pkg_grpc_syncservice_proto_rawDescOnce.Do(func() {
file_pkg_grpc_syncservice_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_grpc_syncservice_proto_rawDescData)
})
return file_pkg_grpc_syncservice_proto_rawDescData
}
var file_pkg_grpc_syncservice_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_pkg_grpc_syncservice_proto_goTypes = []interface{}{
(*SyncMeshRequest)(nil), // 0: syncservice.SyncMeshRequest
(*SyncMeshReply)(nil), // 1: syncservice.SyncMeshReply
}
var file_pkg_grpc_syncservice_proto_depIdxs = []int32{
0, // 0: syncservice.SyncService.SyncMesh:input_type -> syncservice.SyncMeshRequest
1, // 1: syncservice.SyncService.SyncMesh:output_type -> syncservice.SyncMeshReply
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_pkg_grpc_syncservice_proto_init() }
func file_pkg_grpc_syncservice_proto_init() {
if File_pkg_grpc_syncservice_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_pkg_grpc_syncservice_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SyncMeshRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_syncservice_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SyncMeshReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pkg_grpc_syncservice_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_pkg_grpc_syncservice_proto_goTypes,
DependencyIndexes: file_pkg_grpc_syncservice_proto_depIdxs,
MessageInfos: file_pkg_grpc_syncservice_proto_msgTypes,
}.Build()
File_pkg_grpc_syncservice_proto = out.File
file_pkg_grpc_syncservice_proto_rawDesc = nil
file_pkg_grpc_syncservice_proto_goTypes = nil
file_pkg_grpc_syncservice_proto_depIdxs = nil
}

View File

@ -4,18 +4,9 @@ package syncservice;
option go_package = "pkg/rpc"; option go_package = "pkg/rpc";
service SyncService { service SyncService {
rpc GetConf(GetConfRequest) returns (GetConfReply) {}
rpc SyncMesh(stream SyncMeshRequest) returns (stream SyncMeshReply) {} rpc SyncMesh(stream SyncMeshRequest) returns (stream SyncMeshReply) {}
} }
message GetConfRequest {
string meshId = 1;
}
message GetConfReply {
bytes mesh = 1;
}
message SyncMeshRequest { message SyncMeshRequest {
string meshId = 1; string meshId = 1;
bytes changes = 2; bytes changes = 2;

View File

@ -0,0 +1,137 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.2.0
// - protoc v3.21.12
// source: pkg/grpc/syncservice.proto
package rpc
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// SyncServiceClient is the client API for SyncService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type SyncServiceClient interface {
SyncMesh(ctx context.Context, opts ...grpc.CallOption) (SyncService_SyncMeshClient, error)
}
type syncServiceClient struct {
cc grpc.ClientConnInterface
}
func NewSyncServiceClient(cc grpc.ClientConnInterface) SyncServiceClient {
return &syncServiceClient{cc}
}
func (c *syncServiceClient) SyncMesh(ctx context.Context, opts ...grpc.CallOption) (SyncService_SyncMeshClient, error) {
stream, err := c.cc.NewStream(ctx, &SyncService_ServiceDesc.Streams[0], "/syncservice.SyncService/SyncMesh", opts...)
if err != nil {
return nil, err
}
x := &syncServiceSyncMeshClient{stream}
return x, nil
}
type SyncService_SyncMeshClient interface {
Send(*SyncMeshRequest) error
Recv() (*SyncMeshReply, error)
grpc.ClientStream
}
type syncServiceSyncMeshClient struct {
grpc.ClientStream
}
func (x *syncServiceSyncMeshClient) Send(m *SyncMeshRequest) error {
return x.ClientStream.SendMsg(m)
}
func (x *syncServiceSyncMeshClient) Recv() (*SyncMeshReply, error) {
m := new(SyncMeshReply)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// SyncServiceServer is the server API for SyncService service.
// All implementations must embed UnimplementedSyncServiceServer
// for forward compatibility
type SyncServiceServer interface {
SyncMesh(SyncService_SyncMeshServer) error
mustEmbedUnimplementedSyncServiceServer()
}
// UnimplementedSyncServiceServer must be embedded to have forward compatible implementations.
type UnimplementedSyncServiceServer struct {
}
func (UnimplementedSyncServiceServer) SyncMesh(SyncService_SyncMeshServer) error {
return status.Errorf(codes.Unimplemented, "method SyncMesh not implemented")
}
func (UnimplementedSyncServiceServer) mustEmbedUnimplementedSyncServiceServer() {}
// UnsafeSyncServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to SyncServiceServer will
// result in compilation errors.
type UnsafeSyncServiceServer interface {
mustEmbedUnimplementedSyncServiceServer()
}
func RegisterSyncServiceServer(s grpc.ServiceRegistrar, srv SyncServiceServer) {
s.RegisterService(&SyncService_ServiceDesc, srv)
}
func _SyncService_SyncMesh_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(SyncServiceServer).SyncMesh(&syncServiceSyncMeshServer{stream})
}
type SyncService_SyncMeshServer interface {
Send(*SyncMeshReply) error
Recv() (*SyncMeshRequest, error)
grpc.ServerStream
}
type syncServiceSyncMeshServer struct {
grpc.ServerStream
}
func (x *syncServiceSyncMeshServer) Send(m *SyncMeshReply) error {
return x.ServerStream.SendMsg(m)
}
func (x *syncServiceSyncMeshServer) Recv() (*SyncMeshRequest, error) {
m := new(SyncMeshRequest)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// SyncService_ServiceDesc is the grpc.ServiceDesc for SyncService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var SyncService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "syncservice.SyncService",
HandlerType: (*SyncServiceServer)(nil),
Methods: []grpc.MethodDesc{},
Streams: []grpc.StreamDesc{
{
StreamName: "SyncMesh",
Handler: _SyncService_SyncMesh_Handler,
ServerStreams: true,
ClientStreams: true,
},
},
Metadata: "pkg/grpc/syncservice.proto",
}

View File

@ -1,8 +1,7 @@
package ip package ip
/* // Generates a CGA see RFC 3972
* Use a WireGuard public key to generate a unique interface ID // https://datatracker.ietf.org/doc/html/rfc3972
*/
import ( import (
"crypto/rand" "crypto/rand"
@ -22,19 +21,23 @@ const (
InterfaceIdLen = 8 InterfaceIdLen = 8
) )
/* // CGAParameters: parameters used to create a new cryotpgraphically generated
* Cga parameters used to generate an IPV6 interface ID // address
*/
type CgaParameters struct { type CgaParameters struct {
Modifier [ModifierLength]byte Modifier [ModifierLength]byte
// SubnetPrefix: prefix of the subnetwork
SubnetPrefix [2 * InterfaceIdLen]byte SubnetPrefix [2 * InterfaceIdLen]byte
// CollisionCount: total number of times we have atempted to generate a porefix
CollisionCount uint8 CollisionCount uint8
// PublicKey: WireGuard public key of our interface
PublicKey wgtypes.Key PublicKey wgtypes.Key
// interfaceId: the generated interfaceId
interfaceId [2 * InterfaceIdLen]byte interfaceId [2 * InterfaceIdLen]byte
// flag: represents whether or not an IP address has been generated
flag byte flag byte
} }
func NewCga(key wgtypes.Key, subnetPrefix [2 * InterfaceIdLen]byte) (*CgaParameters, error) { func NewCga(key wgtypes.Key, collisionCount uint8, subnetPrefix [2 * InterfaceIdLen]byte) (*CgaParameters, error) {
var params CgaParameters var params CgaParameters
_, err := rand.Read(params.Modifier[:]) _, err := rand.Read(params.Modifier[:])
@ -45,25 +48,10 @@ func NewCga(key wgtypes.Key, subnetPrefix [2 * InterfaceIdLen]byte) (*CgaParamet
params.PublicKey = key params.PublicKey = key
params.SubnetPrefix = subnetPrefix params.SubnetPrefix = subnetPrefix
params.CollisionCount = collisionCount
return &params, nil return &params, nil
} }
func (c *CgaParameters) generateHash2() []byte {
var byteVal [hash2Length]byte
for i := 0; i < ModifierLength; i++ {
byteVal[i] = c.Modifier[i]
}
for i := 0; i < wgtypes.KeyLen; i++ {
byteVal[ModifierLength+ZeroLength+i] = c.PublicKey[i]
}
hash := sha1.Sum(byteVal[:])
return hash[:Hash2Prefix]
}
func (c *CgaParameters) generateHash1() []byte { func (c *CgaParameters) generateHash1() []byte {
var byteVal [hash1Length]byte var byteVal [hash1Length]byte
@ -78,7 +66,6 @@ func (c *CgaParameters) generateHash1() []byte {
byteVal[hash1Length-1] = c.CollisionCount byteVal[hash1Length-1] = c.CollisionCount
hash := sha1.Sum(byteVal[:]) hash := sha1.Sum(byteVal[:])
return hash[:Hash1Prefix] return hash[:Hash1Prefix]
} }
@ -90,9 +77,6 @@ func clearBit(num, pos int) byte {
} }
func (c *CgaParameters) generateInterface() []byte { func (c *CgaParameters) generateInterface() []byte {
// TODO: On duplicate address detection increment collision.
// Also incorporate SEC
hash1 := c.generateHash1() hash1 := c.generateHash1()
var interfaceId []byte = make([]byte, InterfaceIdLen) var interfaceId []byte = make([]byte, InterfaceIdLen)
@ -101,7 +85,6 @@ func (c *CgaParameters) generateInterface() []byte {
interfaceId[0] = clearBit(int(interfaceId[0]), 6) interfaceId[0] = clearBit(int(interfaceId[0]), 6)
interfaceId[0] = clearBit(int(interfaceId[1]), 7) interfaceId[0] = clearBit(int(interfaceId[1]), 7)
return interfaceId return interfaceId
} }

View File

@ -6,6 +6,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// IPAllocator: abstracts the process of creating an IP address
type IPAllocator interface { type IPAllocator interface {
GetIP(key wgtypes.Key, meshId string) (net.IP, error) GetIP(key wgtypes.Key, meshId string, collisionCount uint8) (net.IP, error)
} }

View File

@ -8,6 +8,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// ULABuilder: Create a new ULA in WireGuard
type ULABuilder struct{} type ULABuilder struct{}
func getMeshPrefix(meshId string) [16]byte { func getMeshPrefix(meshId string) [16]byte {
@ -39,10 +40,10 @@ func (u *ULABuilder) GetIPNet(meshId string) (*net.IPNet, error) {
return net, nil return net, nil
} }
func (u *ULABuilder) GetIP(key wgtypes.Key, meshId string) (net.IP, error) { func (u *ULABuilder) GetIP(key wgtypes.Key, meshId string, collisionCount uint8) (net.IP, error) {
ulaPrefix := getMeshPrefix(meshId) ulaPrefix := getMeshPrefix(meshId)
c, err := NewCga(key, ulaPrefix) c, err := NewCga(key, collisionCount, ulaPrefix)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -5,11 +5,27 @@ import (
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
ipcRPC "net/rpc"
"os" "os"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
) )
const SockAddr = "/tmp/smeg.sock"
type MeshIpc interface {
CreateMesh(args *NewMeshArgs, reply *string) error
ListMeshes(name string, reply *ListMeshReply) error
JoinMesh(args *JoinMeshArgs, reply *string) error
LeaveMesh(meshId string, reply *string) error
GetMesh(meshId string, reply *GetMeshReply) error
Query(query QueryMesh, reply *string) error
PutDescription(args PutDescriptionArgs, reply *string) error
PutAlias(args PutAliasArgs, reply *string) error
PutService(args PutServiceArgs, reply *string) error
DeleteService(args DeleteServiceArgs, reply *string) error
}
// WireGuardArgs are provided args specific to WireGuard // WireGuardArgs are provided args specific to WireGuard
type WireGuardArgs struct { type WireGuardArgs struct {
// WgPort is the WireGuard port to expose // WgPort is the WireGuard port to expose
@ -39,44 +55,141 @@ type JoinMeshArgs struct {
// MeshId is the ID of the mesh to join // MeshId is the ID of the mesh to join
MeshId string MeshId string
// IpAddress is a routable IP in another mesh // IpAddress is a routable IP in another mesh
IpAdress string IpAddress string
// WgArgs is the WireGuard parameters to use. // WgArgs is the WireGuard parameters to use.
WgArgs WireGuardArgs WgArgs WireGuardArgs
} }
// PutServiceArgs: args to place a service into the data store
type PutServiceArgs struct { type PutServiceArgs struct {
Service string Service string
Value string Value string
MeshId string
} }
// DeleteServiceArgs: args to remove a service from the data store
type DeleteServiceArgs struct {
Service string
MeshId string
}
// PutAliasArgs: args to assign an alias to a node
type PutAliasArgs struct {
// Alias: represents the alias of the node
Alias string
// MeshId: represents the meshID of the node
MeshId string
}
// PutDescriptionArgs: args to assign a description to a node
type PutDescriptionArgs struct {
// Description: descriptio to add to the network
Description string
// MeshID to add to the mesh network
MeshId string
}
// GetMeshReply: ipc reply to get the mesh network
type GetMeshReply struct { type GetMeshReply struct {
Nodes []ctrlserver.MeshNode Nodes []ctrlserver.MeshNode
} }
// ListMeshReply: ipc reply of the networks the node is part of
type ListMeshReply struct { type ListMeshReply struct {
Meshes []string Meshes []string
} }
// Querymesh: ipc args to query a mesh network
type QueryMesh struct { type QueryMesh struct {
// MeshId: id of the mesh to query
MeshId string MeshId string
Query string // JMESPath: query string to query
Query string
} }
type MeshIpc interface { // ClientIpc: Framework to invoke ipc calls to the daemon
type ClientIpc interface {
// CreateMesh: create a mesh network, return an error if the operation failed
CreateMesh(args *NewMeshArgs, reply *string) error CreateMesh(args *NewMeshArgs, reply *string) error
ListMeshes(name string, reply *ListMeshReply) error // ListMesh: list mesh network the node is a part of, return an error if the operation failed
ListMeshes(args *ListMeshReply, reply *string) error
// JoinMesh: join a mesh network return an error if the operation failed
JoinMesh(args JoinMeshArgs, reply *string) error JoinMesh(args JoinMeshArgs, reply *string) error
// LeaveMesh: leave a mesh network, return an error if the operation failed
LeaveMesh(meshId string, reply *string) error LeaveMesh(meshId string, reply *string) error
// GetMesh: get the given mesh network, return an error if the operation failed
GetMesh(meshId string, reply *GetMeshReply) error GetMesh(meshId string, reply *GetMeshReply) error
GetDOT(meshId string, reply *string) error // Query: query the given mesh network
Query(query QueryMesh, reply *string) error Query(query QueryMesh, reply *string) error
PutDescription(description string, reply *string) error // PutDescription: assign a description to yourself
PutAlias(alias string, reply *string) error PutDescription(args PutDescriptionArgs, reply *string) error
// PutAlias: assign an alias to yourself
PutAlias(args PutAliasArgs, reply *string) error
// PutService: assign a service to yourself
PutService(args PutServiceArgs, reply *string) error PutService(args PutServiceArgs, reply *string) error
DeleteService(service string, reply *string) error // DeleteService: retract a service
DeleteService(args DeleteServiceArgs, reply *string) error
} }
const SockAddr = "/tmp/wgmesh_ipc.sock" type SmegmeshIpc struct {
client *ipcRPC.Client
}
func NewClientIpc() (*SmegmeshIpc, error) {
client, err := ipcRPC.DialHTTP("unix", SockAddr)
if err != nil {
return nil, err
}
return &SmegmeshIpc{
client: client,
}, nil
}
func (c *SmegmeshIpc) CreateMesh(args *NewMeshArgs, reply *string) error {
return c.client.Call("IpcHandler.CreateMesh", args, reply)
}
func (c *SmegmeshIpc) ListMeshes(reply *ListMeshReply) error {
return c.client.Call("IpcHandler.ListMeshes", "", reply)
}
func (c *SmegmeshIpc) JoinMesh(args JoinMeshArgs, reply *string) error {
return c.client.Call("IpcHandler.JoinMesh", &args, reply)
}
func (c *SmegmeshIpc) LeaveMesh(meshId string, reply *string) error {
return c.client.Call("IpcHandler.LeaveMesh", &meshId, reply)
}
func (c *SmegmeshIpc) GetMesh(meshId string, reply *GetMeshReply) error {
return c.client.Call("IpcHandler.GetMesh", &meshId, reply)
}
func (c *SmegmeshIpc) Query(query QueryMesh, reply *string) error {
return c.client.Call("IpcHandler.Query", &query, reply)
}
func (c *SmegmeshIpc) PutDescription(args PutDescriptionArgs, reply *string) error {
return c.client.Call("IpcHandler.PutDescription", &args, reply)
}
func (c *SmegmeshIpc) PutAlias(args PutAliasArgs, reply *string) error {
return c.client.Call("IpcHandler.PutAlias", &args, reply)
}
func (c *SmegmeshIpc) PutService(args PutServiceArgs, reply *string) error {
return c.client.Call("IpcHandler.PutService", &args, reply)
}
func (c *SmegmeshIpc) DeleteService(args DeleteServiceArgs, reply *string) error {
return c.client.Call("IpcHandler.DeleteService", &args, reply)
}
func (c *SmegmeshIpc) Close() error {
return c.client.Close()
}
func RunIpcHandler(server MeshIpc) error { func RunIpcHandler(server MeshIpc) error {
if err := os.RemoveAll(SockAddr); err != nil { if err := os.RemoveAll(SockAddr); err != nil {

View File

@ -36,11 +36,13 @@ func ConsistentHash[V any, K any](values []V, client K, bucketFunc func(V) int,
ourKey := keyFunc(client) ourKey := keyFunc(client)
for _, record := range vs { idx := sort.Search(len(vs), func(i int) bool {
if ourKey < record.value { return vs[i].value >= ourKey
return record.record })
}
if idx == len(vs) {
return vs[0].record
} }
return vs[0].record return vs[idx].record
} }

View File

@ -3,6 +3,7 @@ package lib
import ( import (
"github.com/anandvarma/namegen" "github.com/anandvarma/namegen"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/lithammer/shortuuid"
) )
// IdGenerator generates unique ids // IdGenerator generates unique ids
@ -19,6 +20,14 @@ func (g *UUIDGenerator) GetId() (string, error) {
return id.String(), nil return id.String(), nil
} }
type ShortIDGenerator struct {
}
func (g *ShortIDGenerator) GetId() (string, error) {
id := shortuuid.New()
return id, nil
}
type IDNameGenerator struct { type IDNameGenerator struct {
} }

View File

@ -6,27 +6,21 @@ import (
"net" "net"
"github.com/jsimonetti/rtnetlink" "github.com/jsimonetti/rtnetlink"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// Maximum MTU to assin to WireGuard
// This isn't configurable
const WIREGUARD_MTU = 1420
// RtNetlinkConfig: represents an rtnetlkink configuration instance
type RtNetlinkConfig struct { type RtNetlinkConfig struct {
// conn: connection to the rtnetlink API
conn *rtnetlink.Conn conn *rtnetlink.Conn
} }
func NewRtNetlinkConfig() (*RtNetlinkConfig, error) { // CreateLink: Create a netlink interface if it does not exist. ifName is the name of the netlink interface
conn, err := rtnetlink.Dial(nil)
if err != nil {
return nil, err
}
return &RtNetlinkConfig{conn: conn}, nil
}
const WIREGUARD_MTU = 1420
// Create a netlink interface if it does not exist. ifName is the name of the netlink interface
func (c *RtNetlinkConfig) CreateLink(ifName string) error { func (c *RtNetlinkConfig) CreateLink(ifName string) error {
_, err := net.InterfaceByName(ifName) _, err := net.InterfaceByName(ifName)
@ -51,7 +45,7 @@ func (c *RtNetlinkConfig) CreateLink(ifName string) error {
return nil return nil
} }
// Delete link delete the specified interface // DeleteLink: delete the specified interface
func (c *RtNetlinkConfig) DeleteLink(ifName string) error { func (c *RtNetlinkConfig) DeleteLink(ifName string) error {
iface, err := net.InterfaceByName(ifName) iface, err := net.InterfaceByName(ifName)
@ -68,7 +62,7 @@ func (c *RtNetlinkConfig) DeleteLink(ifName string) error {
return nil return nil
} }
// AddAddress adds an address to the given interface. // AddAddress: adds an address to the given interface.
func (c *RtNetlinkConfig) AddAddress(ifName string, address string) error { func (c *RtNetlinkConfig) AddAddress(ifName string, address string) error {
iface, err := net.InterfaceByName(ifName) iface, err := net.InterfaceByName(ifName)
@ -177,7 +171,7 @@ func (c *RtNetlinkConfig) AddRoute(ifName string, route Route) error {
return nil return nil
} }
// DeleteRoute deletes routes with the gateway and destination // DeleteRoute: deletes routes with the gateway and destination
func (c *RtNetlinkConfig) DeleteRoute(ifName string, route Route) error { func (c *RtNetlinkConfig) DeleteRoute(ifName string, route Route) error {
iface, err := net.InterfaceByName(ifName) iface, err := net.InterfaceByName(ifName)
@ -219,6 +213,7 @@ func (c *RtNetlinkConfig) DeleteRoute(ifName string, route Route) error {
return nil return nil
} }
// route: represents a rout to add to the RIB
type Route struct { type Route struct {
Gateway net.IP Gateway net.IP
Destination net.IPNet Destination net.IPNet
@ -232,7 +227,7 @@ func (r1 Route) equal(r2 Route) bool {
(mask1Ones == 0 && mask2Ones == 0 || r1.Destination.IP.Equal(r2.Destination.IP)) (mask1Ones == 0 && mask2Ones == 0 || r1.Destination.IP.Equal(r2.Destination.IP))
} }
// DeleteRoutes deletes all routes not in exclude // DeleteRoutes: deletes all routes not in exclude on the given interface
func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...Route) error { func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...Route) error {
routes, err := c.listRoutes(ifName, family) routes, err := c.listRoutes(ifName, family)
@ -282,7 +277,7 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
return nil return nil
} }
// listRoutes lists all routes on the interface // listRoutes: lists all routes on the interface
func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8) ([]rtnetlink.RouteMessage, error) { func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8) ([]rtnetlink.RouteMessage, error) {
iface, err := net.InterfaceByName(ifName) iface, err := net.InterfaceByName(ifName)
@ -304,6 +299,18 @@ func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8) ([]rtnetlink.R
return routes, nil return routes, nil
} }
// Close: close the Rtnetlink API
func (c *RtNetlinkConfig) Close() error { func (c *RtNetlinkConfig) Close() error {
return c.conn.Close() return c.conn.Close()
} }
// newRtNetlinkConfig: connect to the RtnetlinkAPI
func NewRtNetlinkConfig() (*RtNetlinkConfig, error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return nil, err
}
return &RtNetlinkConfig{conn: conn}, nil
}

View File

@ -6,6 +6,7 @@ import (
"os" "os"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/tim-beatham/smegmesh/pkg/conf"
) )
var ( var (
@ -39,17 +40,29 @@ func (l *LogrusLogger) Writer() io.Writer {
return l.logger.Writer() return l.logger.Writer()
} }
func NewLogrusLogger() *LogrusLogger { func NewLogrusLogger(confLevel conf.LogLevel) *LogrusLogger {
var level logrus.Level
switch confLevel {
case conf.ERROR:
level = logrus.ErrorLevel
case conf.WARNING:
level = logrus.WarnLevel
case conf.INFO:
level = logrus.InfoLevel
}
logger := logrus.New() logger := logrus.New()
logger.SetFormatter(&logrus.TextFormatter{FullTimestamp: true}) logger.SetFormatter(&logrus.TextFormatter{FullTimestamp: true})
logger.SetOutput(os.Stdout) logger.SetOutput(os.Stdout)
logger.SetLevel(logrus.InfoLevel) logger.SetLevel(level)
return &LogrusLogger{logger: logger} return &LogrusLogger{logger: logger}
} }
func init() { func init() {
SetLogger(NewLogrusLogger()) SetLogger(NewLogrusLogger(conf.INFO))
} }
func SetLogger(l Logger) { func SetLogger(l Logger) {

View File

@ -7,21 +7,22 @@ import (
"strings" "strings"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/route" "github.com/tim-beatham/smegmesh/pkg/route"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// MeshConfigApplyer abstracts applying the mesh configuration // MeshConfigApplyer abstracts applying the mesh configuration
type MeshConfigApplyer interface { type MeshConfigApplyer interface {
// ApplyConfig: apply the configurtation
ApplyConfig() error ApplyConfig() error
RemovePeers(meshId string) error // SetMeshManager: sets the associated manager
SetMeshManager(manager MeshManager) SetMeshManager(manager MeshManager)
} }
// WgMeshConfigApplyer applies WireGuard configuration // WgMeshConfigApplyer: applies WireGuard configuration
type WgMeshConfigApplyer struct { type WgMeshConfigApplyer struct {
meshManager MeshManager meshManager MeshManager
routeInstaller route.RouteInstaller routeInstaller route.RouteInstaller
@ -91,7 +92,11 @@ func (m *WgMeshConfigApplyer) convertMeshNode(params convertMeshNodeParams) (*wg
return p.PublicKey.String() == pubKey.String() return p.PublicKey.String() == pubKey.String()
}) })
endpoint, err := net.ResolveUDPAddr("udp", params.node.GetWgEndpoint()) var endpoint *net.UDPAddr = nil
if params.node.GetType() == conf.PEER_ROLE {
endpoint, err = net.ResolveUDPAddr("udp", params.node.GetWgEndpoint())
}
if err != nil { if err != nil {
return nil, err return nil, err
@ -115,8 +120,13 @@ func (m *WgMeshConfigApplyer) convertMeshNode(params convertMeshNodeParams) (*wg
// getRoutes: finds the routes with the least hop distance. If more than one route exists // getRoutes: finds the routes with the least hop distance. If more than one route exists
// consistently hash to evenly spread the distribution of traffic // consistently hash to evenly spread the distribution of traffic
func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]routeNode { func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) (map[string][]routeNode, error) {
mesh, _ := meshProvider.GetMesh() mesh, err := meshProvider.GetMesh()
if err != nil {
return nil, err
}
routes := make(map[string][]routeNode) routes := make(map[string][]routeNode)
peers := lib.Filter(lib.MapValues(mesh.GetNodes()), func(p MeshNode) bool { peers := lib.Filter(lib.MapValues(mesh.GetNodes()), func(p MeshNode) bool {
@ -154,17 +164,18 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
// Client's only acessible by another peer // Client's only acessible by another peer
if node.GetType() == conf.CLIENT_ROLE { if node.GetType() == conf.CLIENT_ROLE {
peer := m.getCorrespondingPeer(peers, node) peer := m.getCorrespondingPeer(peers, node)
self, _ := m.meshManager.GetSelf(meshProvider.GetMeshId()) self, err := meshProvider.GetNode(m.meshManager.GetPublicKey().String())
if err != nil {
return nil, err
}
// If the node isn't the self use that peer as the gateway
if !NodeEquals(peer, self) { if !NodeEquals(peer, self) {
peerPub, _ := peer.GetPublicKey() peerPub, _ := peer.GetPublicKey()
rn.gateway = peerPub.String() rn.gateway = peerPub.String()
rn.route = &RouteStub{ rn.route = &RouteStub{
Destination: rn.route.GetDestination(), Destination: rn.route.GetDestination(),
HopCount: rn.route.GetHopCount() + 1, Path: append(rn.route.GetPath(), peer.GetWgHost().IP.String()),
// Append the path to this peer
Path: append(rn.route.GetPath(), peer.GetWgHost().IP.String()),
} }
} }
} }
@ -181,7 +192,7 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
} }
} }
return routes return routes, nil
} }
// getCorrespondignPeer: gets the peer corresponding to the client // getCorrespondignPeer: gets the peer corresponding to the client
@ -190,6 +201,7 @@ func (m *WgMeshConfigApplyer) getCorrespondingPeer(peers []MeshNode, client Mesh
return peer return peer
} }
// getPeerCfgsToRemove: remove peer configurations that are no longer in the mesh
func (m *WgMeshConfigApplyer) getPeerCfgsToRemove(dev *wgtypes.Device, newPeers []wgtypes.PeerConfig) []wgtypes.PeerConfig { func (m *WgMeshConfigApplyer) getPeerCfgsToRemove(dev *wgtypes.Device, newPeers []wgtypes.PeerConfig) []wgtypes.PeerConfig {
peers := dev.Peers peers := dev.Peers
peers = lib.Filter(peers, func(p1 wgtypes.Peer) bool { peers = lib.Filter(peers, func(p1 wgtypes.Peer) bool {
@ -214,27 +226,37 @@ type GetConfigParams struct {
routes map[string][]routeNode routes map[string][]routeNode
} }
// getClientConfig: if the node is a client get their configuration
func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes.Config, error) { func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes.Config, error) {
self, err := m.meshManager.GetSelf(params.mesh.GetMeshId())
ula := &ip.ULABuilder{} ula := &ip.ULABuilder{}
meshNet, _ := ula.GetIPNet(params.mesh.GetMeshId()) meshNet, _ := ula.GetIPNet(params.mesh.GetMeshId())
routesForMesh := lib.Map(lib.MapValues(params.routes), func(rns []routeNode) []routeNode { routesForMesh := lib.Map(lib.MapValues(params.routes), func(rns []routeNode) []routeNode {
return lib.Filter(rns, func(rn routeNode) bool { return lib.Filter(rns, func(rn routeNode) bool {
ip, _, _ := net.ParseCIDR(rn.gateway) node, err := params.mesh.GetNode(rn.gateway)
return meshNet.Contains(ip) return node != nil && err == nil
}) })
}) })
routesForMesh = lib.Filter(routesForMesh, func(rns []routeNode) bool {
return len(rns) != 0
})
routes := lib.Map(routesForMesh, func(rs []routeNode) net.IPNet { routes := lib.Map(routesForMesh, func(rs []routeNode) net.IPNet {
return *rs[0].route.GetDestination() return *rs[0].route.GetDestination()
}) })
routes = append(routes, *meshNet) routes = append(routes, *meshNet)
self, err := params.mesh.GetNode(m.meshManager.GetPublicKey().String())
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(params.peers) == 0 {
return nil, fmt.Errorf("no peers in the mesh")
}
peer := m.getCorrespondingPeer(params.peers, self) peer := m.getCorrespondingPeer(params.peers, self)
pubKey, _ := peer.GetPublicKey() pubKey, _ := peer.GetPublicKey()
@ -260,10 +282,14 @@ func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes
installedRoutes := make([]lib.Route, 0) installedRoutes := make([]lib.Route, 0)
for _, route := range peerCfgs[0].AllowedIPs { for _, route := range peerCfgs[0].AllowedIPs {
installedRoutes = append(installedRoutes, lib.Route{ // Don't install routes that we are directly apart
Gateway: peer.GetWgHost().IP, // Dont install default route wgctrl handles this for us
Destination: route, if !meshNet.Contains(route.IP) {
}) installedRoutes = append(installedRoutes, lib.Route{
Gateway: peer.GetWgHost().IP,
Destination: route,
})
}
} }
cfg := wgtypes.Config{ cfg := wgtypes.Config{
@ -274,6 +300,8 @@ func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes
return &cfg, err return &cfg, err
} }
// getRoutesToInstall: work out if the given node is advertising routes that should be installed into the
// RIB
func (m *WgMeshConfigApplyer) getRoutesToInstall(wgNode *wgtypes.PeerConfig, mesh MeshProvider, node MeshNode) []lib.Route { func (m *WgMeshConfigApplyer) getRoutesToInstall(wgNode *wgtypes.PeerConfig, mesh MeshProvider, node MeshNode) []lib.Route {
routes := make([]lib.Route, 0) routes := make([]lib.Route, 0)
@ -281,9 +309,8 @@ func (m *WgMeshConfigApplyer) getRoutesToInstall(wgNode *wgtypes.PeerConfig, mes
ula := &ip.ULABuilder{} ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId()) ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
_, defaultRoute, _ := net.ParseCIDR("::/0") // Check there is no overlap in network and its not the default route
if !ipNet.Contains(route.IP) {
if !ipNet.Contains(route.IP) && !ipNet.IP.Equal(defaultRoute.IP) {
routes = append(routes, lib.Route{ routes = append(routes, lib.Route{
Gateway: node.GetWgHost().IP, Gateway: node.GetWgHost().IP,
Destination: route, Destination: route,
@ -294,11 +321,12 @@ func (m *WgMeshConfigApplyer) getRoutesToInstall(wgNode *wgtypes.PeerConfig, mes
return routes return routes
} }
// getPeerConfig: creates the WireGuard configuration for a peer
func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.Config, error) { func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.Config, error) {
peerToClients := make(map[string][]net.IPNet) peerToClients := make(map[string][]net.IPNet)
installedRoutes := make([]lib.Route, 0) installedRoutes := make([]lib.Route, 0)
peerConfigs := make([]wgtypes.PeerConfig, 0) peerConfigs := make([]wgtypes.PeerConfig, 0)
self, err := m.meshManager.GetSelf(params.mesh.GetMeshId()) self, err := params.mesh.GetNode(m.meshManager.GetPublicKey().String())
if err != nil { if err != nil {
return nil, err return nil, err
@ -367,6 +395,7 @@ func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.C
return &cfg, err return &cfg, err
} }
// updateWgConf: update the WireGuard configuration
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string][]routeNode) error { func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string][]routeNode) error {
snap, err := mesh.GetMesh() snap, err := mesh.GetMesh()
@ -389,7 +418,7 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string]
return mn.GetType() == conf.CLIENT_ROLE return mn.GetType() == conf.CLIENT_ROLE
}) })
self, err := m.meshManager.GetSelf(mesh.GetMeshId()) self, err := mesh.GetNode(m.meshManager.GetPublicKey().String())
if err != nil { if err != nil {
return err return err
@ -428,11 +457,17 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string]
return nil return nil
} }
func (m *WgMeshConfigApplyer) getAllRoutes() map[string][]routeNode { // getAllRoutes: works out all the routes to install out of all the routes in the
// set of networks the node is a part of
func (m *WgMeshConfigApplyer) getAllRoutes() (map[string][]routeNode, error) {
allRoutes := make(map[string][]routeNode) allRoutes := make(map[string][]routeNode)
for _, mesh := range m.meshManager.GetMeshes() { for _, mesh := range m.meshManager.GetMeshes() {
routes := m.getRoutes(mesh) routes, err := m.getRoutes(mesh)
if err != nil {
return nil, err
}
for destination, route := range routes { for destination, route := range routes {
_, ok := allRoutes[destination] _, ok := allRoutes[destination]
@ -450,11 +485,16 @@ func (m *WgMeshConfigApplyer) getAllRoutes() map[string][]routeNode {
} }
} }
return allRoutes return allRoutes, nil
} }
// ApplyConfig: apply the WireGuard configuration
func (m *WgMeshConfigApplyer) ApplyConfig() error { func (m *WgMeshConfigApplyer) ApplyConfig() error {
allRoutes := m.getAllRoutes() allRoutes, err := m.getAllRoutes()
if err != nil {
return err
}
for _, mesh := range m.meshManager.GetMeshes() { for _, mesh := range m.meshManager.GetMeshes() {
err := m.updateWgConf(mesh, allRoutes) err := m.updateWgConf(mesh, allRoutes)
@ -467,27 +507,6 @@ func (m *WgMeshConfigApplyer) ApplyConfig() error {
return nil return nil
} }
func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error {
mesh := m.meshManager.GetMesh(meshId)
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
dev, err := mesh.GetDevice()
if err != nil {
return err
}
m.meshManager.GetClient().ConfigureDevice(dev.Name, wgtypes.Config{
Peers: make([]wgtypes.PeerConfig, 0),
ReplacePeers: true,
})
return nil
}
func (m *WgMeshConfigApplyer) SetMeshManager(manager MeshManager) { func (m *WgMeshConfigApplyer) SetMeshManager(manager MeshManager) {
m.meshManager = manager m.meshManager = manager
} }

View File

@ -1,77 +0,0 @@
package mesh
import (
"errors"
"fmt"
"github.com/tim-beatham/wgmesh/pkg/graph"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
// MeshGraphConverter converts a mesh to a graph
type MeshGraphConverter interface {
// convert the mesh to textual form
Generate(meshId string) (string, error)
}
type MeshDOTConverter struct {
manager MeshManager
}
func (c *MeshDOTConverter) Generate(meshId string) (string, error) {
mesh := c.manager.GetMesh(meshId)
if mesh == nil {
return "", errors.New("mesh does not exist")
}
g := graph.NewGraph(meshId, graph.GRAPH)
snapshot, err := mesh.GetMesh()
if err != nil {
return "", err
}
for _, node := range snapshot.GetNodes() {
c.graphNode(g, node, meshId)
}
nodes := lib.MapValues(snapshot.GetNodes())
for i, node1 := range nodes[:len(nodes)-1] {
for _, node2 := range nodes[i+1:] {
if node1.GetWgEndpoint() == node2.GetWgEndpoint() {
continue
}
node1Id := fmt.Sprintf("\"%s\"", node1.GetIdentifier())
node2Id := fmt.Sprintf("\"%s\"", node2.GetIdentifier())
g.AddEdge(fmt.Sprintf("%s to %s", node1Id, node2Id), node1Id, node2Id)
}
}
return g.GetDOT()
}
// graphNode: graphs a node within the mesh
func (c *MeshDOTConverter) graphNode(g *graph.Graph, node MeshNode, meshId string) {
nodeId := fmt.Sprintf("\"%s\"", node.GetIdentifier())
g.PutNode(nodeId, graph.CIRCLE)
self, _ := c.manager.GetSelf(meshId)
if NodeEquals(self, node) {
return
}
for _, route := range node.GetRoutes() {
routeId := fmt.Sprintf("\"%s\"", route)
g.PutNode(routeId, graph.HEXAGON)
g.AddEdge(fmt.Sprintf("%s to %s", nodeId, routeId), nodeId, routeId)
}
}
func NewMeshDotConverter(m MeshManager) MeshGraphConverter {
return &MeshDOTConverter{manager: m}
}

View File

@ -3,17 +3,21 @@ package mesh
import ( import (
"errors" "errors"
"fmt" "fmt"
"net"
"sync" "sync"
"github.com/tim-beatham/wgmesh/pkg/cmd" "github.com/tim-beatham/smegmesh/pkg/cmd"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/wg" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/wg"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// MeshManager: abstracts maanging meshes, including installing the WireGuard configuration
// to the device, and adding and removing nodes
type MeshManager interface { type MeshManager interface {
CreateMesh(params *CreateMeshParams) (string, error) CreateMesh(params *CreateMeshParams) (string, error)
AddMesh(params *AddMeshParams) error AddMesh(params *AddMeshParams) error
@ -24,10 +28,10 @@ type MeshManager interface {
LeaveMesh(meshId string) error LeaveMesh(meshId string) error
GetSelf(meshId string) (MeshNode, error) GetSelf(meshId string) (MeshNode, error)
ApplyConfig() error ApplyConfig() error
SetDescription(description string) error SetDescription(meshId, description string) error
SetAlias(alias string) error SetAlias(meshId, alias string) error
SetService(service string, value string) error SetService(meshId, service, value string) error
RemoveService(service string) error RemoveService(meshId, service string) error
UpdateTimeStamp() error UpdateTimeStamp() error
GetClient() *wgctrl.Client GetClient() *wgctrl.Client
GetMeshes() map[string]MeshProvider GetMeshes() map[string]MeshProvider
@ -37,12 +41,10 @@ type MeshManager interface {
} }
type MeshManagerImpl struct { type MeshManagerImpl struct {
lock sync.RWMutex meshLock sync.RWMutex
Meshes map[string]MeshProvider meshes map[string]MeshProvider
RouteManager RouteManager RouteManager RouteManager
Client *wgctrl.Client Client *wgctrl.Client
// HostParameters contains information that uniquely locates
// the node in the mesh network.
HostParameters *HostParameters HostParameters *HostParameters
conf *conf.DaemonConfiguration conf *conf.DaemonConfiguration
meshProviderFactory MeshProviderFactory meshProviderFactory MeshProviderFactory
@ -55,39 +57,43 @@ type MeshManagerImpl struct {
OnDelete func(MeshProvider) OnDelete func(MeshProvider)
} }
// GetRouteManager implements MeshManager.
func (m *MeshManagerImpl) GetRouteManager() RouteManager { func (m *MeshManagerImpl) GetRouteManager() RouteManager {
return m.RouteManager return m.RouteManager
} }
// RemoveService implements MeshManager. // RemoveService: remove a service from the given mesh.
func (m *MeshManagerImpl) RemoveService(service string) error { func (m *MeshManagerImpl) RemoveService(meshId, service string) error {
for _, mesh := range m.Meshes { mesh := m.GetMesh(meshId)
err := mesh.RemoveService(m.HostParameters.GetPublicKey(), service)
if err != nil { if mesh == nil {
return err return fmt.Errorf("mesh %s does not exist", meshId)
}
} }
return nil if !mesh.NodeExists(m.HostParameters.GetPublicKey()) {
} return fmt.Errorf("node %s does not exist in the mesh", meshId)
// SetService implements MeshManager.
func (m *MeshManagerImpl) SetService(service string, value string) error {
for _, mesh := range m.Meshes {
err := mesh.AddService(m.HostParameters.GetPublicKey(), service, value)
if err != nil {
return err
}
} }
return nil return mesh.RemoveService(m.HostParameters.GetPublicKey(), service)
} }
// SetService: add a service to the given mesh
func (m *MeshManagerImpl) SetService(meshId, service, value string) error {
mesh := m.GetMesh(meshId)
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
if !mesh.NodeExists(m.HostParameters.GetPublicKey()) {
return fmt.Errorf("node %s does not exist in the mesh", meshId)
}
return mesh.AddService(m.HostParameters.GetPublicKey(), service, value)
}
// GetNode: gets the node with given id in the mesh network
func (m *MeshManagerImpl) GetNode(meshid, nodeId string) MeshNode { func (m *MeshManagerImpl) GetNode(meshid, nodeId string) MeshNode {
mesh, ok := m.Meshes[meshid] mesh, ok := m.meshes[meshid]
if !ok { if !ok {
return nil return nil
@ -134,6 +140,10 @@ func (m *MeshManagerImpl) CreateMesh(args *CreateMeshParams) (string, error) {
return "", err return "", err
} }
if *meshConfiguration.Role == conf.CLIENT_ROLE {
return "", fmt.Errorf("cannot create mesh as a client")
}
meshId, err := m.idGenerator.GetId() meshId, err := m.idGenerator.GetId()
var ifName string = "" var ifName string = ""
@ -166,9 +176,9 @@ func (m *MeshManagerImpl) CreateMesh(args *CreateMeshParams) (string, error) {
return "", fmt.Errorf("error creating mesh: %w", err) return "", fmt.Errorf("error creating mesh: %w", err)
} }
m.lock.Lock() m.meshLock.Lock()
m.Meshes[meshId] = nodeManager m.meshes[meshId] = nodeManager
m.lock.Unlock() m.meshLock.Unlock()
m.cmdRunner.RunCommands(m.conf.BaseConfiguration.PostUp...) m.cmdRunner.RunCommands(m.conf.BaseConfiguration.PostUp...)
@ -182,7 +192,7 @@ type AddMeshParams struct {
Conf *conf.WgConfiguration Conf *conf.WgConfiguration
} }
// AddMesh: Add the mesh to the list of meshes // AddMesh: Add a new mesh network to the list of addresses
func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error { func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
var ifName string var ifName string
var err error var err error
@ -225,20 +235,20 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
return err return err
} }
m.lock.Lock() m.meshLock.Lock()
m.Meshes[params.MeshId] = meshProvider m.meshes[params.MeshId] = meshProvider
m.lock.Unlock() m.meshLock.Unlock()
return nil return nil
} }
// HasChanges returns true if the mesh has changes // HasChanges: returns true if the mesh has changes
func (m *MeshManagerImpl) HasChanges(meshId string) bool { func (m *MeshManagerImpl) HasChanges(meshId string) bool {
return m.Meshes[meshId].HasChanges() return m.meshes[meshId].HasChanges()
} }
// GetMesh returns the mesh with the given meshid // GetMesh: returns the mesh with the given meshid
func (m *MeshManagerImpl) GetMesh(meshId string) MeshProvider { func (m *MeshManagerImpl) GetMesh(meshId string) MeshProvider {
theMesh := m.Meshes[meshId] theMesh := m.meshes[meshId]
return theMesh return theMesh
} }
@ -248,6 +258,8 @@ func (s *MeshManagerImpl) GetPublicKey() *wgtypes.Key {
return &key return &key
} }
// AddSelfParams: parameters required to add yourself to a mesh
// network
type AddSelfParams struct { type AddSelfParams struct {
// MeshId is the ID of the mesh to add this instance to // MeshId is the ID of the mesh to add this instance to
MeshId string MeshId string
@ -257,7 +269,7 @@ type AddSelfParams struct {
Endpoint string Endpoint string
} }
// AddSelf adds this host to the mesh // AddSelf: adds this host to the mesh
func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error { func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
mesh := s.GetMesh(params.MeshId) mesh := s.GetMesh(params.MeshId)
@ -277,10 +289,36 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
pubKey := s.HostParameters.PrivateKey.PublicKey() pubKey := s.HostParameters.PrivateKey.PublicKey()
nodeIP, err := s.ipAllocator.GetIP(pubKey, params.MeshId) collisionCount := uint8(0)
if err != nil { var nodeIP net.IP
return err
// Perform Duplicate Address Detection with the nodes
// that are already in the network
for {
generatedIP, err := s.ipAllocator.GetIP(pubKey, params.MeshId, collisionCount)
if err != nil {
return err
}
snapshot, err := mesh.GetMesh()
if err != nil {
return err
}
proposition := func(node MeshNode) bool {
ipNet := node.GetWgHost()
return ipNet.IP.Equal(nodeIP)
}
if lib.Contains(lib.MapValues(snapshot.GetNodes()), proposition) {
collisionCount++
} else {
nodeIP = generatedIP
break
}
} }
node := s.nodeFactory.Build(&MeshNodeFactoryParams{ node := s.nodeFactory.Build(&MeshNodeFactoryParams{
@ -305,11 +343,11 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
} }
} }
s.Meshes[params.MeshId].AddNode(node) s.meshes[params.MeshId].AddNode(node)
return nil return nil
} }
// LeaveMesh leaves the mesh network // LeaveMesh: leaves the mesh network and force a synchronsiation
func (s *MeshManagerImpl) LeaveMesh(meshId string) error { func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
mesh := s.GetMesh(meshId) mesh := s.GetMesh(meshId)
@ -320,16 +358,16 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
err := mesh.RemoveNode(s.HostParameters.GetPublicKey()) err := mesh.RemoveNode(s.HostParameters.GetPublicKey())
if err != nil { if err != nil {
return err logging.Log.WriteErrorf(err.Error())
} }
if s.OnDelete != nil { if s.OnDelete != nil {
s.OnDelete(mesh) s.OnDelete(mesh)
} }
s.lock.Lock() s.meshLock.Lock()
delete(s.Meshes, meshId) delete(s.meshes, meshId)
s.lock.Unlock() s.meshLock.Unlock()
s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PreDown...) s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PreDown...)
@ -348,12 +386,11 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
} }
s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PostDown...) s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PostDown...)
return err return err
} }
func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) { func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
meshInstance, ok := s.Meshes[meshId] meshInstance, ok := s.meshes[meshId]
if !ok { if !ok {
return nil, fmt.Errorf("mesh %s does not exist", meshId) return nil, fmt.Errorf("mesh %s does not exist", meshId)
@ -368,51 +405,46 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
return node, nil return node, nil
} }
// ApplyConfig: applies the WireGuard configuration
// adds routes to the RIB and so forth.
func (s *MeshManagerImpl) ApplyConfig() error { func (s *MeshManagerImpl) ApplyConfig() error {
if s.conf.StubWg { if s.conf.StubWg {
return nil return nil
} }
return s.configApplyer.ApplyConfig()
err := s.configApplyer.ApplyConfig()
if err != nil {
return err
}
return nil
} }
func (s *MeshManagerImpl) SetDescription(description string) error { func (s *MeshManagerImpl) SetDescription(meshId, description string) error {
meshes := s.GetMeshes() mesh := s.GetMesh(meshId)
for _, mesh := range meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetDescription(s.HostParameters.GetPublicKey(), description)
if err != nil { if mesh == nil {
return err return fmt.Errorf("mesh %s does not exist", meshId)
}
}
} }
return nil if !mesh.NodeExists(s.HostParameters.GetPublicKey()) {
} return fmt.Errorf("node %s does not exist in the mesh", meshId)
// SetAlias implements MeshManager.
func (s *MeshManagerImpl) SetAlias(alias string) error {
meshes := s.GetMeshes()
for _, mesh := range meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetAlias(s.HostParameters.GetPublicKey(), alias)
if err != nil {
return err
}
}
} }
return nil
return mesh.SetDescription(s.HostParameters.GetPublicKey(), description)
} }
// UpdateTimeStamp updates the timestamp of this node in all meshes // SetAlias sets the alias of the node for the given meshid
func (s *MeshManagerImpl) SetAlias(meshId, alias string) error {
mesh := s.GetMesh(meshId)
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
if !mesh.NodeExists(s.HostParameters.GetPublicKey()) {
return fmt.Errorf("node %s does not exist in the mesh", meshId)
}
return mesh.SetAlias(s.HostParameters.GetPublicKey(), alias)
}
// UpdateTimeStamp: updates the timestamp of this node in all meshes
// essentially performs heartbeat if the node is the leader
func (s *MeshManagerImpl) UpdateTimeStamp() error { func (s *MeshManagerImpl) UpdateTimeStamp() error {
meshes := s.GetMeshes() meshes := s.GetMeshes()
for _, mesh := range meshes { for _, mesh := range meshes {
@ -432,26 +464,30 @@ func (s *MeshManagerImpl) GetClient() *wgctrl.Client {
return s.Client return s.Client
} }
// GetMeshes: get all meshes the node is part of
func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider { func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider {
meshes := make(map[string]MeshProvider) meshes := make(map[string]MeshProvider)
s.lock.RLock() // GetMesh: copies the map of meshes to a new map
// to prevent a whole range of concurrency issues
// due to iteration and modification
s.meshLock.RLock()
for id, mesh := range s.Meshes { for id, mesh := range s.meshes {
meshes[id] = mesh meshes[id] = mesh
} }
s.lock.RUnlock() s.meshLock.RUnlock()
return meshes return meshes
} }
// Close the mesh manager // Close: close the mesh manager
func (s *MeshManagerImpl) Close() error { func (s *MeshManagerImpl) Close() error {
if s.conf.StubWg { if s.conf.StubWg {
return nil return nil
} }
for _, mesh := range s.Meshes { for _, mesh := range s.meshes {
dev, err := mesh.GetDevice() dev, err := mesh.GetDevice()
if err != nil { if err != nil {
@ -468,7 +504,7 @@ func (s *MeshManagerImpl) Close() error {
return nil return nil
} }
// NewMeshManagerParams params required to create an instance of a mesh manager // NewMeshManagerParams: params required to create an instance of a mesh manager
type NewMeshManagerParams struct { type NewMeshManagerParams struct {
Conf conf.DaemonConfiguration Conf conf.DaemonConfiguration
Client *wgctrl.Client Client *wgctrl.Client
@ -483,7 +519,7 @@ type NewMeshManagerParams struct {
OnDelete func(MeshProvider) OnDelete func(MeshProvider)
} }
// Creates a new instance of a mesh manager with the given parameters // NewMeshManager: Creates a new instance of a mesh manager with the given parameters
func NewMeshManager(params *NewMeshManagerParams) MeshManager { func NewMeshManager(params *NewMeshManagerParams) MeshManager {
privateKey, _ := wgtypes.GeneratePrivateKey() privateKey, _ := wgtypes.GeneratePrivateKey()
hostParams := HostParameters{ hostParams := HostParameters{
@ -491,7 +527,7 @@ func NewMeshManager(params *NewMeshManagerParams) MeshManager {
} }
m := &MeshManagerImpl{ m := &MeshManagerImpl{
Meshes: make(map[string]MeshProvider), meshes: make(map[string]MeshProvider),
HostParameters: &hostParams, HostParameters: &hostParams,
meshProviderFactory: params.MeshProvider, meshProviderFactory: params.MeshProvider,
nodeFactory: params.NodeFactory, nodeFactory: params.NodeFactory,

View File

@ -3,10 +3,10 @@ package mesh
import ( import (
"testing" "testing"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/wg" "github.com/tim-beatham/smegmesh/pkg/wg"
) )
func getMeshConfiguration() *conf.DaemonConfiguration { func getMeshConfiguration() *conf.DaemonConfiguration {
@ -24,11 +24,11 @@ func getMeshConfiguration() *conf.DaemonConfiguration {
Timeout: 5, Timeout: 5,
Profile: false, Profile: false,
StubWg: true, StubWg: true,
SyncRate: 2, SyncInterval: 2,
KeepAliveTime: 60, Heartbeat: 60,
ClusterSize: 64, ClusterSize: 64,
InterClusterChance: 0.15, InterClusterChance: 0.15,
BranchRate: 3, Branch: 3,
InfectionCount: 3, InfectionCount: 3,
BaseConfiguration: conf.WgConfiguration{ BaseConfiguration: conf.WgConfiguration{
IPDiscovery: &ipDiscovery, IPDiscovery: &ipDiscovery,
@ -213,7 +213,7 @@ func TestLeaveMeshDeletesMesh(t *testing.T) {
} }
} }
func TestSetAlias(t *testing.T) { func TestSetAliasUpdatesAliasOfNode(t *testing.T) {
manager := getMeshManager() manager := getMeshManager()
alias := "Firpo" alias := "Firpo"
@ -221,14 +221,13 @@ func TestSetAlias(t *testing.T) {
Port: 5000, Port: 5000,
Conf: &conf.WgConfiguration{}, Conf: &conf.WgConfiguration{},
}) })
manager.AddSelf(&AddSelfParams{ manager.AddSelf(&AddSelfParams{
MeshId: meshId, MeshId: meshId,
WgPort: 5000, WgPort: 5000,
Endpoint: "abc.com:8080", Endpoint: "abc.com:8080",
}) })
err := manager.SetAlias(alias) err := manager.SetAlias(meshId, alias)
if err != nil { if err != nil {
t.Fatalf(`failed to set the alias`) t.Fatalf(`failed to set the alias`)
@ -245,7 +244,7 @@ func TestSetAlias(t *testing.T) {
} }
} }
func TestSetDescription(t *testing.T) { func TestSetDescriptionSetsTheDescriptionOfTheNode(t *testing.T) {
manager := getMeshManager() manager := getMeshManager()
description := "wooooo" description := "wooooo"
@ -254,23 +253,13 @@ func TestSetDescription(t *testing.T) {
Conf: &conf.WgConfiguration{}, Conf: &conf.WgConfiguration{},
}) })
meshId2, _ := manager.CreateMesh(&CreateMeshParams{
Port: 5001,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{ manager.AddSelf(&AddSelfParams{
MeshId: meshId1, MeshId: meshId1,
WgPort: 5000, WgPort: 5000,
Endpoint: "abc.com:8080", Endpoint: "abc.com:8080",
}) })
manager.AddSelf(&AddSelfParams{
MeshId: meshId2,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
err := manager.SetDescription(description) err := manager.SetDescription(meshId1, description)
if err != nil { if err != nil {
t.Fatalf(`failed to set the descriptions`) t.Fatalf(`failed to set the descriptions`)
@ -285,18 +274,7 @@ func TestSetDescription(t *testing.T) {
if description != self1.GetDescription() { if description != self1.GetDescription() {
t.Fatalf(`description should be %s was %s`, description, self1.GetDescription()) t.Fatalf(`description should be %s was %s`, description, self1.GetDescription())
} }
self2, err := manager.GetSelf(meshId2)
if err != nil {
t.Fatalf(`failed to set the description`)
}
if description != self2.GetDescription() {
t.Fatalf(`description should be %s was %s`, description, self2.GetDescription())
}
} }
func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) { func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) {
manager := getMeshManager() manager := getMeshManager()
@ -327,3 +305,68 @@ func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) {
t.Fatalf(`failed to update the timestamp`) t.Fatalf(`failed to update the timestamp`)
} }
} }
func TestAddServiceAddsServiceToTheMesh(t *testing.T) {
manager := getMeshManager()
meshId1, _ := manager.CreateMesh(&CreateMeshParams{
Port: 5000,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
serviceName := "hello"
manager.SetService(meshId1, serviceName, "dave")
self, err := manager.GetSelf(meshId1)
if err != nil {
t.Fatalf(`error thrown %s:`, err.Error())
}
if _, ok := self.GetServices()[serviceName]; !ok {
t.Fatalf(`service not added`)
}
}
func TestRemoveServiceRemovesTheServiceFromTheMesh(t *testing.T) {
manager := getMeshManager()
meshId1, _ := manager.CreateMesh(&CreateMeshParams{
Port: 5000,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
serviceName := "hello"
manager.SetService(meshId1, serviceName, "dave")
self, err := manager.GetSelf(meshId1)
if err != nil {
t.Fatalf(`error thrown %s:`, err.Error())
}
if _, ok := self.GetServices()[serviceName]; !ok {
t.Fatalf(`service not added`)
}
manager.RemoveService(meshId1, serviceName)
self, err = manager.GetSelf(meshId1)
if err != nil {
t.Fatalf(`error thrown %s:`, err.Error())
}
if _, ok := self.GetServices()[serviceName]; ok {
t.Fatalf(`service still exists`)
}
}

View File

@ -3,11 +3,13 @@ package mesh
import ( import (
"net" "net"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
) )
// RouteManager: manager that leaks routes between meshes
type RouteManager interface { type RouteManager interface {
// UpdateRoutes: leak all routes in each mesh
UpdateRoutes() error UpdateRoutes() error
} }
@ -19,12 +21,17 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
meshes := r.meshManager.GetMeshes() meshes := r.meshManager.GetMeshes()
routes := make(map[string][]Route) routes := make(map[string][]Route)
for _, mesh := range meshes {
// Make empty routes so that routes are retracted
routes[mesh.GetMeshId()] = make([]Route, 0)
}
for _, mesh1 := range meshes { for _, mesh1 := range meshes {
if !*mesh1.GetConfiguration().AdvertiseRoutes { if !*mesh1.GetConfiguration().AdvertiseRoutes {
continue continue
} }
self, err := r.meshManager.GetSelf(mesh1.GetMeshId()) self, err := mesh1.GetNode(r.meshManager.GetPublicKey().String())
if err != nil { if err != nil {
return err return err
@ -39,7 +46,6 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
defaultRoute := &RouteStub{ defaultRoute := &RouteStub{
Destination: ipv6Default, Destination: ipv6Default,
HopCount: 0,
Path: []string{mesh1.GetMeshId()}, Path: []string{mesh1.GetMeshId()},
} }
@ -68,7 +74,6 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
routeValues = append(routeValues, &RouteStub{ routeValues = append(routeValues, &RouteStub{
Destination: mesh1IpNet, Destination: mesh1IpNet,
HopCount: 0,
Path: []string{mesh1.GetMeshId()}, Path: []string{mesh1.GetMeshId()},
}) })
@ -90,15 +95,21 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
// Calculate the set different of each, working out routes to remove and to keep. // Calculate the set different of each, working out routes to remove and to keep.
for meshId, meshRoutes := range routes { for meshId, meshRoutes := range routes {
mesh := r.meshManager.GetMesh(meshId) mesh := meshes[meshId]
self, _ := r.meshManager.GetSelf(meshId)
self, err := mesh.GetNode(r.meshManager.GetPublicKey().String())
if err != nil {
return err
}
toRemove := make([]Route, 0) toRemove := make([]Route, 0)
prevRoutes, _ := mesh.GetRoutes(NodeID(self)) prevRoutes := self.GetRoutes()
for _, route := range prevRoutes { for _, route := range prevRoutes {
if !lib.Contains(meshRoutes, func(r Route) bool { if !lib.Contains(meshRoutes, func(r Route) bool {
return RouteEquals(r, route) return RouteEqual(r, route)
}) { }) {
toRemove = append(toRemove, route) toRemove = append(toRemove, route)
} }

View File

@ -5,8 +5,8 @@ import (
"net" "net"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@ -30,8 +30,8 @@ func (*MeshNodeStub) GetType() conf.NodeType {
} }
// GetServices implements MeshNode. // GetServices implements MeshNode.
func (*MeshNodeStub) GetServices() map[string]string { func (m *MeshNodeStub) GetServices() map[string]string {
return make(map[string]string) return m.services
} }
// GetAlias implements MeshNode. // GetAlias implements MeshNode.
@ -249,6 +249,7 @@ func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode {
routes: make([]Route, 0), routes: make([]Route, 0),
identifier: "abc", identifier: "abc",
description: "A Mesh Node Stub", description: "A Mesh Node Stub",
services: make(map[string]string),
} }
} }
@ -271,32 +272,32 @@ type MeshManagerStub struct {
// GetRouteManager implements MeshManager. // GetRouteManager implements MeshManager.
func (*MeshManagerStub) GetRouteManager() RouteManager { func (*MeshManagerStub) GetRouteManager() RouteManager {
panic("unimplemented") return nil
} }
// GetNode implements MeshManager. // GetNode implements MeshManager.
func (*MeshManagerStub) GetNode(string, string) MeshNode { func (*MeshManagerStub) GetNode(meshId, nodeId string) MeshNode {
panic("unimplemented") return nil
} }
// RemoveService implements MeshManager. // RemoveService implements MeshManager.
func (*MeshManagerStub) RemoveService(service string) error { func (*MeshManagerStub) RemoveService(meshId, service string) error {
panic("unimplemented") return nil
} }
// SetService implements MeshManager. // SetService implements MeshManager.
func (*MeshManagerStub) SetService(service string, value string) error { func (*MeshManagerStub) SetService(meshId, service, value string) error {
panic("unimplemented") return nil
} }
// SetAlias implements MeshManager. // SetAlias implements MeshManager.
func (*MeshManagerStub) SetAlias(alias string) error { func (*MeshManagerStub) SetAlias(meshId, alias string) error {
panic("unimplemented") return nil
} }
// Close implements MeshManager. // Close implements MeshManager.
func (*MeshManagerStub) Close() error { func (*MeshManagerStub) Close() error {
panic("unimplemented") return nil
} }
// Prune implements MeshManager. // Prune implements MeshManager.
@ -348,7 +349,7 @@ func (m *MeshManagerStub) ApplyConfig() error {
return nil return nil
} }
func (m *MeshManagerStub) SetDescription(description string) error { func (m *MeshManagerStub) SetDescription(meshId, description string) error {
return nil return nil
} }

View File

@ -6,7 +6,7 @@ import (
"net" "net"
"slices" "slices"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@ -21,12 +21,6 @@ type Route interface {
} }
func RouteEqual(r1 Route, r2 Route) bool { func RouteEqual(r1 Route, r2 Route) bool {
return r1.GetDestination().IP.Equal(r2.GetDestination().IP) &&
r1.GetHopCount() == r2.GetHopCount() &&
slices.Equal(r1.GetPath(), r2.GetPath())
}
func RouteEquals(r1, r2 Route) bool {
return r1.GetDestination().String() == r2.GetDestination().String() && return r1.GetDestination().String() == r2.GetDestination().String() &&
r1.GetHopCount() == r2.GetHopCount() && r1.GetHopCount() == r2.GetHopCount() &&
slices.Equal(r1.GetPath(), r2.GetPath()) slices.Equal(r1.GetPath(), r2.GetPath())
@ -34,7 +28,6 @@ func RouteEquals(r1, r2 Route) bool {
type RouteStub struct { type RouteStub struct {
Destination *net.IPNet Destination *net.IPNet
HopCount int
Path []string Path []string
} }
@ -43,7 +36,7 @@ func (r *RouteStub) GetDestination() *net.IPNet {
} }
func (r *RouteStub) GetHopCount() int { func (r *RouteStub) GetHopCount() int {
return r.HopCount return len(r.Path)
} }
func (r *RouteStub) GetPath() []string { func (r *RouteStub) GetPath() []string {
@ -81,6 +74,10 @@ func NodeEquals(node1, node2 MeshNode) bool {
key1, _ := node1.GetPublicKey() key1, _ := node1.GetPublicKey()
key2, _ := node2.GetPublicKey() key2, _ := node2.GetPublicKey()
if node1 == nil || node2 == nil {
return false
}
return key1.String() == key2.String() return key1.String() == key2.String()
} }

View File

@ -6,9 +6,9 @@ import (
"strings" "strings"
"github.com/jmespath/go-jmespath" "github.com/jmespath/go-jmespath"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
) )
// Querier queries a data store for the given data // Querier queries a data store for the given data
@ -17,20 +17,24 @@ type Querier interface {
Query(meshId string, queryParams string) ([]byte, error) Query(meshId string, queryParams string) ([]byte, error)
} }
// JmesQuerier: queries the datstore in JMESPath syntax
type JmesQuerier struct { type JmesQuerier struct {
manager mesh.MeshManager manager mesh.MeshManager
} }
// QueryError: query error if something went wrong
type QueryError struct { type QueryError struct {
msg string msg string
} }
// QuerRoute: represents a route in the query
type QueryRoute struct { type QueryRoute struct {
Destination string `json:"destination"` Destination string `json:"destination"`
HopCount int `json:"hopCount"` HopCount int `json:"hopCount"`
Path string `json:"path"` Path string `json:"path"`
} }
// QueryNode: represents a single node in the query
type QueryNode struct { type QueryNode struct {
HostEndpoint string `json:"hostEndpoint"` HostEndpoint string `json:"hostEndpoint"`
PublicKey string `json:"publicKey"` PublicKey string `json:"publicKey"`
@ -48,7 +52,7 @@ func (m *QueryError) Error() string {
return m.msg return m.msg
} }
// Query: queries the data // Query: queries the the datastore at the given meshid
func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) { func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
mesh, ok := j.manager.GetMeshes()[meshId] mesh, ok := j.manager.GetMeshes()[meshId]
@ -74,6 +78,7 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
return bytes, err return bytes, err
} }
// MeshNodeToQuerynode: convert the mesh node into a query abstraction
func MeshNodeToQueryNode(node mesh.MeshNode) *QueryNode { func MeshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode := new(QueryNode) queryNode := new(QueryNode)
queryNode.HostEndpoint = node.GetHostEndpoint() queryNode.HostEndpoint = node.GetHostEndpoint()

View File

@ -4,20 +4,22 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"strconv" "slices"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/smegmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/smegmesh/pkg/rpc"
) )
// IpcHandler: represents a handler for ipc calls
type IpcHandler struct { type IpcHandler struct {
Server ctrlserver.CtrlServer Server ctrlserver.CtrlServer
} }
// getOverrideConfiguration: override any specific WireGuard configuration
func getOverrideConfiguration(args *ipc.WireGuardArgs) conf.WgConfiguration { func getOverrideConfiguration(args *ipc.WireGuardArgs) conf.WgConfiguration {
overrideConf := conf.WgConfiguration{} overrideConf := conf.WgConfiguration{}
@ -40,20 +42,17 @@ func getOverrideConfiguration(args *ipc.WireGuardArgs) conf.WgConfiguration {
return overrideConf return overrideConf
} }
// CreateMesh: create a new mesh network
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error { func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
overrideConf := getOverrideConfiguration(&args.WgArgs) overrideConf := getOverrideConfiguration(&args.WgArgs)
if overrideConf.Role != nil && *overrideConf.Role == conf.CLIENT_ROLE {
return fmt.Errorf("cannot create a mesh with no public endpoint")
}
meshId, err := n.Server.GetMeshManager().CreateMesh(&mesh.CreateMeshParams{ meshId, err := n.Server.GetMeshManager().CreateMesh(&mesh.CreateMeshParams{
Port: args.WgArgs.WgPort, Port: args.WgArgs.WgPort,
Conf: &overrideConf, Conf: &overrideConf,
}) })
if err != nil { if err != nil {
return err return errors.New("could not create mesh")
} }
err = n.Server.GetMeshManager().AddSelf(&mesh.AddSelfParams{ err = n.Server.GetMeshManager().AddSelf(&mesh.AddSelfParams{
@ -63,13 +62,14 @@ func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
}) })
if err != nil { if err != nil {
return err return errors.New("could not create mesh")
} }
*reply = meshId *reply = meshId
return err return err
} }
// ListMeshes: list mesh networks
func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error { func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
meshNames := make([]string, len(n.Server.GetMeshManager().GetMeshes())) meshNames := make([]string, len(n.Server.GetMeshManager().GetMeshes()))
@ -79,29 +79,35 @@ func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
i++ i++
} }
slices.Sort(meshNames)
*reply = ipc.ListMeshReply{Meshes: meshNames} *reply = ipc.ListMeshReply{Meshes: meshNames}
return nil return nil
} }
func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { // JoinMesh: join a mesh network
func (n *IpcHandler) JoinMesh(args *ipc.JoinMeshArgs, reply *string) error {
overrideConf := getOverrideConfiguration(&args.WgArgs) overrideConf := getOverrideConfiguration(&args.WgArgs)
peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAdress) if n.Server.GetMeshManager().GetMesh(args.MeshId) != nil {
return fmt.Errorf("user is already apart of the mesh")
}
peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAddress)
if err != nil { if err != nil {
return err return fmt.Errorf("could not join mesh %s", args.MeshId)
} }
client, err := peerConnection.GetClient() client, err := peerConnection.GetClient()
if err != nil { if err != nil {
return err return fmt.Errorf("could not join mesh %s", args.MeshId)
} }
c := rpc.NewMeshCtrlServerClient(client) c := rpc.NewMeshCtrlServerClient(client)
if err != nil { if err != nil {
return err return fmt.Errorf("could not join mesh %s", args.MeshId)
} }
configuration := n.Server.GetConfiguration() configuration := n.Server.GetConfiguration()
@ -112,7 +118,7 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
meshReply, err := c.GetMesh(ctx, &rpc.GetMeshRequest{MeshId: args.MeshId}) meshReply, err := c.GetMesh(ctx, &rpc.GetMeshRequest{MeshId: args.MeshId})
if err != nil { if err != nil {
return err return fmt.Errorf("could not join mesh %s", args.MeshId)
} }
err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{ err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
@ -123,7 +129,7 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
}) })
if err != nil { if err != nil {
return err return fmt.Errorf("could not join mesh %s", args.MeshId)
} }
err = n.Server.GetMeshManager().AddSelf(&mesh.AddSelfParams{ err = n.Server.GetMeshManager().AddSelf(&mesh.AddSelfParams{
@ -133,24 +139,24 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
}) })
if err != nil { if err != nil {
return err return fmt.Errorf("could not join mesh %s", args.MeshId)
} }
*reply = strconv.FormatBool(true) *reply = fmt.Sprintf("Successfully Joined: %s", args.MeshId)
return nil return nil
} }
// LeaveMesh leaves a mesh network // LeaveMesh: leaves a mesh network
func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error { func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error {
err := n.Server.GetMeshManager().LeaveMesh(meshId) err := n.Server.GetMeshManager().LeaveMesh(meshId)
if err == nil { if err == nil {
*reply = fmt.Sprintf("Left Mesh %s", meshId) *reply = fmt.Sprintf("Left Mesh %s", meshId)
} }
return err return err
} }
// GetMesh: get a mesh network at the given meshid
func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error { func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
theMesh := n.Server.GetMeshManager().GetMesh(meshId) theMesh := n.Server.GetMeshManager().GetMesh(meshId)
@ -182,19 +188,7 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
return nil return nil
} }
func (n *IpcHandler) GetDOT(meshId string, reply *string) error { // Query: perform a jmespath query
g := mesh.NewMeshDotConverter(n.Server.GetMeshManager())
result, err := g.Generate(meshId)
if err != nil {
return err
}
*reply = result
return nil
}
func (n *IpcHandler) Query(params ipc.QueryMesh, reply *string) error { func (n *IpcHandler) Query(params ipc.QueryMesh, reply *string) error {
queryResponse, err := n.Server.GetQuerier().Query(params.MeshId, params.Query) queryResponse, err := n.Server.GetQuerier().Query(params.MeshId, params.Query)
@ -206,50 +200,59 @@ func (n *IpcHandler) Query(params ipc.QueryMesh, reply *string) error {
return nil return nil
} }
func (n *IpcHandler) PutDescription(description string, reply *string) error { // PutDescription: change your description in the mesh
err := n.Server.GetMeshManager().SetDescription(description) func (n *IpcHandler) PutDescription(args ipc.PutDescriptionArgs, reply *string) error {
err := n.Server.GetMeshManager().SetDescription(args.MeshId, args.Description)
if err != nil { if err != nil {
return err return err
} }
*reply = fmt.Sprintf("Set description to %s", description) *reply = fmt.Sprintf("set description to %s for %s", args.Description, args.MeshId)
return nil return nil
} }
func (n *IpcHandler) PutAlias(alias string, reply *string) error { // PutAlias: put your aliasin the mesh
err := n.Server.GetMeshManager().SetAlias(alias) func (n *IpcHandler) PutAlias(args ipc.PutAliasArgs, reply *string) error {
if args.Alias == "" {
if err != nil { return fmt.Errorf("alias not provided")
return err
} }
*reply = fmt.Sprintf("Set alias to %s", alias) err := n.Server.GetMeshManager().SetAlias(args.MeshId, args.Alias)
if err != nil {
return fmt.Errorf("could not set alias: %s", args.Alias)
}
*reply = fmt.Sprintf("Set alias to %s", args.Alias)
return nil return nil
} }
// PutService: place a service in the mesh
func (n *IpcHandler) PutService(service ipc.PutServiceArgs, reply *string) error { func (n *IpcHandler) PutService(service ipc.PutServiceArgs, reply *string) error {
err := n.Server.GetMeshManager().SetService(service.Service, service.Value) err := n.Server.GetMeshManager().SetService(service.MeshId, service.Service, service.Value)
if err != nil { if err != nil {
return err return err
} }
*reply = "success" *reply = fmt.Sprintf("Set service %s in %s to %s", service.Service, service.MeshId, service.Value)
return nil return nil
} }
func (n *IpcHandler) DeleteService(service string, reply *string) error { // DeleteService: withtract a service in the mesh
err := n.Server.GetMeshManager().RemoveService(service) func (n *IpcHandler) DeleteService(service ipc.DeleteServiceArgs, reply *string) error {
err := n.Server.GetMeshManager().RemoveService(service.MeshId, service.Service)
if err != nil { if err != nil {
return err return err
} }
*reply = "success" *reply = fmt.Sprintf("Removed service %s from %s", service.Service, service.MeshId)
return nil return nil
} }
// RobinIpcParams: parameters required to construct a new mesh network
type RobinIpcParams struct { type RobinIpcParams struct {
CtrlServer ctrlserver.CtrlServer CtrlServer ctrlserver.CtrlServer
} }

View File

@ -3,10 +3,10 @@ package robin
import ( import (
"testing" "testing"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/smegmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
) )
func getRequester() *IpcHandler { func getRequester() *IpcHandler {

View File

@ -4,15 +4,17 @@ import (
"context" "context"
"errors" "errors"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/smegmesh/pkg/rpc"
) )
// WgRpc: represents a WireGuard rpc call
type WgRpc struct { type WgRpc struct {
rpc.UnimplementedMeshCtrlServerServer rpc.UnimplementedMeshCtrlServerServer
Server *ctrlserver.MeshCtrlServer Server *ctrlserver.MeshCtrlServer
} }
// GetMesh: serialise the mesh network into bytes
func (m *WgRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.GetMeshReply, error) { func (m *WgRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.GetMeshReply, error) {
mesh := m.Server.MeshManager.GetMesh(request.MeshId) mesh := m.Server.MeshManager.GetMesh(request.MeshId)

View File

@ -1 +0,0 @@
package robin

View File

@ -1,10 +1,11 @@
package route package route
import ( import (
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// RouteInstaller: install the routes to the given interface
type RouteInstaller interface { type RouteInstaller interface {
InstallRoutes(devName string, routes ...lib.Route) error InstallRoutes(devName string, routes ...lib.Route) error
} }
@ -19,6 +20,8 @@ func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...lib.Route)
return err return err
} }
defer rtnl.Close()
err = rtnl.DeleteRoutes(devName, unix.AF_INET6, routes...) err = rtnl.DeleteRoutes(devName, unix.AF_INET6, routes...)
if err != nil { if err != nil {

View File

@ -1,235 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.1
// protoc v3.21.12
// source: pkg/grpc/ctrlserver/authentication.proto
package rpc
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type JoinAuthMeshRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
MeshId string `protobuf:"bytes,1,opt,name=meshId,proto3" json:"meshId,omitempty"`
Alias string `protobuf:"bytes,2,opt,name=alias,proto3" json:"alias,omitempty"`
}
func (x *JoinAuthMeshRequest) Reset() {
*x = JoinAuthMeshRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *JoinAuthMeshRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JoinAuthMeshRequest) ProtoMessage() {}
func (x *JoinAuthMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JoinAuthMeshRequest.ProtoReflect.Descriptor instead.
func (*JoinAuthMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_authentication_proto_rawDescGZIP(), []int{0}
}
func (x *JoinAuthMeshRequest) GetMeshId() string {
if x != nil {
return x.MeshId
}
return ""
}
func (x *JoinAuthMeshRequest) GetAlias() string {
if x != nil {
return x.Alias
}
return ""
}
type JoinAuthMeshReply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"`
Token *string `protobuf:"bytes,2,opt,name=token,proto3,oneof" json:"token,omitempty"`
}
func (x *JoinAuthMeshReply) Reset() {
*x = JoinAuthMeshReply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *JoinAuthMeshReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JoinAuthMeshReply) ProtoMessage() {}
func (x *JoinAuthMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JoinAuthMeshReply.ProtoReflect.Descriptor instead.
func (*JoinAuthMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_authentication_proto_rawDescGZIP(), []int{1}
}
func (x *JoinAuthMeshReply) GetSuccess() bool {
if x != nil {
return x.Success
}
return false
}
func (x *JoinAuthMeshReply) GetToken() string {
if x != nil && x.Token != nil {
return *x.Token
}
return ""
}
var File_pkg_grpc_ctrlserver_authentication_proto protoreflect.FileDescriptor
var file_pkg_grpc_ctrlserver_authentication_proto_rawDesc = []byte{
0x0a, 0x28, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73,
0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x61, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61,
0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x72, 0x70, 0x63, 0x74,
0x79, 0x70, 0x65, 0x73, 0x22, 0x43, 0x0a, 0x13, 0x4a, 0x6f, 0x69, 0x6e, 0x41, 0x75, 0x74, 0x68,
0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x6d,
0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73,
0x68, 0x49, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x61, 0x6c, 0x69, 0x61, 0x73, 0x18, 0x02, 0x20, 0x01,
0x28, 0x09, 0x52, 0x05, 0x61, 0x6c, 0x69, 0x61, 0x73, 0x22, 0x52, 0x0a, 0x11, 0x4a, 0x6f, 0x69,
0x6e, 0x41, 0x75, 0x74, 0x68, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x18,
0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52,
0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x19, 0x0a, 0x05, 0x74, 0x6f, 0x6b, 0x65,
0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e,
0x88, 0x01, 0x01, 0x42, 0x08, 0x0a, 0x06, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x32, 0x5a, 0x0a,
0x0e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12,
0x48, 0x0a, 0x08, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x1d, 0x2e, 0x72, 0x70,
0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x4d,
0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x4d, 0x65,
0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x42, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67,
0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_pkg_grpc_ctrlserver_authentication_proto_rawDescOnce sync.Once
file_pkg_grpc_ctrlserver_authentication_proto_rawDescData = file_pkg_grpc_ctrlserver_authentication_proto_rawDesc
)
func file_pkg_grpc_ctrlserver_authentication_proto_rawDescGZIP() []byte {
file_pkg_grpc_ctrlserver_authentication_proto_rawDescOnce.Do(func() {
file_pkg_grpc_ctrlserver_authentication_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_grpc_ctrlserver_authentication_proto_rawDescData)
})
return file_pkg_grpc_ctrlserver_authentication_proto_rawDescData
}
var file_pkg_grpc_ctrlserver_authentication_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_pkg_grpc_ctrlserver_authentication_proto_goTypes = []interface{}{
(*JoinAuthMeshRequest)(nil), // 0: rpctypes.JoinAuthMeshRequest
(*JoinAuthMeshReply)(nil), // 1: rpctypes.JoinAuthMeshReply
}
var file_pkg_grpc_ctrlserver_authentication_proto_depIdxs = []int32{
0, // 0: rpctypes.Authentication.JoinMesh:input_type -> rpctypes.JoinAuthMeshRequest
1, // 1: rpctypes.Authentication.JoinMesh:output_type -> rpctypes.JoinAuthMeshReply
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_pkg_grpc_ctrlserver_authentication_proto_init() }
func file_pkg_grpc_ctrlserver_authentication_proto_init() {
if File_pkg_grpc_ctrlserver_authentication_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*JoinAuthMeshRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*JoinAuthMeshReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[1].OneofWrappers = []interface{}{}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pkg_grpc_ctrlserver_authentication_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_pkg_grpc_ctrlserver_authentication_proto_goTypes,
DependencyIndexes: file_pkg_grpc_ctrlserver_authentication_proto_depIdxs,
MessageInfos: file_pkg_grpc_ctrlserver_authentication_proto_msgTypes,
}.Build()
File_pkg_grpc_ctrlserver_authentication_proto = out.File
file_pkg_grpc_ctrlserver_authentication_proto_rawDesc = nil
file_pkg_grpc_ctrlserver_authentication_proto_goTypes = nil
file_pkg_grpc_ctrlserver_authentication_proto_depIdxs = nil
}

View File

@ -1,105 +0,0 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.2.0
// - protoc v3.21.12
// source: pkg/grpc/ctrlserver/authentication.proto
package rpc
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// AuthenticationClient is the client API for Authentication service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type AuthenticationClient interface {
JoinMesh(ctx context.Context, in *JoinAuthMeshRequest, opts ...grpc.CallOption) (*JoinAuthMeshReply, error)
}
type authenticationClient struct {
cc grpc.ClientConnInterface
}
func NewAuthenticationClient(cc grpc.ClientConnInterface) AuthenticationClient {
return &authenticationClient{cc}
}
func (c *authenticationClient) JoinMesh(ctx context.Context, in *JoinAuthMeshRequest, opts ...grpc.CallOption) (*JoinAuthMeshReply, error) {
out := new(JoinAuthMeshReply)
err := c.cc.Invoke(ctx, "/rpctypes.Authentication/JoinMesh", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// AuthenticationServer is the server API for Authentication service.
// All implementations must embed UnimplementedAuthenticationServer
// for forward compatibility
type AuthenticationServer interface {
JoinMesh(context.Context, *JoinAuthMeshRequest) (*JoinAuthMeshReply, error)
mustEmbedUnimplementedAuthenticationServer()
}
// UnimplementedAuthenticationServer must be embedded to have forward compatible implementations.
type UnimplementedAuthenticationServer struct {
}
func (UnimplementedAuthenticationServer) JoinMesh(context.Context, *JoinAuthMeshRequest) (*JoinAuthMeshReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method JoinMesh not implemented")
}
func (UnimplementedAuthenticationServer) mustEmbedUnimplementedAuthenticationServer() {}
// UnsafeAuthenticationServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to AuthenticationServer will
// result in compilation errors.
type UnsafeAuthenticationServer interface {
mustEmbedUnimplementedAuthenticationServer()
}
func RegisterAuthenticationServer(s grpc.ServiceRegistrar, srv AuthenticationServer) {
s.RegisterService(&Authentication_ServiceDesc, srv)
}
func _Authentication_JoinMesh_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(JoinAuthMeshRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(AuthenticationServer).JoinMesh(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/rpctypes.Authentication/JoinMesh",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(AuthenticationServer).JoinMesh(ctx, req.(*JoinAuthMeshRequest))
}
return interceptor(ctx, in, info, handler)
}
// Authentication_ServiceDesc is the grpc.ServiceDesc for Authentication service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var Authentication_ServiceDesc = grpc.ServiceDesc{
ServiceName: "rpctypes.Authentication",
HandlerType: (*AuthenticationServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "JoinMesh",
Handler: _Authentication_JoinMesh_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "pkg/grpc/ctrlserver/authentication.proto",
}

View File

@ -1,201 +1,279 @@
package sync package sync
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"math/rand" "math/rand"
"sync"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn" "github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
) )
// Syncer: picks random nodes from the meshs // Syncer: picks random nodes from the meshs
type Syncer interface { type Syncer interface {
Sync(meshId string) error Sync(theMesh mesh.MeshProvider) (bool, error)
SyncMeshes() error SyncMeshes() error
} }
// SyncerImpl: implementation of a syncer to sync meshes
type SyncerImpl struct { type SyncerImpl struct {
manager mesh.MeshManager meshManager mesh.MeshManager
requester SyncRequester requester SyncRequester
infectionCount int infectionCount int
syncCount int syncCount int
cluster conn.ConnCluster cluster conn.ConnCluster
conf *conf.DaemonConfiguration configuration *conf.DaemonConfiguration
lastSync map[string]uint64 lastSync map[string]int64
lastPoll map[string]int64
lastSyncLock sync.RWMutex
lastPollLock sync.RWMutex
} }
// Sync: Sync random nodes // Sync: Sync with random nodes. Returns true if there was changes false otherwise
func (s *SyncerImpl) Sync(meshId string) error { func (s *SyncerImpl) Sync(correspondingMesh mesh.MeshProvider) (bool, error) {
// Self can be nil if the node is removed if correspondingMesh == nil {
self, _ := s.manager.GetSelf(meshId) return false, fmt.Errorf("mesh provided was nil cannot sync nil mesh")
}
correspondingMesh := s.manager.GetMesh(meshId) // Self can be nil if the node is removed
selfID := s.meshManager.GetPublicKey()
self, _ := correspondingMesh.GetNode(selfID.String())
correspondingMesh.Prune() correspondingMesh.Prune()
if self != nil && self.GetType() == conf.PEER_ROLE && !s.manager.HasChanges(meshId) && s.infectionCount == 0 { if correspondingMesh.HasChanges() {
logging.Log.WriteInfof("No changes for %s", meshId) logging.Log.WriteInfof("meshes %s has changes", correspondingMesh.GetMeshId())
}
// If not synchronised in certain pull from random neighbour // If removed sync with other nodes to gossip the node is removed
if uint64(time.Now().Unix())-s.lastSync[meshId] > 20 { if self != nil && self.GetType() == conf.PEER_ROLE && !correspondingMesh.HasChanges() && s.infectionCount == 0 {
return s.Pull(meshId) logging.Log.WriteInfof("no changes for %s", correspondingMesh.GetMeshId())
// If not synchronised in certain time pull from random neighbour
if s.configuration.PullInterval != 0 && time.Now().Unix()-s.lastSync[correspondingMesh.GetMeshId()] > int64(s.configuration.PullInterval) {
return s.Pull(self, correspondingMesh)
} }
return nil return false, nil
} }
before := time.Now() before := time.Now()
s.manager.GetRouteManager().UpdateRoutes()
publicKey := s.manager.GetPublicKey()
logging.Log.WriteInfof(publicKey.String())
publicKey := s.meshManager.GetPublicKey()
nodeNames := correspondingMesh.GetPeers() nodeNames := correspondingMesh.GetPeers()
if self != nil { nodeNames = lib.Filter(nodeNames, func(s string) bool {
nodeNames = lib.Filter(nodeNames, func(s string) bool { // Filter our only public key out so we dont sync with ourself
return s != mesh.NodeID(self) return s != publicKey.String()
}) })
}
var gossipNodes []string var gossipNodes []string
// Clients always pings its peer for configuration // Clients always pings its peer for configuration
if self != nil && self.GetType() == conf.CLIENT_ROLE { if self != nil && self.GetType() == conf.CLIENT_ROLE && len(nodeNames) > 1 {
keyFunc := lib.HashString neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
bucketFunc := lib.HashString
neighbour := lib.ConsistentHash(nodeNames, publicKey.String(), keyFunc, bucketFunc) if len(neighbours) == 0 {
gossipNodes = make([]string, 1) return false, nil
gossipNodes[0] = neighbour }
// Peer with 2 nodes so that there is redundnacy in
// the situation the node leaves pre-emptively
redundancyLength := min(len(neighbours), 2)
gossipNodes = neighbours[:redundancyLength]
} else { } else {
neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String()) neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
gossipNodes = lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate) gossipNodes = lib.RandomSubsetOfLength(neighbours, s.configuration.Branch)
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance { if len(nodeNames) > s.configuration.ClusterSize && rand.Float64() < s.configuration.InterClusterChance {
gossipNodes[len(gossipNodes)-1] = s.cluster.GetInterCluster(nodeNames, publicKey.String()) gossipNodes[len(gossipNodes)-1] = s.cluster.GetInterCluster(nodeNames, publicKey.String())
} }
} }
var succeeded bool = false var succeeded bool = false
// Do this synchronously to conserve bandwidth var wait sync.WaitGroup
for _, node := range gossipNodes {
correspondingPeer := s.manager.GetNode(meshId, node)
if correspondingPeer == nil { for index, node := range gossipNodes {
logging.Log.WriteErrorf("node %s does not exist", node) wait.Add(1)
syncNode := func(i int) {
correspondingPeer, err := correspondingMesh.GetNode(node)
defer wait.Done()
if correspondingPeer == nil || err != nil {
logging.Log.WriteErrorf("node %s does not exist", node)
return
}
err = s.requester.SyncMesh(correspondingMesh, correspondingPeer)
if err == nil || err == io.EOF {
succeeded = true
}
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
} }
err := s.requester.SyncMesh(meshId, correspondingPeer) go syncNode(index)
if err == nil || err == io.EOF {
succeeded = true
} else {
// If the synchronisation operation has failed them mark a gravestone
// preventing the peer from being re-contacted until it has updated
// itself
s.manager.GetMesh(meshId).Mark(node)
}
} }
s.syncCount++ wait.Wait()
logging.Log.WriteInfof("SYNC TIME: %v", time.Since(before))
logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount)
s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount) s.syncCount++
logging.Log.WriteInfof("sync time: %v", time.Since(before))
logging.Log.WriteInfof("number of syncs: %d", s.syncCount)
s.infectionCount = ((s.configuration.InfectionCount + s.infectionCount - 1) % s.configuration.InfectionCount)
if !succeeded { if !succeeded {
// If could not gossip with anyone then repeat.
s.infectionCount++ s.infectionCount++
} }
s.manager.GetMesh(meshId).SaveChanges() changes := correspondingMesh.HasChanges()
s.lastSync[meshId] = uint64(time.Now().Unix()) correspondingMesh.SaveChanges()
logging.Log.WriteInfof("UPDATING WG CONF") s.lastSyncLock.Lock()
err := s.manager.ApplyConfig() s.lastSync[correspondingMesh.GetMeshId()] = time.Now().Unix()
s.lastSyncLock.Unlock()
if err != nil { return changes, nil
logging.Log.WriteInfof("Failed to update config %w", err)
}
return nil
} }
// Pull one node in the cluster, if there has not been message dissemination // Pull one node in the cluster, if there has not been message dissemination
// in a certain period of time pull a random node within the cluster // in a certain period of time pull a random node within the cluster
func (s *SyncerImpl) Pull(meshId string) error { func (s *SyncerImpl) Pull(self mesh.MeshNode, mesh mesh.MeshProvider) (bool, error) {
mesh := s.manager.GetMesh(meshId) peers := mesh.GetPeers()
self, err := s.manager.GetSelf(meshId)
if err != nil {
return err
}
pubKey, _ := self.GetPublicKey() pubKey, _ := self.GetPublicKey()
if mesh == nil {
return errors.New("mesh is nil, invalid operation")
}
peers := mesh.GetPeers()
neighbours := s.cluster.GetNeighbours(peers, pubKey.String()) neighbours := s.cluster.GetNeighbours(peers, pubKey.String())
neighbour := lib.RandomSubsetOfLength(neighbours, 1) neighbour := lib.RandomSubsetOfLength(neighbours, 1)
if len(neighbour) == 0 { if len(neighbour) == 0 {
logging.Log.WriteInfof("no neighbours") logging.Log.WriteInfof("no neighbours")
return nil return false, nil
} }
logging.Log.WriteInfof("PULLING from node %s", neighbour[0]) logging.Log.WriteInfof("pulling from node %s", neighbour[0])
pullNode, err := mesh.GetNode(neighbour[0]) pullNode, err := mesh.GetNode(neighbour[0])
if err != nil || pullNode == nil { if err != nil || pullNode == nil {
return fmt.Errorf("node %s does not exist in the mesh", neighbour[0]) return false, fmt.Errorf("node %s does not exist in the mesh", neighbour[0])
} }
err = s.requester.SyncMesh(meshId, pullNode) err = s.requester.SyncMesh(mesh, pullNode)
if err == nil || err == io.EOF { if err == nil || err == io.EOF {
s.lastSync[meshId] = uint64(time.Now().Unix()) s.lastSync[mesh.GetMeshId()] = time.Now().Unix()
} else { } else {
return err return false, err
} }
s.syncCount++ s.syncCount++
return nil
changes := mesh.HasChanges()
return changes, nil
} }
// SyncMeshes: Sync all meshes // SyncMeshes: Sync all meshes
func (s *SyncerImpl) SyncMeshes() error { func (s *SyncerImpl) SyncMeshes() error {
for meshId := range s.manager.GetMeshes() { var wg sync.WaitGroup
err := s.Sync(meshId)
if err != nil { meshes := s.meshManager.GetMeshes()
logging.Log.WriteErrorf(err.Error())
s.lastPollLock.Lock()
meshesToSync := lib.Filter(lib.MapValues(meshes), func(mesh mesh.MeshProvider) bool {
return time.Now().Unix()-s.lastPoll[mesh.GetMeshId()] >= int64(s.configuration.SyncInterval)
})
s.lastPollLock.Unlock()
changes := make(chan bool, len(meshesToSync))
for i := 0; i < len(meshesToSync); {
wg.Add(1)
sync := func(index int) {
defer wg.Done()
var hasChanges bool = false
mesh := meshesToSync[index]
hasChanges, err := s.Sync(mesh)
changes <- hasChanges
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
s.lastPollLock.Lock()
s.lastPoll[mesh.GetMeshId()] = time.Now().Unix()
s.lastPollLock.Unlock()
}
go sync(i)
i++
}
wg.Wait()
hasChanges := false
for i := 0; i < len(changes); i++ {
if <-changes {
hasChanges = true
} }
} }
return nil var err error
if hasChanges {
logging.Log.WriteInfof("updating the WireGuard configuration")
err = s.meshManager.ApplyConfig()
if err != nil {
logging.Log.WriteErrorf("failed to update config %s", err.Error())
}
err = s.meshManager.GetRouteManager().UpdateRoutes()
if err != nil {
logging.Log.WriteErrorf("update routes failed %s", err.Error())
}
}
return err
} }
func NewSyncer(m mesh.MeshManager, conf *conf.DaemonConfiguration, r SyncRequester) Syncer { type NewSyncerParams struct {
cluster, _ := conn.NewConnCluster(conf.ClusterSize) MeshManager mesh.MeshManager
ConnectionManager conn.ConnectionManager
Configuration *conf.DaemonConfiguration
Requester SyncRequester
}
func NewSyncer(params *NewSyncerParams) Syncer {
cluster, _ := conn.NewConnCluster(params.Configuration.ClusterSize)
syncRequester := NewSyncRequester(NewSyncRequesterParams{
MeshManager: params.MeshManager,
ConnectionManager: params.ConnectionManager,
Configuration: params.Configuration,
})
return &SyncerImpl{ return &SyncerImpl{
manager: m, meshManager: params.MeshManager,
conf: conf, configuration: params.Configuration,
requester: r, requester: syncRequester,
infectionCount: 0, infectionCount: 0,
syncCount: 0, syncCount: 0,
cluster: cluster, cluster: cluster,
lastSync: make(map[string]uint64)} lastSync: make(map[string]int64),
lastPoll: make(map[string]int64)}
} }

View File

@ -1,41 +1,60 @@
package sync package sync
import ( import (
logging "github.com/tim-beatham/wgmesh/pkg/log" "github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/mesh" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
// SyncErrorHandler: Handles errors when attempting to sync // SyncErrorHandler: Handles errors when attempting to sync
type SyncErrorHandler interface { type SyncErrorHandler interface {
Handle(meshId string, endpoint string, err error) bool Handle(mesh mesh.MeshProvider, endpoint string, err error) bool
} }
// SyncErrorHandlerImpl Is an implementation of the SyncErrorHandler // SyncErrorHandlerImpl Is an implementation of the SyncErrorHandler
type SyncErrorHandlerImpl struct { type SyncErrorHandlerImpl struct {
meshManager mesh.MeshManager meshManager mesh.MeshManager
connManager conn.ConnectionManager
} }
func (s *SyncErrorHandlerImpl) handleFailed(meshId string, nodeId string) bool { func (s *SyncErrorHandlerImpl) handleFailed(mesh mesh.MeshProvider, nodeId string) bool {
mesh := s.meshManager.GetMesh(meshId)
mesh.Mark(nodeId) mesh.Mark(nodeId)
node, err := mesh.GetNode(nodeId)
if err != nil {
s.connManager.RemoveConnection(node.GetHostEndpoint())
}
return true return true
} }
func (s *SyncErrorHandlerImpl) Handle(meshId string, nodeId string, err error) bool { func (s *SyncErrorHandlerImpl) handleDeadlineExceeded(mesh mesh.MeshProvider, nodeId string) bool {
node, err := mesh.GetNode(nodeId)
if err != nil {
return false
}
s.connManager.RemoveConnection(node.GetHostEndpoint())
return true
}
func (s *SyncErrorHandlerImpl) Handle(mesh mesh.MeshProvider, nodeId string, err error) bool {
errStatus, _ := status.FromError(err) errStatus, _ := status.FromError(err)
logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message()) logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message())
switch errStatus.Code() { switch errStatus.Code() {
case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound: case codes.Unavailable, codes.Unknown, codes.Internal, codes.NotFound:
return s.handleFailed(meshId, nodeId) return s.handleFailed(mesh, nodeId)
case codes.DeadlineExceeded:
return s.handleDeadlineExceeded(mesh, nodeId)
} }
return false return false
} }
func NewSyncErrorHandler(m mesh.MeshManager) SyncErrorHandler { func NewSyncErrorHandler(m mesh.MeshManager, conn conn.ConnectionManager) SyncErrorHandler {
return &SyncErrorHandlerImpl{meshManager: m} return &SyncErrorHandlerImpl{meshManager: m, connManager: conn}
} }

View File

@ -2,76 +2,44 @@ package sync
import ( import (
"context" "context"
"errors"
"io" "io"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/conf"
logging "github.com/tim-beatham/wgmesh/pkg/log" "github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/mesh" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/rpc"
) )
// SyncRequester: coordinates the syncing of meshes // SyncRequester: coordinates the syncing of meshes
type SyncRequester interface { type SyncRequester interface {
GetMesh(meshId string, ifName string, port int, endPoint string) error SyncMesh(mesh mesh.MeshProvider, meshNode mesh.MeshNode) error
SyncMesh(meshid string, meshNode mesh.MeshNode) error
} }
type SyncRequesterImpl struct { type SyncRequesterImpl struct {
server *ctrlserver.MeshCtrlServer manager mesh.MeshManager
errorHdlr SyncErrorHandler connectionManager conn.ConnectionManager
configuration *conf.DaemonConfiguration
errorHdlr SyncErrorHandler
} }
// GetMesh: Retrieves the local state of the mesh at the endpoint // handleErr: handleGrpc errors
func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endPoint string) error { func (s *SyncRequesterImpl) handleErr(mesh mesh.MeshProvider, pubKey string, err error) error {
peerConnection, err := s.server.ConnectionManager.GetConnection(endPoint) ok := s.errorHdlr.Handle(mesh, pubKey, err)
if err != nil {
return err
}
client, err := peerConnection.GetClient()
if err != nil {
return err
}
c := rpc.NewSyncServiceClient(client)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
reply, err := c.GetConf(ctx, &rpc.GetConfRequest{MeshId: meshId})
if err != nil {
return err
}
err = s.server.MeshManager.AddMesh(&mesh.AddMeshParams{
MeshId: meshId,
WgPort: port,
MeshBytes: reply.Mesh,
})
return err
}
func (s *SyncRequesterImpl) handleErr(meshId, pubKey string, err error) error {
ok := s.errorHdlr.Handle(meshId, pubKey, err)
if ok { if ok {
return nil return nil
} }
return err return err
} }
// SyncMesh: Proactively send a sync request to the other mesh // SyncMesh: Proactively send a sync request to the other mesh
func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) error { func (s *SyncRequesterImpl) SyncMesh(mesh mesh.MeshProvider, meshNode mesh.MeshNode) error {
endpoint := meshNode.GetHostEndpoint() endpoint := meshNode.GetHostEndpoint()
pubKey, _ := meshNode.GetPublicKey() pubKey, _ := meshNode.GetPublicKey()
peerConnection, err := s.server.ConnectionManager.GetConnection(endpoint) peerConnection, err := s.connectionManager.GetConnection(endpoint)
if err != nil { if err != nil {
return err return err
@ -83,15 +51,9 @@ func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) erro
return err return err
} }
mesh := s.server.MeshManager.GetMesh(meshId)
if mesh == nil {
return errors.New("mesh does not exist")
}
c := rpc.NewSyncServiceClient(client) c := rpc.NewSyncServiceClient(client)
syncTimeOut := float64(s.server.Conf.SyncRate) * float64(time.Second) syncTimeOut := float64(s.configuration.SyncInterval) * float64(time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(syncTimeOut)) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(syncTimeOut))
defer cancel() defer cancel()
@ -99,11 +61,11 @@ func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) erro
err = s.syncMesh(mesh, ctx, c) err = s.syncMesh(mesh, ctx, c)
if err != nil { if err != nil {
return s.handleErr(meshId, pubKey.String(), err) s.handleErr(mesh, pubKey.String(), err)
} }
logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId) logging.Log.WriteInfof("synced with node: %s meshId: %s\n", endpoint, mesh.GetMeshId())
return nil return err
} }
func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context, client rpc.SyncServiceClient) error { func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context, client rpc.SyncServiceClient) error {
@ -127,7 +89,7 @@ func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context
in, err := stream.Recv() in, err := stream.Recv()
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
logging.Log.WriteInfof("Stream recv error: %s\n", err.Error()) logging.Log.WriteInfof("stream recv error: %s\n", err.Error())
return err return err
} }
@ -136,7 +98,7 @@ func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context
} }
if err != nil { if err != nil {
logging.Log.WriteInfof("Syncer recv error: %s\n", err.Error()) logging.Log.WriteInfof("syncer recv error: %s\n", err.Error())
return err return err
} }
@ -150,7 +112,17 @@ func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context
return nil return nil
} }
func NewSyncRequester(s *ctrlserver.MeshCtrlServer) SyncRequester { type NewSyncRequesterParams struct {
errorHdlr := NewSyncErrorHandler(s.MeshManager) MeshManager mesh.MeshManager
return &SyncRequesterImpl{server: s, errorHdlr: errorHdlr} ConnectionManager conn.ConnectionManager
Configuration *conf.DaemonConfiguration
}
func NewSyncRequester(params NewSyncRequesterParams) SyncRequester {
errorHdlr := NewSyncErrorHandler(params.MeshManager, params.ConnectionManager)
return &SyncRequesterImpl{manager: params.MeshManager,
connectionManager: params.ConnectionManager,
configuration: params.Configuration,
errorHdlr: errorHdlr,
}
} }

View File

@ -1,18 +0,0 @@
package sync
import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
// Run implements SyncScheduler.
func syncFunction(syncer Syncer) lib.TimerFunc {
return func() error {
syncer.SyncMeshes()
return nil
}
}
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester, syncer Syncer) *lib.Timer {
return lib.NewTimer(syncFunction(syncer), s.Conf.SyncRate)
}

View File

@ -6,19 +6,18 @@ import (
"errors" "errors"
"io" "io"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/rpc"
"github.com/tim-beatham/wgmesh/pkg/rpc"
) )
type SyncServiceImpl struct { type SyncServiceImpl struct {
rpc.UnimplementedSyncServiceServer rpc.UnimplementedSyncServiceServer
Server *ctrlserver.MeshCtrlServer MeshManager mesh.MeshManager
} }
// GetMesh: Gets a nodes local mesh configuration as a CRDT // GetMesh: Gets a nodes local mesh configuration as a CRDT
func (s *SyncServiceImpl) GetConf(context context.Context, request *rpc.GetConfRequest) (*rpc.GetConfReply, error) { func (s *SyncServiceImpl) GetConf(context context.Context, request *rpc.GetConfRequest) (*rpc.GetConfReply, error) {
mesh := s.Server.MeshManager.GetMesh(request.MeshId) mesh := s.MeshManager.GetMesh(request.MeshId)
if mesh == nil { if mesh == nil {
return nil, errors.New("mesh does not exist") return nil, errors.New("mesh does not exist")
@ -56,7 +55,7 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error
if len(meshId) == 0 { if len(meshId) == 0 {
meshId = in.MeshId meshId = in.MeshId
mesh := s.Server.MeshManager.GetMesh(meshId) mesh := s.MeshManager.GetMesh(meshId)
if mesh == nil { if mesh == nil {
return errors.New("mesh does not exist") return errors.New("mesh does not exist")
@ -92,7 +91,3 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error
} }
} }
} }
func NewSyncService(server *ctrlserver.MeshCtrlServer) *SyncServiceImpl {
return &SyncServiceImpl{Server: server}
}

View File

@ -1,15 +0,0 @@
package timer
import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
)
func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
timerFunc := func() error {
logging.Log.WriteInfof("Updated Timestamp")
return ctrlServer.MeshManager.UpdateTimeStamp()
}
return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime)
}

View File

@ -5,8 +5,8 @@ import (
"crypto/rand" "crypto/rand"
"fmt" "fmt"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )